Merge pull request #6 from bmaltais/dev

v17.1
This commit is contained in:
bmaltais 2022-12-17 11:53:33 -05:00 committed by GitHub
commit b946be390d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 492 additions and 663 deletions

311
README.md
View File

@ -77,6 +77,20 @@ pip install --upgrade -r requirements.txt
Once the commands have completed successfully you should be ready to use the new version.
## GUI
There is now support for GUI based training using gradio. You can start the GUI interface by running:
```powershell
python .\dreambooth_gui.py
```
## Quickstart screencast
You can find a screen cast on how to use the GUI at the following location:
https://youtu.be/RlvqEKj03WI
## Folders configuration
Refer to the note to understand how to create the folde structure. In short it should look like:
@ -106,305 +120,16 @@ my_asd_dog_dreambooth
`- dog8.png
```
## GUI
There is now support for GUI based training using gradio. You can start the GUI interface by running:
```powershell
python .\dreambooth_gui.py
```
## Support
Drop by the discord server for support: https://discord.com/channels/1041518562487058594/1041518563242020906
## Manual Script Execution
### SD1.5 example
Edit and paste the following in a Powershell terminal:
```powershell
accelerate launch --num_cpu_threads_per_process 6 train_db_fixed.py `
--pretrained_model_name_or_path="D:\models\last.ckpt" `
--train_data_dir="D:\dreambooth\train_bernard\train_man" `
--reg_data_dir="D:\dreambooth\train_bernard\reg_man" `
--output_dir="D:\dreambooth\train_bernard" `
--prior_loss_weight=1.0 `
--resolution=512 `
--train_batch_size=1 `
--learning_rate=1e-6 `
--max_train_steps=2100 `
--use_8bit_adam `
--xformers `
--mixed_precision="fp16" `
--cache_latents `
--gradient_checkpointing `
--save_every_n_epochs=1
```
### SD2.0 512 Base example
```powershell
# variable values
$pretrained_model_name_or_path = "D:\models\512-base-ema.ckpt"
$data_dir = "D:\models\dariusz_zawadzki\kohya_reg\data"
$reg_data_dir = "D:\models\dariusz_zawadzki\kohya_reg\reg"
$logging_dir = "D:\models\dariusz_zawadzki\logs"
$output_dir = "D:\models\dariusz_zawadzki\train_db_fixed_model_reg_v2"
$resolution = "512,512"
$lr_scheduler="polynomial"
$cache_latents = 1 # 1 = true, 0 = false
$image_num = Get-ChildItem $data_dir -Recurse -File -Include *.png, *.jpg, *.webp | Measure-Object | %{$_.Count}
Write-Output "image_num: $image_num"
$dataset_repeats = 200
$learning_rate = 2e-6
$train_batch_size = 4
$epoch = 1
$save_every_n_epochs=1
$mixed_precision="bf16"
$num_cpu_threads_per_process=6
# You should not have to change values past this point
if ($cache_latents -eq 1) {
$cache_latents_value="--cache_latents"
}
else {
$cache_latents_value=""
}
$repeats = $image_num * $dataset_repeats
$mts = [Math]::Ceiling($repeats / $train_batch_size * $epoch)
Write-Output "Repeats: $repeats"
cd D:\kohya_ss
.\venv\Scripts\activate
accelerate launch --num_cpu_threads_per_process $num_cpu_threads_per_process train_db_fixed.py `
--v2 `
--pretrained_model_name_or_path=$pretrained_model_name_or_path `
--train_data_dir=$data_dir `
--output_dir=$output_dir `
--resolution=$resolution `
--train_batch_size=$train_batch_size `
--learning_rate=$learning_rate `
--max_train_steps=$mts `
--use_8bit_adam `
--xformers `
--mixed_precision=$mixed_precision `
$cache_latents_value `
--save_every_n_epochs=$save_every_n_epochs `
--logging_dir=$logging_dir `
--save_precision="fp16" `
--reg_data_dir=$reg_data_dir `
--seed=494481440 `
--lr_scheduler=$lr_scheduler
# Add the inference yaml file along with the model for proper loading. Need to have the same name as model... Most likelly "last.yaml" in our case.
cp v2_inference\v2-inference.yaml $output_dir"\last.yaml"
```
### SD2.0 768v Base example
```powershell
# variable values
$pretrained_model_name_or_path = "C:\Users\berna\Downloads\768-v-ema.ckpt"
$data_dir = "D:\dreambooth\train_paper_artwork\kohya\data"
$logging_dir = "D:\dreambooth\train_paper_artwork"
$output_dir = "D:\models\paper_artwork\train_db_fixed_model_v2_768v"
$resolution = "768,768"
$lr_scheduler="polynomial"
$cache_latents = 1 # 1 = true, 0 = false
$image_num = Get-ChildItem $data_dir -Recurse -File -Include *.png, *.jpg, *.webp | Measure-Object | %{$_.Count}
Write-Output "image_num: $image_num"
$dataset_repeats = 200
$learning_rate = 2e-6
$train_batch_size = 4
$epoch = 1
$save_every_n_epochs=1
$mixed_precision="bf16"
$num_cpu_threads_per_process=6
# You should not have to change values past this point
if ($cache_latents -eq 1) {
$cache_latents_value="--cache_latents"
}
else {
$cache_latents_value=""
}
$repeats = $image_num * $dataset_repeats
$mts = [Math]::Ceiling($repeats / $train_batch_size * $epoch)
Write-Output "Repeats: $repeats"
cd D:\kohya_ss
.\venv\Scripts\activate
accelerate launch --num_cpu_threads_per_process $num_cpu_threads_per_process train_db_fixed.py `
--v2 `
--v_parameterization `
--pretrained_model_name_or_path=$pretrained_model_name_or_path `
--train_data_dir=$data_dir `
--output_dir=$output_dir `
--resolution=$resolution `
--train_batch_size=$train_batch_size `
--learning_rate=$learning_rate `
--max_train_steps=$mts `
--use_8bit_adam `
--xformers `
--mixed_precision=$mixed_precision `
$cache_latents_value `
--save_every_n_epochs=$save_every_n_epochs `
--logging_dir=$logging_dir `
--save_precision="fp16" `
--seed=494481440 `
--lr_scheduler=$lr_scheduler
# Add the inference 768v yaml file along with the model for proper loading. Need to have the same name as model... Most likelly "last.yaml" in our case.
cp v2_inference\v2-inference-v.yaml $output_dir"\last.yaml"
```
## Finetuning
If you would rather use model finetuning rather than the dreambooth method you can use a command similat to the following. The advantage of fine tuning is that you do not need to worry about regularization images... but you need to provide captions for every images. The caption will be used to train the model. You can use auto1111 to preprocess your training images and add either BLIP or danbooru captions to them. You then need to edit those to add the name of the model and correct any wrong description.
```
accelerate launch --num_cpu_threads_per_process 6 train_db_fixed-ber.py `
--pretrained_model_name_or_path="D:\models\alexandrine_teissier_and_bernard_maltais-400-kohya-sd15-v1.ckpt" `
--train_data_dir="D:\dreambooth\source\alet_et_bernard\landscape-pp" `
--output_dir="D:\dreambooth\train_alex_and_bernard" `
--resolution="640,448" `
--train_batch_size=1 `
--learning_rate=1e-6 `
--max_train_steps=550 `
--use_8bit_adam `
--xformers `
--mixed_precision="fp16" `
--cache_latents `
--save_every_n_epochs=1 `
--fine_tuning `
--enable_bucket `
--dataset_repeats=200 `
--seed=23 `
---save_precision="fp16"
```
Refer to this url for more details about finetuning: https://note.com/kohya_ss/n/n1269f1e1a54e
## Options list
```txt
usage: train_db_fixed.py [-h] [--v2] [--v_parameterization] [--pretrained_model_name_or_path PRETRAINED_MODEL_NAME_OR_PATH]
[--fine_tuning] [--shuffle_caption] [--caption_extention CAPTION_EXTENTION]
[--caption_extension CAPTION_EXTENSION] [--train_data_dir TRAIN_DATA_DIR]
[--reg_data_dir REG_DATA_DIR] [--dataset_repeats DATASET_REPEATS] [--output_dir OUTPUT_DIR]
[--use_safetensors] [--save_every_n_epochs SAVE_EVERY_N_EPOCHS] [--save_state] [--resume RESUME]
[--prior_loss_weight PRIOR_LOSS_WEIGHT] [--no_token_padding]
[--stop_text_encoder_training STOP_TEXT_ENCODER_TRAINING] [--color_aug] [--flip_aug]
[--face_crop_aug_range FACE_CROP_AUG_RANGE] [--random_crop] [--debug_dataset]
[--resolution RESOLUTION] [--train_batch_size TRAIN_BATCH_SIZE] [--use_8bit_adam] [--mem_eff_attn]
[--xformers] [--vae VAE] [--cache_latents] [--enable_bucket] [--min_bucket_reso MIN_BUCKET_RESO]
[--max_bucket_reso MAX_BUCKET_RESO] [--learning_rate LEARNING_RATE]
[--max_train_steps MAX_TRAIN_STEPS] [--seed SEED] [--gradient_checkpointing]
[--mixed_precision {no,fp16,bf16}] [--full_fp16] [--save_precision {None,float,fp16,bf16}]
[--clip_skip CLIP_SKIP] [--logging_dir LOGGING_DIR] [--log_prefix LOG_PREFIX]
[--lr_scheduler LR_SCHEDULER] [--lr_warmup_steps LR_WARMUP_STEPS]
options:
-h, --help show this help message and exit
--v2 load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む
--v_parameterization enable v-parameterization training / v-parameterization学習を有効にする
--pretrained_model_name_or_path PRETRAINED_MODEL_NAME_OR_PATH
pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint /
学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル
--fine_tuning fine tune the model instead of DreamBooth / DreamBoothではなくfine tuningする
--shuffle_caption shuffle comma-separated caption / コンマで区切られたcaptionの各要素をshuffleする
--caption_extention CAPTION_EXTENTION
extension of caption files (backward compatiblity) / 読み込むcaptionファイルの拡張子スペルミスを残してあります
--caption_extension CAPTION_EXTENSION
extension of caption files / 読み込むcaptionファイルの拡張子
--train_data_dir TRAIN_DATA_DIR
directory for train images / 学習画像データのディレクトリ
--reg_data_dir REG_DATA_DIR
directory for regularization images / 正則化画像データのディレクトリ
--dataset_repeats DATASET_REPEATS
repeat dataset in fine tuning / fine tuning時にデータセットを繰り返す回数
--output_dir OUTPUT_DIR
directory to output trained model / 学習後のモデル出力先ディレクトリ
--use_safetensors use safetensors format to save / checkpoint、モデルをsafetensors形式で保存する
--save_every_n_epochs SAVE_EVERY_N_EPOCHS
save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する
--save_state save training state additionally (including optimizer states etc.) /
optimizerなど学習状態も含めたstateを追加で保存する
--resume RESUME saved state to resume training / 学習再開するモデルのstate
--prior_loss_weight PRIOR_LOSS_WEIGHT
loss weight for regularization images / 正則化画像のlossの重み
--no_token_padding disable token padding (same as Diffuser's DreamBooth) /
トークンのpaddingを無効にするDiffusers版DreamBoothと同じ動作
--stop_text_encoder_training STOP_TEXT_ENCODER_TRAINING
steps to stop text encoder training / Text Encoderの学習を止めるステップ数
--color_aug enable weak color augmentation / 学習時に色合いのaugmentationを有効にする
--flip_aug enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする
--face_crop_aug_range FACE_CROP_AUG_RANGE
enable face-centered crop augmentation and its range (e.g. 2.0,4.0) /
学習時に顔を中心とした切り出しaugmentationを有効にするときは倍率を指定する2.0,4.0
--random_crop enable random crop (for style training in face-centered crop augmentation) /
ランダムな切り出しを有効にする顔を中心としたaugmentationを行うときに画風の学習用に指定する
--debug_dataset show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)
--resolution RESOLUTION
resolution in training ('size' or 'width,height') / 学習時の画像解像度('サイズ'指定、または'幅,高さ' 指定)
--train_batch_size TRAIN_BATCH_SIZE
batch size for training (1 means one train or reg data, not train/reg pair) /
学習時のバッチサイズ1でtrain/regをそれぞれ1件ずつ学習
--use_8bit_adam use 8bit Adam optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使うbitsandbytesのインス トールが必要)
--mem_eff_attn use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う
--xformers use xformers for CrossAttention / CrossAttentionにxformersを使う
--vae VAE path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ
--cache_latents cache latents to reduce memory (augmentations must be disabled) /
メモリ削減のためにlatentをcacheするaugmentationは使用不可
--enable_bucket enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする
--min_bucket_reso MIN_BUCKET_RESO
minimum resolution for buckets / bucketの最小解像度
--max_bucket_reso MAX_BUCKET_RESO
maximum resolution for buckets / bucketの最小解像度
--learning_rate LEARNING_RATE
learning rate / 学習率
--max_train_steps MAX_TRAIN_STEPS
training steps / 学習ステップ数
--seed SEED random seed for training / 学習時の乱数のseed
--gradient_checkpointing
enable gradient checkpointing / grandient checkpointingを有効にする
--mixed_precision {no,fp16,bf16}
use mixed precision / 混合精度を使う場合、その精度
--full_fp16 fp16 training including gradients / 勾配も含めてfp16で学習する
--save_precision {None,float,fp16,bf16}
precision in saving (available in StableDiffusion checkpoint) /
保存時に精度を変更して保存するStableDiffusion形式での保存時のみ有効
--clip_skip CLIP_SKIP
use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いるnは1以上
--logging_dir LOGGING_DIR
enable logging and output TensorBoard log to this directory /
ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する
--log_prefix LOG_PREFIX
add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列
--lr_scheduler LR_SCHEDULER
scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial,
constant (default), constant_with_warmup
--lr_warmup_steps LR_WARMUP_STEPS
Number of steps for the warmup in the lr scheduler (default is 0) /
学習率のスケジューラをウォームアップするステップ数デフォルト0
```
## Change history
* 12/17 (v17.1) update:
- Adding GUI for kohya_ss called dreambooth_gui.py
- removing support for `--finetuning` as there is now a dedicated python repo for that. `--fine-tuning` is still there behind the scene until kohya_ss remove it in a future code release.
- removing cli examples as I will now focus on the GUI for training. People who prefer cli based training can still do that.
* 12/13 (v17) update:
- Added support for learning to fp16 gradient (experimental function). SD1.x can be trained with 8GB of VRAM. Specify full_fp16 options.
* 12/06 (v16) update:

View File

@ -10,12 +10,22 @@ import os
import subprocess
import pathlib
import shutil
from dreambooth_gui.dreambooth_folder_creation import gradio_dreambooth_folder_creation_tab
from dreambooth_gui.dreambooth_folder_creation import (
gradio_dreambooth_folder_creation_tab,
)
from dreambooth_gui.caption_gui import gradio_caption_gui_tab
from dreambooth_gui.common_gui import get_folder_path, remove_doublequote, get_file_path
from dreambooth_gui.common_gui import (
get_folder_path,
remove_doublequote,
get_file_path,
)
from easygui import filesavebox, msgbox
# sys.path.insert(0, './dreambooth_gui')
folder_symbol = '\U0001f4c2' # 📂
refresh_symbol = '\U0001f504' # 🔄
save_style_symbol = '\U0001f4be' # 💾
document_symbol = '\U0001F4C4' # 📄
def save_configuration(
save_as,
@ -53,22 +63,22 @@ def save_configuration(
):
original_file_path = file_path
save_as_bool = True if save_as.get("label") == "True" else False
save_as_bool = True if save_as.get('label') == 'True' else False
if save_as_bool:
print("Save as...")
print('Save as...')
file_path = filesavebox(
"Select the config file to save",
default="finetune.json",
filetypes="*.json",
'Select the config file to save',
default='finetune.json',
filetypes='*.json',
)
else:
print("Save...")
if file_path == None or file_path == "":
print('Save...')
if file_path == None or file_path == '':
file_path = filesavebox(
"Select the config file to save",
default="finetune.json",
filetypes="*.json",
'Select the config file to save',
default='finetune.json',
filetypes='*.json',
)
if file_path == None:
@ -76,40 +86,40 @@ def save_configuration(
# Return the values of the variables as a dictionary
variables = {
"pretrained_model_name_or_path": pretrained_model_name_or_path,
"v2": v2,
"v_parameterization": v_parameterization,
"logging_dir": logging_dir,
"train_data_dir": train_data_dir,
"reg_data_dir": reg_data_dir,
"output_dir": output_dir,
"max_resolution": max_resolution,
"learning_rate": learning_rate,
"lr_scheduler": lr_scheduler,
"lr_warmup": lr_warmup,
"train_batch_size": train_batch_size,
"epoch": epoch,
"save_every_n_epochs": save_every_n_epochs,
"mixed_precision": mixed_precision,
"save_precision": save_precision,
"seed": seed,
"num_cpu_threads_per_process": num_cpu_threads_per_process,
"convert_to_safetensors": convert_to_safetensors,
"convert_to_ckpt": convert_to_ckpt,
"cache_latent": cache_latent,
"caption_extention": caption_extention,
"use_safetensors": use_safetensors,
"enable_bucket": enable_bucket,
"gradient_checkpointing": gradient_checkpointing,
"full_fp16": full_fp16,
"no_token_padding": no_token_padding,
"stop_text_encoder_training": stop_text_encoder_training,
"use_8bit_adam": use_8bit_adam,
"xformers": xformers,
'pretrained_model_name_or_path': pretrained_model_name_or_path,
'v2': v2,
'v_parameterization': v_parameterization,
'logging_dir': logging_dir,
'train_data_dir': train_data_dir,
'reg_data_dir': reg_data_dir,
'output_dir': output_dir,
'max_resolution': max_resolution,
'learning_rate': learning_rate,
'lr_scheduler': lr_scheduler,
'lr_warmup': lr_warmup,
'train_batch_size': train_batch_size,
'epoch': epoch,
'save_every_n_epochs': save_every_n_epochs,
'mixed_precision': mixed_precision,
'save_precision': save_precision,
'seed': seed,
'num_cpu_threads_per_process': num_cpu_threads_per_process,
'convert_to_safetensors': convert_to_safetensors,
'convert_to_ckpt': convert_to_ckpt,
'cache_latent': cache_latent,
'caption_extention': caption_extention,
'use_safetensors': use_safetensors,
'enable_bucket': enable_bucket,
'gradient_checkpointing': gradient_checkpointing,
'full_fp16': full_fp16,
'no_token_padding': no_token_padding,
'stop_text_encoder_training': stop_text_encoder_training,
'use_8bit_adam': use_8bit_adam,
'xformers': xformers,
}
# Save the data to the selected file
with open(file_path, "w") as file:
with open(file_path, 'w') as file:
json.dump(variables, file)
return file_path
@ -152,10 +162,10 @@ def open_configuration(
original_file_path = file_path
file_path = get_file_path(file_path)
if file_path != "" and file_path != None:
if file_path != '' and file_path != None:
print(file_path)
# load variables from JSON file
with open(file_path, "r") as f:
with open(file_path, 'r') as f:
my_data = json.load(f)
else:
file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action
@ -164,36 +174,40 @@ def open_configuration(
# Return the values of the variables as a dictionary
return (
file_path,
my_data.get("pretrained_model_name_or_path", pretrained_model_name_or_path),
my_data.get("v2", v2),
my_data.get("v_parameterization", v_parameterization),
my_data.get("logging_dir", logging_dir),
my_data.get("train_data_dir", train_data_dir),
my_data.get("reg_data_dir", reg_data_dir),
my_data.get("output_dir", output_dir),
my_data.get("max_resolution", max_resolution),
my_data.get("learning_rate", learning_rate),
my_data.get("lr_scheduler", lr_scheduler),
my_data.get("lr_warmup", lr_warmup),
my_data.get("train_batch_size", train_batch_size),
my_data.get("epoch", epoch),
my_data.get("save_every_n_epochs", save_every_n_epochs),
my_data.get("mixed_precision", mixed_precision),
my_data.get("save_precision", save_precision),
my_data.get("seed", seed),
my_data.get("num_cpu_threads_per_process", num_cpu_threads_per_process),
my_data.get("convert_to_safetensors", convert_to_safetensors),
my_data.get("convert_to_ckpt", convert_to_ckpt),
my_data.get("cache_latent", cache_latent),
my_data.get("caption_extention", caption_extention),
my_data.get("use_safetensors", use_safetensors),
my_data.get("enable_bucket", enable_bucket),
my_data.get("gradient_checkpointing", gradient_checkpointing),
my_data.get("full_fp16", full_fp16),
my_data.get("no_token_padding", no_token_padding),
my_data.get("stop_text_encoder_training", stop_text_encoder_training),
my_data.get("use_8bit_adam", use_8bit_adam),
my_data.get("xformers", xformers),
my_data.get(
'pretrained_model_name_or_path', pretrained_model_name_or_path
),
my_data.get('v2', v2),
my_data.get('v_parameterization', v_parameterization),
my_data.get('logging_dir', logging_dir),
my_data.get('train_data_dir', train_data_dir),
my_data.get('reg_data_dir', reg_data_dir),
my_data.get('output_dir', output_dir),
my_data.get('max_resolution', max_resolution),
my_data.get('learning_rate', learning_rate),
my_data.get('lr_scheduler', lr_scheduler),
my_data.get('lr_warmup', lr_warmup),
my_data.get('train_batch_size', train_batch_size),
my_data.get('epoch', epoch),
my_data.get('save_every_n_epochs', save_every_n_epochs),
my_data.get('mixed_precision', mixed_precision),
my_data.get('save_precision', save_precision),
my_data.get('seed', seed),
my_data.get(
'num_cpu_threads_per_process', num_cpu_threads_per_process
),
my_data.get('convert_to_safetensors', convert_to_safetensors),
my_data.get('convert_to_ckpt', convert_to_ckpt),
my_data.get('cache_latent', cache_latent),
my_data.get('caption_extention', caption_extention),
my_data.get('use_safetensors', use_safetensors),
my_data.get('enable_bucket', enable_bucket),
my_data.get('gradient_checkpointing', gradient_checkpointing),
my_data.get('full_fp16', full_fp16),
my_data.get('no_token_padding', no_token_padding),
my_data.get('stop_text_encoder_training', stop_text_encoder_training),
my_data.get('use_8bit_adam', use_8bit_adam),
my_data.get('xformers', xformers),
)
@ -229,46 +243,46 @@ def train_model(
use_8bit_adam,
xformers,
):
def save_inference_file(output_dir, v2, v_parameterization):
# Copy inference model for v2 if required
if v2 and v_parameterization:
print(f"Saving v2-inference-v.yaml as {output_dir}/last.yaml")
print(f'Saving v2-inference-v.yaml as {output_dir}/last.yaml')
shutil.copy(
f"./v2_inference/v2-inference-v.yaml",
f"{output_dir}/last.yaml",
f'./v2_inference/v2-inference-v.yaml',
f'{output_dir}/last.yaml',
)
elif v2:
print(f"Saving v2-inference.yaml as {output_dir}/last.yaml")
print(f'Saving v2-inference.yaml as {output_dir}/last.yaml')
shutil.copy(
f"./v2_inference/v2-inference.yaml",
f"{output_dir}/last.yaml",
f'./v2_inference/v2-inference.yaml',
f'{output_dir}/last.yaml',
)
if pretrained_model_name_or_path == "":
msgbox("Source model information is missing")
if pretrained_model_name_or_path == '':
msgbox('Source model information is missing')
return
if train_data_dir == "":
msgbox("Image folder path is missing")
if train_data_dir == '':
msgbox('Image folder path is missing')
return
if not os.path.exists(train_data_dir):
msgbox("Image folder does not exist")
msgbox('Image folder does not exist')
return
if reg_data_dir != "":
if reg_data_dir != '':
if not os.path.exists(reg_data_dir):
msgbox("Regularisation folder does not exist")
msgbox('Regularisation folder does not exist')
return
if output_dir == "":
msgbox("Output folder path is missing")
if output_dir == '':
msgbox('Output folder path is missing')
return
# Get a list of all subfolders in train_data_dir
subfolders = [
f for f in os.listdir(train_data_dir)
f
for f in os.listdir(train_data_dir)
if os.path.isdir(os.path.join(train_data_dir, f))
]
@ -277,115 +291,127 @@ def train_model(
# Loop through each subfolder and extract the number of repeats
for folder in subfolders:
# Extract the number of repeats from the folder name
repeats = int(folder.split("_")[0])
repeats = int(folder.split('_')[0])
# Count the number of images in the folder
num_images = len([
f for f in os.listdir(os.path.join(train_data_dir, folder))
if f.endswith(".jpg") or f.endswith(".jpeg") or f.endswith(".png")
or f.endswith(".webp")
])
num_images = len(
[
f
for f in os.listdir(os.path.join(train_data_dir, folder))
if f.endswith('.jpg')
or f.endswith('.jpeg')
or f.endswith('.png')
or f.endswith('.webp')
]
)
# Calculate the total number of steps for this folder
steps = repeats * num_images
total_steps += steps
# Print the result
print(f"Folder {folder}: {steps} steps")
print(f'Folder {folder}: {steps} steps')
# Print the result
# print(f"{total_steps} total steps")
if reg_data_dir == "":
if reg_data_dir == '':
reg_factor = 1
else:
print(
"Regularisation images are used... Will double the number of steps required..."
'Regularisation images are used... Will double the number of steps required...'
)
reg_factor = 2
# calculate max_train_steps
max_train_steps = int(
math.ceil(
float(total_steps) / int(train_batch_size) * int(epoch) *
int(reg_factor)))
print(f"max_train_steps = {max_train_steps}")
float(total_steps)
/ int(train_batch_size)
* int(epoch)
* int(reg_factor)
)
)
print(f'max_train_steps = {max_train_steps}')
# calculate stop encoder training
if stop_text_encoder_training_pct == None:
stop_text_encoder_training = 0
else:
stop_text_encoder_training = math.ceil(
float(max_train_steps) / 100 * int(stop_text_encoder_training_pct))
print(f"stop_text_encoder_training = {stop_text_encoder_training}")
float(max_train_steps) / 100 * int(stop_text_encoder_training_pct)
)
print(f'stop_text_encoder_training = {stop_text_encoder_training}')
lr_warmup_steps = round(float(int(lr_warmup) * int(max_train_steps) / 100))
print(f"lr_warmup_steps = {lr_warmup_steps}")
print(f'lr_warmup_steps = {lr_warmup_steps}')
run_cmd = f'accelerate launch --num_cpu_threads_per_process={num_cpu_threads_per_process} "train_db_fixed.py"'
if v2:
run_cmd += " --v2"
run_cmd += ' --v2'
if v_parameterization:
run_cmd += " --v_parameterization"
run_cmd += ' --v_parameterization'
if cache_latent:
run_cmd += " --cache_latents"
run_cmd += ' --cache_latents'
if use_safetensors:
run_cmd += " --use_safetensors"
run_cmd += ' --use_safetensors'
if enable_bucket:
run_cmd += " --enable_bucket"
run_cmd += ' --enable_bucket'
if gradient_checkpointing:
run_cmd += " --gradient_checkpointing"
run_cmd += ' --gradient_checkpointing'
if full_fp16:
run_cmd += " --full_fp16"
run_cmd += ' --full_fp16'
if no_token_padding:
run_cmd += " --no_token_padding"
run_cmd += ' --no_token_padding'
if use_8bit_adam:
run_cmd += " --use_8bit_adam"
run_cmd += ' --use_8bit_adam'
if xformers:
run_cmd += " --xformers"
run_cmd += f" --pretrained_model_name_or_path={pretrained_model_name_or_path}"
run_cmd += ' --xformers'
run_cmd += (
f' --pretrained_model_name_or_path={pretrained_model_name_or_path}'
)
run_cmd += f' --train_data_dir="{train_data_dir}"'
if len(reg_data_dir):
run_cmd += f' --reg_data_dir="{reg_data_dir}"'
run_cmd += f" --resolution={max_resolution}"
run_cmd += f" --output_dir={output_dir}"
run_cmd += f" --train_batch_size={train_batch_size}"
run_cmd += f" --learning_rate={learning_rate}"
run_cmd += f" --lr_scheduler={lr_scheduler}"
run_cmd += f" --lr_warmup_steps={lr_warmup_steps}"
run_cmd += f" --max_train_steps={max_train_steps}"
run_cmd += f" --use_8bit_adam"
run_cmd += f" --xformers"
run_cmd += f" --mixed_precision={mixed_precision}"
run_cmd += f" --save_every_n_epochs={save_every_n_epochs}"
run_cmd += f" --seed={seed}"
run_cmd += f" --save_precision={save_precision}"
run_cmd += f" --logging_dir={logging_dir}"
run_cmd += f" --caption_extention={caption_extention}"
run_cmd += f" --stop_text_encoder_training={stop_text_encoder_training}"
run_cmd += f' --resolution={max_resolution}'
run_cmd += f' --output_dir={output_dir}'
run_cmd += f' --train_batch_size={train_batch_size}'
run_cmd += f' --learning_rate={learning_rate}'
run_cmd += f' --lr_scheduler={lr_scheduler}'
run_cmd += f' --lr_warmup_steps={lr_warmup_steps}'
run_cmd += f' --max_train_steps={max_train_steps}'
run_cmd += f' --use_8bit_adam'
run_cmd += f' --xformers'
run_cmd += f' --mixed_precision={mixed_precision}'
run_cmd += f' --save_every_n_epochs={save_every_n_epochs}'
run_cmd += f' --seed={seed}'
run_cmd += f' --save_precision={save_precision}'
run_cmd += f' --logging_dir={logging_dir}'
run_cmd += f' --caption_extention={caption_extention}'
run_cmd += f' --stop_text_encoder_training={stop_text_encoder_training}'
print(run_cmd)
# Run the command
subprocess.run(run_cmd)
# check if output_dir/last is a directory... therefore it is a diffuser model
last_dir = pathlib.Path(f"{output_dir}/last")
last_dir = pathlib.Path(f'{output_dir}/last')
print(last_dir)
if last_dir.is_dir():
if convert_to_ckpt:
print(f"Converting diffuser model {last_dir} to {last_dir}.ckpt")
print(f'Converting diffuser model {last_dir} to {last_dir}.ckpt')
os.system(
f"python ./tools/convert_diffusers20_original_sd.py {last_dir} {last_dir}.ckpt --{save_precision}"
f'python ./tools/convert_diffusers20_original_sd.py {last_dir} {last_dir}.ckpt --{save_precision}'
)
save_inference_file(output_dir, v2, v_parameterization)
if convert_to_safetensors:
print(
f"Converting diffuser model {last_dir} to {last_dir}.safetensors"
f'Converting diffuser model {last_dir} to {last_dir}.safetensors'
)
os.system(
f"python ./tools/convert_diffusers20_original_sd.py {last_dir} {last_dir}.safetensors --{save_precision}"
f'python ./tools/convert_diffusers20_original_sd.py {last_dir} {last_dir}.safetensors --{save_precision}'
)
save_inference_file(output_dir, v2, v_parameterization)
@ -400,13 +426,13 @@ def train_model(
def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
# define a list of substrings to search for
substrings_v2 = [
"stabilityai/stable-diffusion-2-1-base",
"stabilityai/stable-diffusion-2-base",
'stabilityai/stable-diffusion-2-1-base',
'stabilityai/stable-diffusion-2-base',
]
# check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v2 list
if str(value) in substrings_v2:
print("SD v2 model detected. Setting --v2 parameter")
print('SD v2 model detected. Setting --v2 parameter')
v2 = True
v_parameterization = False
@ -414,14 +440,14 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
# define a list of substrings to search for v-objective
substrings_v_parameterization = [
"stabilityai/stable-diffusion-2-1",
"stabilityai/stable-diffusion-2",
'stabilityai/stable-diffusion-2-1',
'stabilityai/stable-diffusion-2',
]
# check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v_parameterization list
if str(value) in substrings_v_parameterization:
print(
"SD v2 v_parameterization detected. Setting --v2 parameter and --v_parameterization"
'SD v2 v_parameterization detected. Setting --v2 parameter and --v_parameterization'
)
v2 = True
v_parameterization = True
@ -430,8 +456,8 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
# define a list of substrings to v1.x
substrings_v1_model = [
"CompVis/stable-diffusion-v1-4",
"runwayml/stable-diffusion-v1-5",
'CompVis/stable-diffusion-v1-4',
'runwayml/stable-diffusion-v1-5',
]
if str(value) in substrings_v1_model:
@ -440,62 +466,79 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
return value, v2, v_parameterization
if value == "custom":
value = ""
if value == 'custom':
value = ''
v2 = False
v_parameterization = False
return value, v2, v_parameterization
css = ""
css = ''
if os.path.exists("./style.css"):
with open(os.path.join("./style.css"), "r", encoding="utf8") as file:
print("Load CSS...")
css += file.read() + "\n"
if os.path.exists('./style.css'):
with open(os.path.join('./style.css'), 'r', encoding='utf8') as file:
print('Load CSS...')
css += file.read() + '\n'
interface = gr.Blocks(css=css)
with interface:
dummy_true = gr.Label(value=True, visible=False)
dummy_false = gr.Label(value=False, visible=False)
gr.Markdown("Enter kohya finetuner parameter using this interface.")
with gr.Accordion("Configuration File Load/Save", open=False):
gr.Markdown('Enter kohya finetuner parameter using this interface.')
with gr.Accordion('Configuration File Load/Save', open=False):
with gr.Row():
button_open_config = gr.Button("Open 📂", elem_id="open_folder")
button_save_config = gr.Button("Save 💾", elem_id="open_folder")
button_save_as_config = gr.Button("Save as... 💾",
elem_id="open_folder")
button_open_config = gr.Button('Open 📂', elem_id='open_folder')
button_save_config = gr.Button('Save 💾', elem_id='open_folder')
button_save_as_config = gr.Button(
'Save as... 💾', elem_id='open_folder'
)
config_file_name = gr.Textbox(
label="", placeholder="type the configuration file path or use the 'Open' button above to select it...")
config_file_name.change(remove_doublequote,
inputs=[config_file_name],
outputs=[config_file_name])
with gr.Tab("Source model"):
label='',
placeholder="type the configuration file path or use the 'Open' button above to select it...",
)
config_file_name.change(
remove_doublequote,
inputs=[config_file_name],
outputs=[config_file_name],
)
with gr.Tab('Source model'):
# Define the input elements
with gr.Row():
pretrained_model_name_or_path_input = gr.Textbox(
label="Pretrained model name or path",
placeholder=
"enter the path to custom model or name of pretrained model",
label='Pretrained model name or path',
placeholder='enter the path to custom model or name of pretrained model',
)
pretrained_model_name_or_path_fille = gr.Button(
document_symbol, elem_id='open_folder_small'
)
pretrained_model_name_or_path_fille.click(
get_file_path, inputs=[pretrained_model_name_or_path_input], outputs=pretrained_model_name_or_path_input
)
pretrained_model_name_or_path_folder = gr.Button(
folder_symbol, elem_id='open_folder_small'
)
pretrained_model_name_or_path_folder.click(
get_folder_path, outputs=pretrained_model_name_or_path_input
)
model_list = gr.Dropdown(
label="(Optional) Model Quick Pick",
label='(Optional) Model Quick Pick',
choices=[
"custom",
"stabilityai/stable-diffusion-2-1-base",
"stabilityai/stable-diffusion-2-base",
"stabilityai/stable-diffusion-2-1",
"stabilityai/stable-diffusion-2",
"runwayml/stable-diffusion-v1-5",
"CompVis/stable-diffusion-v1-4",
'custom',
'stabilityai/stable-diffusion-2-1-base',
'stabilityai/stable-diffusion-2-base',
'stabilityai/stable-diffusion-2-1',
'stabilityai/stable-diffusion-2',
'runwayml/stable-diffusion-v1-5',
'CompVis/stable-diffusion-v1-4',
],
)
with gr.Row():
v2_input = gr.Checkbox(label="v2", value=True)
v_parameterization_input = gr.Checkbox(label="v_parameterization",
value=False)
v2_input = gr.Checkbox(label='v2', value=True)
v_parameterization_input = gr.Checkbox(
label='v_parameterization', value=False
)
pretrained_model_name_or_path_input.change(
remove_doublequote,
inputs=[pretrained_model_name_or_path_input],
@ -511,44 +554,49 @@ with interface:
],
)
with gr.Tab("Directories"):
with gr.Tab('Directories'):
with gr.Row():
train_data_dir_input = gr.Textbox(
label="Image folder",
placeholder=
"Directory where the training folders containing the images are located",
label='Image folder',
placeholder='Directory where the training folders containing the images are located',
)
train_data_dir_input_folder = gr.Button(
"📂", elem_id="open_folder_small")
train_data_dir_input_folder.click(get_folder_path,
outputs=train_data_dir_input)
reg_data_dir_input = gr.Textbox(
label="Regularisation folder",
placeholder=
"(Optional) Directory where where the regularization folders containing the images are located",
'📂', elem_id='open_folder_small'
)
train_data_dir_input_folder.click(
get_folder_path, outputs=train_data_dir_input
)
reg_data_dir_input = gr.Textbox(
label='Regularisation folder',
placeholder='(Optional) Directory where where the regularization folders containing the images are located',
)
reg_data_dir_input_folder = gr.Button(
'📂', elem_id='open_folder_small'
)
reg_data_dir_input_folder.click(
get_folder_path, outputs=reg_data_dir_input
)
reg_data_dir_input_folder = gr.Button("📂",
elem_id="open_folder_small")
reg_data_dir_input_folder.click(get_folder_path,
outputs=reg_data_dir_input)
with gr.Row():
output_dir_input = gr.Textbox(
label="Output directory",
placeholder="Directory to output trained model",
label='Output directory',
placeholder='Directory to output trained model',
)
output_dir_input_folder = gr.Button(
'📂', elem_id='open_folder_small'
)
output_dir_input_folder.click(
get_folder_path, outputs=output_dir_input
)
output_dir_input_folder = gr.Button("📂",
elem_id="open_folder_small")
output_dir_input_folder.click(get_folder_path,
outputs=output_dir_input)
logging_dir_input = gr.Textbox(
label="Logging directory",
placeholder=
"Optional: enable logging and output TensorBoard log to this directory",
label='Logging directory',
placeholder='Optional: enable logging and output TensorBoard log to this directory',
)
logging_dir_input_folder = gr.Button(
'📂', elem_id='open_folder_small'
)
logging_dir_input_folder.click(
get_folder_path, outputs=logging_dir_input
)
logging_dir_input_folder = gr.Button("📂",
elem_id="open_folder_small")
logging_dir_input_folder.click(get_folder_path,
outputs=logging_dir_input)
train_data_dir_input.change(
remove_doublequote,
inputs=[train_data_dir_input],
@ -559,111 +607,130 @@ with interface:
inputs=[reg_data_dir_input],
outputs=[reg_data_dir_input],
)
output_dir_input.change(remove_doublequote,
inputs=[output_dir_input],
outputs=[output_dir_input])
logging_dir_input.change(remove_doublequote,
inputs=[logging_dir_input],
outputs=[logging_dir_input])
with gr.Tab("Training parameters"):
output_dir_input.change(
remove_doublequote,
inputs=[output_dir_input],
outputs=[output_dir_input],
)
logging_dir_input.change(
remove_doublequote,
inputs=[logging_dir_input],
outputs=[logging_dir_input],
)
with gr.Tab('Training parameters'):
with gr.Row():
learning_rate_input = gr.Textbox(label="Learning rate", value=1e-6)
learning_rate_input = gr.Textbox(label='Learning rate', value=1e-6)
lr_scheduler_input = gr.Dropdown(
label="LR Scheduler",
label='LR Scheduler',
choices=[
"constant",
"constant_with_warmup",
"cosine",
"cosine_with_restarts",
"linear",
"polynomial",
'constant',
'constant_with_warmup',
'cosine',
'cosine_with_restarts',
'linear',
'polynomial',
],
value="constant",
value='constant',
)
lr_warmup_input = gr.Textbox(label="LR warmup", value=0)
lr_warmup_input = gr.Textbox(label='LR warmup', value=0)
with gr.Row():
train_batch_size_input = gr.Slider(minimum=1,
maximum=32,
label="Train batch size",
value=1,
step=1)
epoch_input = gr.Textbox(label="Epoch", value=1)
save_every_n_epochs_input = gr.Textbox(label="Save every N epochs",
value=1)
train_batch_size_input = gr.Slider(
minimum=1,
maximum=32,
label='Train batch size',
value=1,
step=1,
)
epoch_input = gr.Textbox(label='Epoch', value=1)
save_every_n_epochs_input = gr.Textbox(
label='Save every N epochs', value=1
)
with gr.Row():
mixed_precision_input = gr.Dropdown(
label="Mixed precision",
label='Mixed precision',
choices=[
"no",
"fp16",
"bf16",
'no',
'fp16',
'bf16',
],
value="fp16",
value='fp16',
)
save_precision_input = gr.Dropdown(
label="Save precision",
label='Save precision',
choices=[
"float",
"fp16",
"bf16",
'float',
'fp16',
'bf16',
],
value="fp16",
value='fp16',
)
num_cpu_threads_per_process_input = gr.Slider(
minimum=1,
maximum=os.cpu_count(),
step=1,
label="Number of CPU threads per process",
label='Number of CPU threads per process',
value=os.cpu_count(),
)
with gr.Row():
seed_input = gr.Textbox(label="Seed", value=1234)
max_resolution_input = gr.Textbox(label="Max resolution",
value="512,512",
placeholder="512,512")
seed_input = gr.Textbox(label='Seed', value=1234)
max_resolution_input = gr.Textbox(
label='Max resolution', value='512,512', placeholder='512,512'
)
with gr.Row():
caption_extention_input = gr.Textbox(
label="Caption Extension",
placeholder=
"(Optional) Extension for caption files. default: .caption",
label='Caption Extension',
placeholder='(Optional) Extension for caption files. default: .caption',
)
stop_text_encoder_training_input = gr.Slider(
minimum=0,
maximum=100,
value=0,
step=1,
label="Stop text encoder training",
label='Stop text encoder training',
)
with gr.Row():
full_fp16_input = gr.Checkbox(
label="Full fp16 training (experimental)", value=False)
no_token_padding_input = gr.Checkbox(label="No token padding",
value=False)
label='Full fp16 training (experimental)', value=False
)
no_token_padding_input = gr.Checkbox(
label='No token padding', value=False
)
use_safetensors_input = gr.Checkbox(
label="Use safetensor when saving", value=False)
label='Use safetensor when saving', value=False
)
gradient_checkpointing_input = gr.Checkbox(
label="Gradient checkpointing", value=False)
label='Gradient checkpointing', value=False
)
with gr.Row():
enable_bucket_input = gr.Checkbox(label="Enable buckets",
value=True)
cache_latent_input = gr.Checkbox(label="Cache latent", value=True)
use_8bit_adam_input = gr.Checkbox(label="Use 8bit adam",
value=True)
xformers_input = gr.Checkbox(label="Use xformers", value=True)
enable_bucket_input = gr.Checkbox(
label='Enable buckets', value=True
)
cache_latent_input = gr.Checkbox(label='Cache latent', value=True)
use_8bit_adam_input = gr.Checkbox(
label='Use 8bit adam', value=True
)
xformers_input = gr.Checkbox(label='Use xformers', value=True)
with gr.Tab("Model conversion"):
with gr.Tab('Model conversion'):
convert_to_safetensors_input = gr.Checkbox(
label="Convert to SafeTensors", value=True)
convert_to_ckpt_input = gr.Checkbox(label="Convert to CKPT",
value=False)
with gr.Tab("Utilities"):
label='Convert to SafeTensors', value=True
)
convert_to_ckpt_input = gr.Checkbox(
label='Convert to CKPT', value=False
)
with gr.Tab('Utilities'):
# Dreambooth folder creation tab
gradio_dreambooth_folder_creation_tab(train_data_dir_input, reg_data_dir_input, output_dir_input, logging_dir_input)
gradio_dreambooth_folder_creation_tab(
train_data_dir_input,
reg_data_dir_input,
output_dir_input,
logging_dir_input,
)
# Captionning tab
gradio_caption_gui_tab()
button_run = gr.Button("Train model")
button_run = gr.Button('Train model')
button_open_config.click(
open_configuration,

View File

@ -3,69 +3,85 @@ from easygui import msgbox
import subprocess
from .common_gui import get_folder_path
def caption_images(caption_text_input, images_dir_input, overwrite_input, caption_file_ext):
def caption_images(
caption_text_input, images_dir_input, overwrite_input, caption_file_ext
):
# Check for caption_text_input
if caption_text_input == "":
msgbox("Caption text is missing...")
if caption_text_input == '':
msgbox('Caption text is missing...')
return
# Check for images_dir_input
if images_dir_input == "":
msgbox("Image folder is missing...")
if images_dir_input == '':
msgbox('Image folder is missing...')
return
print(f"Captionning files in {images_dir_input} with {caption_text_input}...")
print(
f'Captionning files in {images_dir_input} with {caption_text_input}...'
)
run_cmd = f'python "tools/caption.py"'
run_cmd += f' --caption_text="{caption_text_input}"'
if overwrite_input:
run_cmd += f' --overwrite'
if caption_file_ext != "":
if caption_file_ext != '':
run_cmd += f' --caption_file_ext="{caption_file_ext}"'
run_cmd += f' "{images_dir_input}"'
print(run_cmd)
# Run the command
subprocess.run(run_cmd)
print("...captionning done")
print('...captionning done')
###
# Gradio UI
###
def gradio_caption_gui_tab():
with gr.Tab("Captionning"):
with gr.Tab('Captionning'):
gr.Markdown(
"This utility will allow the creation of caption files for each images in a folder."
'This utility will allow the creation of caption files for each images in a folder.'
)
with gr.Row():
caption_text_input = gr.Textbox(
label="Caption text",
placeholder="Eg: , by some artist",
label='Caption text',
placeholder='Eg: , by some artist',
interactive=True,
)
)
overwrite_input = gr.Checkbox(
label="Overwrite existing captions in folder",
label='Overwrite existing captions in folder',
interactive=True,
value=False
value=False,
)
caption_file_ext = gr.Textbox(
label="Caption file extension",
placeholder="(Optional) Default: .caption",
label='Caption file extension',
placeholder='(Optional) Default: .caption',
interactive=True,
)
with gr.Row():
images_dir_input = gr.Textbox(
label="Image forder to caption",
placeholder="Directory containing the images to caption",
label='Image forder to caption',
placeholder='Directory containing the images to caption',
interactive=True,
)
button_images_dir_input = gr.Button(
"📂", elem_id="open_folder_small")
'📂', elem_id='open_folder_small'
)
button_images_dir_input.click(
get_folder_path, outputs=images_dir_input)
caption_button = gr.Button("Caption images")
caption_button.click(caption_images, inputs=[caption_text_input, images_dir_input, overwrite_input, caption_file_ext])
get_folder_path, outputs=images_dir_input
)
caption_button = gr.Button('Caption images')
caption_button.click(
caption_images,
inputs=[
caption_text_input,
images_dir_input,
overwrite_input,
caption_file_ext,
],
)

View File

@ -1,19 +1,22 @@
from easygui import diropenbox, fileopenbox
def get_folder_path():
folder_path = diropenbox("Select the directory to use")
folder_path = diropenbox('Select the directory to use')
return folder_path
def remove_doublequote(file_path):
if file_path != None:
file_path = file_path.replace('"', "")
def get_file_path(file_path):
file_path = fileopenbox(
'Select the config file to load', default=file_path, filetypes='*.json',
)
return file_path
def get_file_path(file_path):
file_path = fileopenbox("Select the config file to load",
default=file_path,
filetypes="*.json")
return file_path
def remove_doublequote(file_path):
if file_path != None:
file_path = file_path.replace('"', '')
return file_path

View File

@ -4,14 +4,15 @@ from .common_gui import get_folder_path
import shutil
import os
def copy_info_to_Directories_tab(training_folder):
img_folder = os.path.join(training_folder, "img")
if os.path.exists(os.path.join(training_folder, "reg")):
reg_folder = os.path.join(training_folder, "reg")
img_folder = os.path.join(training_folder, 'img')
if os.path.exists(os.path.join(training_folder, 'reg')):
reg_folder = os.path.join(training_folder, 'reg')
else:
reg_folder = ""
model_folder = os.path.join(training_folder, "model")
log_folder = os.path.join(training_folder, "log")
reg_folder = ''
model_folder = os.path.join(training_folder, 'model')
log_folder = os.path.join(training_folder, 'log')
return img_folder, reg_folder, model_folder, log_folder
@ -27,7 +28,7 @@ def dreambooth_folder_preparation(
):
# Check if the input variables are empty
if (not len(util_training_dir_output)):
if not len(util_training_dir_output):
print(
"Destination training directory is missing... can't perform the required task..."
)
@ -37,17 +38,17 @@ def dreambooth_folder_preparation(
os.makedirs(util_training_dir_output, exist_ok=True)
# Check for instance prompt
if util_instance_prompt_input == "":
msgbox("Instance prompt missing...")
if util_instance_prompt_input == '':
msgbox('Instance prompt missing...')
return
# Check for class prompt
if util_class_prompt_input == "":
msgbox("Class prompt missing...")
if util_class_prompt_input == '':
msgbox('Class prompt missing...')
return
# Create the training_dir path
if (util_training_images_dir_input == ""):
if util_training_images_dir_input == '':
print(
"Training images directory is missing... can't perform the required task..."
)
@ -55,106 +56,120 @@ def dreambooth_folder_preparation(
else:
training_dir = os.path.join(
util_training_dir_output,
f"img/{int(util_training_images_repeat_input)}_{util_instance_prompt_input} {util_class_prompt_input}",
f'img/{int(util_training_images_repeat_input)}_{util_instance_prompt_input} {util_class_prompt_input}',
)
# Remove folders if they exist
if os.path.exists(training_dir):
print(f"Removing existing directory {training_dir}...")
print(f'Removing existing directory {training_dir}...')
shutil.rmtree(training_dir)
# Copy the training images to their respective directories
print(f"Copy {util_training_images_dir_input} to {training_dir}...")
print(f'Copy {util_training_images_dir_input} to {training_dir}...')
shutil.copytree(util_training_images_dir_input, training_dir)
# Create the regularization_dir path
if (not (util_class_prompt_input == "")
or not util_regularization_images_repeat_input > 0):
if (
util_class_prompt_input == ''
or not util_regularization_images_repeat_input > 0
):
print(
"Regularization images directory or repeats is missing... not copying regularisation images..."
'Regularization images directory or repeats is missing... not copying regularisation images...'
)
else:
regularization_dir = os.path.join(
util_training_dir_output,
f"reg/{int(util_regularization_images_repeat_input)}_{util_class_prompt_input}",
f'reg/{int(util_regularization_images_repeat_input)}_{util_class_prompt_input}',
)
# Remove folders if they exist
if os.path.exists(regularization_dir):
print(f"Removing existing directory {regularization_dir}...")
print(f'Removing existing directory {regularization_dir}...')
shutil.rmtree(regularization_dir)
# Copy the regularisation images to their respective directories
print(
f"Copy {util_regularization_images_dir_input} to {regularization_dir}..."
f'Copy {util_regularization_images_dir_input} to {regularization_dir}...'
)
shutil.copytree(
util_regularization_images_dir_input, regularization_dir
)
shutil.copytree(util_regularization_images_dir_input,
regularization_dir)
print(
f"Done creating kohya_ss training folder structure at {util_training_dir_output}..."
f'Done creating kohya_ss training folder structure at {util_training_dir_output}...'
)
def gradio_dreambooth_folder_creation_tab(train_data_dir_input, reg_data_dir_input, output_dir_input, logging_dir_input):
with gr.Tab("Dreambooth folder preparation"):
def gradio_dreambooth_folder_creation_tab(
train_data_dir_input,
reg_data_dir_input,
output_dir_input,
logging_dir_input,
):
with gr.Tab('Dreambooth folder preparation'):
gr.Markdown(
"This utility will create the necessary folder structure for the training images and optional regularization images needed for the kohys_ss Dreambooth method to function correctly."
'This utility will create the necessary folder structure for the training images and optional regularization images needed for the kohys_ss Dreambooth method to function correctly.'
)
with gr.Row():
util_instance_prompt_input = gr.Textbox(
label="Instance prompt",
placeholder="Eg: asd",
label='Instance prompt',
placeholder='Eg: asd',
interactive=True,
)
util_class_prompt_input = gr.Textbox(
label="Class prompt",
placeholder="Eg: person",
label='Class prompt',
placeholder='Eg: person',
interactive=True,
)
with gr.Row():
util_training_images_dir_input = gr.Textbox(
label="Training images",
placeholder="Directory containing the training images",
label='Training images',
placeholder='Directory containing the training images',
interactive=True,
)
button_util_training_images_dir_input = gr.Button(
"📂", elem_id="open_folder_small")
'📂', elem_id='open_folder_small'
)
button_util_training_images_dir_input.click(
get_folder_path, outputs=util_training_images_dir_input)
get_folder_path, outputs=util_training_images_dir_input
)
util_training_images_repeat_input = gr.Number(
label="Repeats",
label='Repeats',
value=40,
interactive=True,
elem_id="number_input")
elem_id='number_input',
)
with gr.Row():
util_regularization_images_dir_input = gr.Textbox(
label="Regularisation images",
placeholder=
"(Optional) Directory containing the regularisation images",
label='Regularisation images',
placeholder='(Optional) Directory containing the regularisation images',
interactive=True,
)
button_util_regularization_images_dir_input = gr.Button(
"📂", elem_id="open_folder_small")
'📂', elem_id='open_folder_small'
)
button_util_regularization_images_dir_input.click(
get_folder_path,
outputs=util_regularization_images_dir_input)
get_folder_path, outputs=util_regularization_images_dir_input
)
util_regularization_images_repeat_input = gr.Number(
label="Repeats",
label='Repeats',
value=1,
interactive=True,
elem_id="number_input")
elem_id='number_input',
)
with gr.Row():
util_training_dir_output = gr.Textbox(
label="Destination training directory",
placeholder=
"Directory where formatted training and regularisation folders will be placed",
label='Destination training directory',
placeholder='Directory where formatted training and regularisation folders will be placed',
interactive=True,
)
button_util_training_dir_output = gr.Button(
"📂", elem_id="open_folder_small")
'📂', elem_id='open_folder_small'
)
button_util_training_dir_output.click(
get_folder_path, outputs=util_training_dir_output)
button_prepare_training_data = gr.Button("Prepare training data")
get_folder_path, outputs=util_training_dir_output
)
button_prepare_training_data = gr.Button('Prepare training data')
button_prepare_training_data.click(
dreambooth_folder_preparation,
inputs=[
@ -168,12 +183,15 @@ def gradio_dreambooth_folder_creation_tab(train_data_dir_input, reg_data_dir_inp
],
)
button_copy_info_to_Directories_tab = gr.Button(
"Copy info to Directories Tab")
button_copy_info_to_Directories_tab.click(copy_info_to_Directories_tab,
inputs=[util_training_dir_output],
outputs=[
train_data_dir_input,
reg_data_dir_input,
output_dir_input,
logging_dir_input
])
'Copy info to Directories Tab'
)
button_copy_info_to_Directories_tab.click(
copy_info_to_Directories_tab,
inputs=[util_training_dir_output],
outputs=[
train_data_dir_input,
reg_data_dir_input,
output_dir_input,
logging_dir_input,
],
)