Integrate new bucket parameters in GUI
This commit is contained in:
parent
2486af9903
commit
cbfc311687
@ -82,8 +82,12 @@ def save_configuration(
|
||||
max_data_loader_n_workers,
|
||||
mem_eff_attn,
|
||||
gradient_accumulation_steps,
|
||||
model_list, keep_tokens,
|
||||
model_list,
|
||||
keep_tokens,
|
||||
persistent_data_loader_workers,
|
||||
bucket_no_upscale,
|
||||
random_crop,
|
||||
bucket_reso_steps,
|
||||
):
|
||||
# Get list of function parameters and values
|
||||
parameters = list(locals().items())
|
||||
@ -167,8 +171,12 @@ def open_configuration(
|
||||
max_data_loader_n_workers,
|
||||
mem_eff_attn,
|
||||
gradient_accumulation_steps,
|
||||
model_list, keep_tokens,
|
||||
model_list,
|
||||
keep_tokens,
|
||||
persistent_data_loader_workers,
|
||||
bucket_no_upscale,
|
||||
random_crop,
|
||||
bucket_reso_steps,
|
||||
):
|
||||
# Get list of function parameters and values
|
||||
parameters = list(locals().items())
|
||||
@ -239,6 +247,9 @@ def train_model(
|
||||
model_list, # Keep this. Yes, it is unused here but required given the common list used
|
||||
keep_tokens,
|
||||
persistent_data_loader_workers,
|
||||
bucket_no_upscale,
|
||||
random_crop,
|
||||
bucket_reso_steps,
|
||||
):
|
||||
if pretrained_model_name_or_path == '':
|
||||
msgbox('Source model information is missing')
|
||||
@ -402,6 +413,9 @@ def train_model(
|
||||
use_8bit_adam=use_8bit_adam,
|
||||
keep_tokens=keep_tokens,
|
||||
persistent_data_loader_workers=persistent_data_loader_workers,
|
||||
bucket_no_upscale=bucket_no_upscale,
|
||||
random_crop=random_crop,
|
||||
bucket_reso_steps=bucket_reso_steps,
|
||||
)
|
||||
|
||||
print(run_cmd)
|
||||
@ -610,6 +624,9 @@ def dreambooth_tab(
|
||||
max_data_loader_n_workers,
|
||||
keep_tokens,
|
||||
persistent_data_loader_workers,
|
||||
bucket_no_upscale,
|
||||
random_crop,
|
||||
bucket_reso_steps,
|
||||
) = gradio_advanced_training()
|
||||
color_aug.change(
|
||||
color_aug_changed,
|
||||
@ -675,6 +692,9 @@ def dreambooth_tab(
|
||||
model_list,
|
||||
keep_tokens,
|
||||
persistent_data_loader_workers,
|
||||
bucket_no_upscale,
|
||||
random_crop,
|
||||
bucket_reso_steps,
|
||||
]
|
||||
|
||||
button_open_config.click(
|
||||
|
@ -78,8 +78,12 @@ def save_configuration(
|
||||
color_aug,
|
||||
model_list,
|
||||
cache_latents,
|
||||
use_latent_files, keep_tokens,
|
||||
use_latent_files,
|
||||
keep_tokens,
|
||||
persistent_data_loader_workers,
|
||||
bucket_no_upscale,
|
||||
random_crop,
|
||||
bucket_reso_steps,
|
||||
):
|
||||
# Get list of function parameters and values
|
||||
parameters = list(locals().items())
|
||||
@ -169,8 +173,12 @@ def open_config_file(
|
||||
color_aug,
|
||||
model_list,
|
||||
cache_latents,
|
||||
use_latent_files, keep_tokens,
|
||||
use_latent_files,
|
||||
keep_tokens,
|
||||
persistent_data_loader_workers,
|
||||
bucket_no_upscale,
|
||||
random_crop,
|
||||
bucket_reso_steps,
|
||||
):
|
||||
# Get list of function parameters and values
|
||||
parameters = list(locals().items())
|
||||
@ -245,8 +253,12 @@ def train_model(
|
||||
color_aug,
|
||||
model_list, # Keep this. Yes, it is unused here but required given the common list used
|
||||
cache_latents,
|
||||
use_latent_files, keep_tokens,
|
||||
use_latent_files,
|
||||
keep_tokens,
|
||||
persistent_data_loader_workers,
|
||||
bucket_no_upscale,
|
||||
random_crop,
|
||||
bucket_reso_steps,
|
||||
):
|
||||
# create caption json file
|
||||
if generate_caption_database:
|
||||
@ -295,7 +307,11 @@ def train_model(
|
||||
subprocess.run(run_cmd)
|
||||
|
||||
image_num = len(
|
||||
[f for f in os.listdir(image_folder) if f.endswith('.jpg') or f.endswith('.png') or f.endswith('.webp')]
|
||||
[
|
||||
f
|
||||
for f in os.listdir(image_folder)
|
||||
if f.endswith('.jpg') or f.endswith('.png') or f.endswith('.webp')
|
||||
]
|
||||
)
|
||||
print(f'image_num = {image_num}')
|
||||
|
||||
@ -386,6 +402,9 @@ def train_model(
|
||||
use_8bit_adam=use_8bit_adam,
|
||||
keep_tokens=keep_tokens,
|
||||
persistent_data_loader_workers=persistent_data_loader_workers,
|
||||
bucket_no_upscale=bucket_no_upscale,
|
||||
random_crop=random_crop,
|
||||
bucket_reso_steps=bucket_reso_steps,
|
||||
)
|
||||
|
||||
print(run_cmd)
|
||||
@ -592,6 +611,9 @@ def finetune_tab():
|
||||
max_data_loader_n_workers,
|
||||
keep_tokens,
|
||||
persistent_data_loader_workers,
|
||||
bucket_no_upscale,
|
||||
random_crop,
|
||||
bucket_reso_steps,
|
||||
) = gradio_advanced_training()
|
||||
color_aug.change(
|
||||
color_aug_changed,
|
||||
@ -653,6 +675,9 @@ def finetune_tab():
|
||||
use_latent_files,
|
||||
keep_tokens,
|
||||
persistent_data_loader_workers,
|
||||
bucket_no_upscale,
|
||||
random_crop,
|
||||
bucket_reso_steps,
|
||||
]
|
||||
|
||||
button_run.click(train_model, inputs=settings_list)
|
||||
|
@ -19,7 +19,7 @@ def UI(username, password):
|
||||
print('Load CSS...')
|
||||
css += file.read() + '\n'
|
||||
|
||||
interface = gr.Blocks(css=css, title="Kohya_ss GUI")
|
||||
interface = gr.Blocks(css=css, title='Kohya_ss GUI')
|
||||
|
||||
with interface:
|
||||
with gr.Tab('Dreambooth'):
|
||||
|
@ -10,13 +10,15 @@ def caption_images(
|
||||
overwrite_input,
|
||||
caption_file_ext,
|
||||
prefix,
|
||||
postfix, find, replace
|
||||
postfix,
|
||||
find,
|
||||
replace,
|
||||
):
|
||||
# Check for images_dir_input
|
||||
if images_dir_input == '':
|
||||
msgbox('Image folder is missing...')
|
||||
return
|
||||
|
||||
|
||||
if caption_file_ext == '':
|
||||
msgbox('Please provide an extension for the caption files.')
|
||||
return
|
||||
@ -39,7 +41,7 @@ def caption_images(
|
||||
subprocess.run(run_cmd)
|
||||
|
||||
if overwrite_input:
|
||||
if not prefix=='' or not postfix=='':
|
||||
if not prefix == '' or not postfix == '':
|
||||
# Add prefix and postfix
|
||||
add_pre_postfix(
|
||||
folder=images_dir_input,
|
||||
@ -47,7 +49,7 @@ def caption_images(
|
||||
prefix=prefix,
|
||||
postfix=postfix,
|
||||
)
|
||||
if not find=='':
|
||||
if not find == '':
|
||||
find_replace(
|
||||
folder=images_dir_input,
|
||||
caption_file_ext=caption_file_ext,
|
||||
@ -134,6 +136,7 @@ def gradio_basic_caption_gui_tab():
|
||||
caption_file_ext,
|
||||
prefix,
|
||||
postfix,
|
||||
find, replace
|
||||
find,
|
||||
replace,
|
||||
],
|
||||
)
|
||||
|
@ -26,7 +26,7 @@ def caption_images(
|
||||
if train_data_dir == '':
|
||||
msgbox('Image folder is missing...')
|
||||
return
|
||||
|
||||
|
||||
if caption_file_ext == '':
|
||||
msgbox('Please provide an extension for the caption files.')
|
||||
return
|
||||
|
@ -9,6 +9,7 @@ refresh_symbol = '\U0001f504' # 🔄
|
||||
save_style_symbol = '\U0001f4be' # 💾
|
||||
document_symbol = '\U0001F4C4' # 📄
|
||||
|
||||
|
||||
def get_dir_and_file(file_path):
|
||||
dir_path, file_name = os.path.split(file_path)
|
||||
return (dir_path, file_name)
|
||||
@ -200,7 +201,7 @@ def find_replace(folder='', caption_file_ext='.caption', find='', replace=''):
|
||||
|
||||
files = [f for f in os.listdir(folder) if f.endswith(caption_file_ext)]
|
||||
for file in files:
|
||||
with open(os.path.join(folder, file), 'r', errors="ignore") as f:
|
||||
with open(os.path.join(folder, file), 'r', errors='ignore') as f:
|
||||
content = f.read()
|
||||
f.close
|
||||
content = content.replace(find, replace)
|
||||
@ -304,7 +305,8 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
|
||||
###
|
||||
### Gradio common GUI section
|
||||
###
|
||||
|
||||
|
||||
|
||||
def gradio_config():
|
||||
with gr.Accordion('Configuration file', open=False):
|
||||
with gr.Row():
|
||||
@ -318,7 +320,13 @@ def gradio_config():
|
||||
placeholder="type the configuration file path or use the 'Open' button above to select it...",
|
||||
interactive=True,
|
||||
)
|
||||
return (button_open_config, button_save_config, button_save_as_config, config_file_name)
|
||||
return (
|
||||
button_open_config,
|
||||
button_save_config,
|
||||
button_save_as_config,
|
||||
config_file_name,
|
||||
)
|
||||
|
||||
|
||||
def gradio_source_model():
|
||||
with gr.Tab('Source model'):
|
||||
@ -382,9 +390,20 @@ def gradio_source_model():
|
||||
v_parameterization,
|
||||
],
|
||||
)
|
||||
return (pretrained_model_name_or_path, v2, v_parameterization, save_model_as, model_list)
|
||||
return (
|
||||
pretrained_model_name_or_path,
|
||||
v2,
|
||||
v_parameterization,
|
||||
save_model_as,
|
||||
model_list,
|
||||
)
|
||||
|
||||
def gradio_training(learning_rate_value='1e-6', lr_scheduler_value='constant', lr_warmup_value='0'):
|
||||
|
||||
def gradio_training(
|
||||
learning_rate_value='1e-6',
|
||||
lr_scheduler_value='constant',
|
||||
lr_warmup_value='0',
|
||||
):
|
||||
with gr.Row():
|
||||
train_batch_size = gr.Slider(
|
||||
minimum=1,
|
||||
@ -394,9 +413,7 @@ def gradio_training(learning_rate_value='1e-6', lr_scheduler_value='constant', l
|
||||
step=1,
|
||||
)
|
||||
epoch = gr.Textbox(label='Epoch', value=1)
|
||||
save_every_n_epochs = gr.Textbox(
|
||||
label='Save every N epochs', value=1
|
||||
)
|
||||
save_every_n_epochs = gr.Textbox(label='Save every N epochs', value=1)
|
||||
caption_extension = gr.Textbox(
|
||||
label='Caption Extension',
|
||||
placeholder='(Optional) Extension for caption files. default: .caption',
|
||||
@ -429,7 +446,9 @@ def gradio_training(learning_rate_value='1e-6', lr_scheduler_value='constant', l
|
||||
)
|
||||
seed = gr.Textbox(label='Seed', value=1234)
|
||||
with gr.Row():
|
||||
learning_rate = gr.Textbox(label='Learning rate', value=learning_rate_value)
|
||||
learning_rate = gr.Textbox(
|
||||
label='Learning rate', value=learning_rate_value
|
||||
)
|
||||
lr_scheduler = gr.Dropdown(
|
||||
label='LR Scheduler',
|
||||
choices=[
|
||||
@ -442,7 +461,9 @@ def gradio_training(learning_rate_value='1e-6', lr_scheduler_value='constant', l
|
||||
],
|
||||
value=lr_scheduler_value,
|
||||
)
|
||||
lr_warmup = gr.Textbox(label='LR warmup (% of steps)', value=lr_warmup_value)
|
||||
lr_warmup = gr.Textbox(
|
||||
label='LR warmup (% of steps)', value=lr_warmup_value
|
||||
)
|
||||
cache_latents = gr.Checkbox(label='Cache latent', value=True)
|
||||
return (
|
||||
learning_rate,
|
||||
@ -459,50 +480,38 @@ def gradio_training(learning_rate_value='1e-6', lr_scheduler_value='constant', l
|
||||
cache_latents,
|
||||
)
|
||||
|
||||
|
||||
def run_cmd_training(**kwargs):
|
||||
options = [
|
||||
f' --learning_rate="{kwargs.get("learning_rate", "")}"'
|
||||
if kwargs.get('learning_rate')
|
||||
else '',
|
||||
|
||||
f' --lr_scheduler="{kwargs.get("lr_scheduler", "")}"'
|
||||
if kwargs.get('lr_scheduler')
|
||||
else '',
|
||||
|
||||
f' --lr_warmup_steps="{kwargs.get("lr_warmup_steps", "")}"'
|
||||
if kwargs.get('lr_warmup_steps')
|
||||
else '',
|
||||
|
||||
f' --train_batch_size="{kwargs.get("train_batch_size", "")}"'
|
||||
if kwargs.get('train_batch_size')
|
||||
else '',
|
||||
|
||||
f' --max_train_steps="{kwargs.get("max_train_steps", "")}"'
|
||||
if kwargs.get('max_train_steps')
|
||||
else '',
|
||||
|
||||
f' --save_every_n_epochs="{kwargs.get("save_every_n_epochs", "")}"'
|
||||
if kwargs.get('save_every_n_epochs')
|
||||
else '',
|
||||
|
||||
f' --mixed_precision="{kwargs.get("mixed_precision", "")}"'
|
||||
if kwargs.get('mixed_precision')
|
||||
else '',
|
||||
|
||||
f' --save_precision="{kwargs.get("save_precision", "")}"'
|
||||
if kwargs.get('save_precision')
|
||||
else '',
|
||||
|
||||
f' --seed="{kwargs.get("seed", "")}"'
|
||||
if kwargs.get('seed')
|
||||
else '',
|
||||
|
||||
f' --seed="{kwargs.get("seed", "")}"' if kwargs.get('seed') else '',
|
||||
f' --caption_extension="{kwargs.get("caption_extension", "")}"'
|
||||
if kwargs.get('caption_extension')
|
||||
else '',
|
||||
|
||||
' --cache_latents' if kwargs.get('cache_latents') else '',
|
||||
|
||||
]
|
||||
run_cmd = ''.join(options)
|
||||
return run_cmd
|
||||
@ -532,9 +541,7 @@ def gradio_advanced_training():
|
||||
gradient_checkpointing = gr.Checkbox(
|
||||
label='Gradient checkpointing', value=False
|
||||
)
|
||||
shuffle_caption = gr.Checkbox(
|
||||
label='Shuffle caption', value=False
|
||||
)
|
||||
shuffle_caption = gr.Checkbox(label='Shuffle caption', value=False)
|
||||
persistent_data_loader_workers = gr.Checkbox(
|
||||
label='Persistent data loader', value=False
|
||||
)
|
||||
@ -544,10 +551,18 @@ def gradio_advanced_training():
|
||||
with gr.Row():
|
||||
use_8bit_adam = gr.Checkbox(label='Use 8bit adam', value=True)
|
||||
xformers = gr.Checkbox(label='Use xformers', value=True)
|
||||
color_aug = gr.Checkbox(
|
||||
label='Color augmentation', value=False
|
||||
)
|
||||
color_aug = gr.Checkbox(label='Color augmentation', value=False)
|
||||
flip_aug = gr.Checkbox(label='Flip augmentation', value=False)
|
||||
with gr.Row():
|
||||
bucket_no_upscale = gr.Checkbox(
|
||||
label="Don't upscale bucket resolution", value=True
|
||||
)
|
||||
random_crop = gr.Checkbox(
|
||||
label='Random crop instead of center crop', value=False
|
||||
)
|
||||
bucket_reso_steps = gr.Number(
|
||||
label='Bucket resolution steps', value=64
|
||||
)
|
||||
with gr.Row():
|
||||
save_state = gr.Checkbox(label='Save training state', value=False)
|
||||
resume = gr.Textbox(
|
||||
@ -581,55 +596,53 @@ def gradio_advanced_training():
|
||||
max_data_loader_n_workers,
|
||||
keep_tokens,
|
||||
persistent_data_loader_workers,
|
||||
bucket_no_upscale,
|
||||
random_crop,
|
||||
bucket_reso_steps,
|
||||
)
|
||||
|
||||
|
||||
def run_cmd_advanced_training(**kwargs):
|
||||
options = [
|
||||
f' --max_train_epochs="{kwargs.get("max_train_epochs", "")}"'
|
||||
if kwargs.get('max_train_epochs')
|
||||
else '',
|
||||
|
||||
f' --max_data_loader_n_workers="{kwargs.get("max_data_loader_n_workers", "")}"'
|
||||
if kwargs.get('max_data_loader_n_workers')
|
||||
else '',
|
||||
|
||||
f' --max_token_length={kwargs.get("max_token_length", "")}'
|
||||
if int(kwargs.get('max_token_length', 75)) > 75
|
||||
else '',
|
||||
|
||||
f' --clip_skip={kwargs.get("clip_skip", "")}'
|
||||
if int(kwargs.get('clip_skip', 1)) > 1
|
||||
else '',
|
||||
|
||||
f' --resume="{kwargs.get("resume", "")}"'
|
||||
if kwargs.get('resume')
|
||||
else '',
|
||||
|
||||
f' --keep_tokens="{kwargs.get("keep_tokens", "")}"'
|
||||
if int(kwargs.get('keep_tokens', 0)) > 0
|
||||
else '',
|
||||
|
||||
f' --bucket_reso_steps={int(kwargs.get("bucket_reso_steps", 1))}'
|
||||
if int(kwargs.get('bucket_reso_steps', 64)) >= 1
|
||||
else '',
|
||||
|
||||
' --save_state' if kwargs.get('save_state') else '',
|
||||
|
||||
' --mem_eff_attn' if kwargs.get('mem_eff_attn') else '',
|
||||
|
||||
' --color_aug' if kwargs.get('color_aug') else '',
|
||||
|
||||
' --flip_aug' if kwargs.get('flip_aug') else '',
|
||||
|
||||
' --shuffle_caption' if kwargs.get('shuffle_caption') else '',
|
||||
|
||||
' --gradient_checkpointing' if kwargs.get('gradient_checkpointing') else '',
|
||||
|
||||
' --gradient_checkpointing'
|
||||
if kwargs.get('gradient_checkpointing')
|
||||
else '',
|
||||
' --full_fp16' if kwargs.get('full_fp16') else '',
|
||||
|
||||
' --xformers' if kwargs.get('xformers') else '',
|
||||
|
||||
' --use_8bit_adam' if kwargs.get('use_8bit_adam') else '',
|
||||
|
||||
' --persistent_data_loader_workers' if kwargs.get('persistent_data_loader_workers') else '',
|
||||
|
||||
' --persistent_data_loader_workers'
|
||||
if kwargs.get('persistent_data_loader_workers')
|
||||
else '',
|
||||
' --bucket_no_upscale' if kwargs.get('bucket_no_upscale') else '',
|
||||
' --random_crop' if kwargs.get('random_crop') else '',
|
||||
]
|
||||
run_cmd = ''.join(options)
|
||||
return run_cmd
|
||||
|
||||
|
@ -191,9 +191,7 @@ def gradio_dreambooth_folder_creation_tab(
|
||||
util_training_dir_output,
|
||||
],
|
||||
)
|
||||
button_copy_info_to_Folders_tab = gr.Button(
|
||||
'Copy info to Folders Tab'
|
||||
)
|
||||
button_copy_info_to_Folders_tab = gr.Button('Copy info to Folders Tab')
|
||||
button_copy_info_to_Folders_tab.click(
|
||||
copy_info_to_Folders_tab,
|
||||
inputs=[util_training_dir_output],
|
||||
|
@ -2,7 +2,11 @@ import gradio as gr
|
||||
from easygui import msgbox
|
||||
import subprocess
|
||||
import os
|
||||
from .common_gui import get_saveasfilename_path, get_any_file_path, get_file_path
|
||||
from .common_gui import (
|
||||
get_saveasfilename_path,
|
||||
get_any_file_path,
|
||||
get_file_path,
|
||||
)
|
||||
|
||||
folder_symbol = '\U0001f4c2' # 📂
|
||||
refresh_symbol = '\U0001f504' # 🔄
|
||||
@ -11,13 +15,18 @@ document_symbol = '\U0001F4C4' # 📄
|
||||
|
||||
|
||||
def extract_lora(
|
||||
model_tuned, model_org, save_to, save_precision, dim, v2,
|
||||
model_tuned,
|
||||
model_org,
|
||||
save_to,
|
||||
save_precision,
|
||||
dim,
|
||||
v2,
|
||||
):
|
||||
# Check for caption_text_input
|
||||
if model_tuned == '':
|
||||
msgbox('Invalid finetuned model file')
|
||||
return
|
||||
|
||||
|
||||
if model_org == '':
|
||||
msgbox('Invalid base model file')
|
||||
return
|
||||
@ -26,12 +35,14 @@ def extract_lora(
|
||||
if not os.path.isfile(model_tuned):
|
||||
msgbox('The provided finetuned model is not a file')
|
||||
return
|
||||
|
||||
|
||||
if not os.path.isfile(model_org):
|
||||
msgbox('The provided base model is not a file')
|
||||
return
|
||||
|
||||
run_cmd = f'.\\venv\Scripts\python.exe "networks\extract_lora_from_models.py"'
|
||||
run_cmd = (
|
||||
f'.\\venv\Scripts\python.exe "networks\extract_lora_from_models.py"'
|
||||
)
|
||||
run_cmd += f' --save_precision {save_precision}'
|
||||
run_cmd += f' --save_to "{save_to}"'
|
||||
run_cmd += f' --model_org "{model_org}"'
|
||||
@ -60,7 +71,7 @@ def gradio_extract_lora_tab():
|
||||
lora_ext_name = gr.Textbox(value='LoRA model types', visible=False)
|
||||
model_ext = gr.Textbox(value='*.ckpt *.safetensors', visible=False)
|
||||
model_ext_name = gr.Textbox(value='Model types', visible=False)
|
||||
|
||||
|
||||
with gr.Row():
|
||||
model_tuned = gr.Textbox(
|
||||
label='Finetuned model',
|
||||
@ -75,7 +86,7 @@ def gradio_extract_lora_tab():
|
||||
inputs=[model_tuned, model_ext, model_ext_name],
|
||||
outputs=model_tuned,
|
||||
)
|
||||
|
||||
|
||||
model_org = gr.Textbox(
|
||||
label='Stable Diffusion base model',
|
||||
placeholder='Stable Diffusion original model: ckpt or safetensors file',
|
||||
@ -99,7 +110,9 @@ def gradio_extract_lora_tab():
|
||||
folder_symbol, elem_id='open_folder_small'
|
||||
)
|
||||
button_save_to.click(
|
||||
get_saveasfilename_path, inputs=[save_to, lora_ext, lora_ext_name], outputs=save_to
|
||||
get_saveasfilename_path,
|
||||
inputs=[save_to, lora_ext, lora_ext_name],
|
||||
outputs=save_to,
|
||||
)
|
||||
save_precision = gr.Dropdown(
|
||||
label='Save precison',
|
||||
@ -122,6 +135,5 @@ def gradio_extract_lora_tab():
|
||||
|
||||
extract_button.click(
|
||||
extract_lora,
|
||||
inputs=[model_tuned, model_org, save_to, save_precision, dim, v2
|
||||
],
|
||||
inputs=[model_tuned, model_org, save_to, save_precision, dim, v2],
|
||||
)
|
||||
|
@ -15,11 +15,11 @@ def caption_images(
|
||||
prefix,
|
||||
postfix,
|
||||
):
|
||||
# Check for images_dir_input
|
||||
# Check for images_dir_input
|
||||
if train_data_dir == '':
|
||||
msgbox('Image folder is missing...')
|
||||
return
|
||||
|
||||
|
||||
if caption_ext == '':
|
||||
msgbox('Please provide an extension for the caption files.')
|
||||
return
|
||||
@ -29,7 +29,9 @@ def caption_images(
|
||||
if not model_id == '':
|
||||
run_cmd += f' --model_id="{model_id}"'
|
||||
run_cmd += f' --batch_size="{int(batch_size)}"'
|
||||
run_cmd += f' --max_data_loader_n_workers="{int(max_data_loader_n_workers)}"'
|
||||
run_cmd += (
|
||||
f' --max_data_loader_n_workers="{int(max_data_loader_n_workers)}"'
|
||||
)
|
||||
run_cmd += f' --max_length="{int(max_length)}"'
|
||||
if caption_ext != '':
|
||||
run_cmd += f' --caption_extension="{caption_ext}"'
|
||||
@ -105,8 +107,9 @@ def gradio_git_caption_gui_tab():
|
||||
value=75, label='Max length', interactive=True
|
||||
)
|
||||
model_id = gr.Textbox(
|
||||
label="Model",
|
||||
placeholder="(Optional) model id for GIT in Hugging Face", interactive=True
|
||||
label='Model',
|
||||
placeholder='(Optional) model id for GIT in Hugging Face',
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
caption_button = gr.Button('Caption images')
|
||||
|
@ -2,7 +2,11 @@ import gradio as gr
|
||||
from easygui import msgbox
|
||||
import subprocess
|
||||
import os
|
||||
from .common_gui import get_saveasfilename_path, get_any_file_path, get_file_path
|
||||
from .common_gui import (
|
||||
get_saveasfilename_path,
|
||||
get_any_file_path,
|
||||
get_file_path,
|
||||
)
|
||||
|
||||
folder_symbol = '\U0001f4c2' # 📂
|
||||
refresh_symbol = '\U0001f504' # 🔄
|
||||
@ -11,13 +15,18 @@ document_symbol = '\U0001F4C4' # 📄
|
||||
|
||||
|
||||
def merge_lora(
|
||||
lora_a_model, lora_b_model, ratio, save_to, precision, save_precision,
|
||||
lora_a_model,
|
||||
lora_b_model,
|
||||
ratio,
|
||||
save_to,
|
||||
precision,
|
||||
save_precision,
|
||||
):
|
||||
# Check for caption_text_input
|
||||
if lora_a_model == '':
|
||||
msgbox('Invalid model A file')
|
||||
return
|
||||
|
||||
|
||||
if lora_b_model == '':
|
||||
msgbox('Invalid model B file')
|
||||
return
|
||||
@ -26,7 +35,7 @@ def merge_lora(
|
||||
if not os.path.isfile(lora_a_model):
|
||||
msgbox('The provided model A is not a file')
|
||||
return
|
||||
|
||||
|
||||
if not os.path.isfile(lora_b_model):
|
||||
msgbox('The provided model B is not a file')
|
||||
return
|
||||
@ -54,13 +63,11 @@ def merge_lora(
|
||||
|
||||
def gradio_merge_lora_tab():
|
||||
with gr.Tab('Merge LoRA'):
|
||||
gr.Markdown(
|
||||
'This utility can merge two LoRA networks together.'
|
||||
)
|
||||
|
||||
gr.Markdown('This utility can merge two LoRA networks together.')
|
||||
|
||||
lora_ext = gr.Textbox(value='*.pt *.safetensors', visible=False)
|
||||
lora_ext_name = gr.Textbox(value='LoRA model types', visible=False)
|
||||
|
||||
|
||||
with gr.Row():
|
||||
lora_a_model = gr.Textbox(
|
||||
label='LoRA model "A"',
|
||||
@ -75,7 +82,7 @@ def gradio_merge_lora_tab():
|
||||
inputs=[lora_a_model, lora_ext, lora_ext_name],
|
||||
outputs=lora_a_model,
|
||||
)
|
||||
|
||||
|
||||
lora_b_model = gr.Textbox(
|
||||
label='LoRA model "B"',
|
||||
placeholder='Path to the LoRA B model',
|
||||
@ -90,9 +97,15 @@ def gradio_merge_lora_tab():
|
||||
outputs=lora_b_model,
|
||||
)
|
||||
with gr.Row():
|
||||
ratio = gr.Slider(label="Merge ratio (eg: 0.7 mean 70% of model A and 30% of model B", minimum=0, maximum=1, step=0.01, value=0.5,
|
||||
interactive=True,)
|
||||
|
||||
ratio = gr.Slider(
|
||||
label='Merge ratio (eg: 0.7 mean 70% of model A and 30% of model B',
|
||||
minimum=0,
|
||||
maximum=1,
|
||||
step=0.01,
|
||||
value=0.5,
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
save_to = gr.Textbox(
|
||||
label='Save to',
|
||||
@ -103,7 +116,9 @@ def gradio_merge_lora_tab():
|
||||
folder_symbol, elem_id='open_folder_small'
|
||||
)
|
||||
button_save_to.click(
|
||||
get_saveasfilename_path, inputs=[save_to, lora_ext, lora_ext_name], outputs=save_to
|
||||
get_saveasfilename_path,
|
||||
inputs=[save_to, lora_ext, lora_ext_name],
|
||||
outputs=save_to,
|
||||
)
|
||||
precision = gr.Dropdown(
|
||||
label='Merge precison',
|
||||
@ -122,6 +137,12 @@ def gradio_merge_lora_tab():
|
||||
|
||||
convert_button.click(
|
||||
merge_lora,
|
||||
inputs=[lora_a_model, lora_b_model, ratio, save_to, precision, save_precision,
|
||||
inputs=[
|
||||
lora_a_model,
|
||||
lora_b_model,
|
||||
ratio,
|
||||
save_to,
|
||||
precision,
|
||||
save_precision,
|
||||
],
|
||||
)
|
||||
|
@ -11,7 +11,11 @@ document_symbol = '\U0001F4C4' # 📄
|
||||
|
||||
|
||||
def resize_lora(
|
||||
model, new_rank, save_to, save_precision, device,
|
||||
model,
|
||||
new_rank,
|
||||
save_to,
|
||||
save_precision,
|
||||
device,
|
||||
):
|
||||
# Check for caption_text_input
|
||||
if model == '':
|
||||
@ -22,7 +26,7 @@ def resize_lora(
|
||||
if not os.path.isfile(model):
|
||||
msgbox('The provided model is not a file')
|
||||
return
|
||||
|
||||
|
||||
if device == '':
|
||||
device = 'cuda'
|
||||
|
||||
@ -46,13 +50,11 @@ def resize_lora(
|
||||
|
||||
def gradio_resize_lora_tab():
|
||||
with gr.Tab('Resize LoRA'):
|
||||
gr.Markdown(
|
||||
'This utility can resize a LoRA.'
|
||||
)
|
||||
|
||||
gr.Markdown('This utility can resize a LoRA.')
|
||||
|
||||
lora_ext = gr.Textbox(value='*.pt *.safetensors', visible=False)
|
||||
lora_ext_name = gr.Textbox(value='LoRA model types', visible=False)
|
||||
|
||||
|
||||
with gr.Row():
|
||||
model = gr.Textbox(
|
||||
label='Source LoRA',
|
||||
@ -68,9 +70,15 @@ def gradio_resize_lora_tab():
|
||||
outputs=model,
|
||||
)
|
||||
with gr.Row():
|
||||
new_rank = gr.Slider(label="Desired LoRA rank", minimum=1, maximum=1024, step=1, value=4,
|
||||
interactive=True,)
|
||||
|
||||
new_rank = gr.Slider(
|
||||
label='Desired LoRA rank',
|
||||
minimum=1,
|
||||
maximum=1024,
|
||||
step=1,
|
||||
value=4,
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
save_to = gr.Textbox(
|
||||
label='Save to',
|
||||
@ -81,7 +89,9 @@ def gradio_resize_lora_tab():
|
||||
folder_symbol, elem_id='open_folder_small'
|
||||
)
|
||||
button_save_to.click(
|
||||
get_saveasfilename_path, inputs=[save_to, lora_ext, lora_ext_name], outputs=save_to
|
||||
get_saveasfilename_path,
|
||||
inputs=[save_to, lora_ext, lora_ext_name],
|
||||
outputs=save_to,
|
||||
)
|
||||
save_precision = gr.Dropdown(
|
||||
label='Save precison',
|
||||
@ -99,6 +109,11 @@ def gradio_resize_lora_tab():
|
||||
|
||||
convert_button.click(
|
||||
resize_lora,
|
||||
inputs=[model, new_rank, save_to, save_precision, device,
|
||||
inputs=[
|
||||
model,
|
||||
new_rank,
|
||||
save_to,
|
||||
save_precision,
|
||||
device,
|
||||
],
|
||||
)
|
||||
|
@ -2,7 +2,11 @@ import gradio as gr
|
||||
from easygui import msgbox
|
||||
import subprocess
|
||||
import os
|
||||
from .common_gui import get_saveasfilename_path, get_any_file_path, get_file_path
|
||||
from .common_gui import (
|
||||
get_saveasfilename_path,
|
||||
get_any_file_path,
|
||||
get_file_path,
|
||||
)
|
||||
|
||||
folder_symbol = '\U0001f4c2' # 📂
|
||||
refresh_symbol = '\U0001f504' # 🔄
|
||||
@ -30,9 +34,11 @@ def verify_lora(
|
||||
|
||||
# Run the command
|
||||
subprocess.run(run_cmd)
|
||||
process = subprocess.Popen(run_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
process = subprocess.Popen(
|
||||
run_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
||||
)
|
||||
output, error = process.communicate()
|
||||
|
||||
|
||||
return (output.decode(), error.decode())
|
||||
|
||||
|
||||
@ -46,10 +52,10 @@ def gradio_verify_lora_tab():
|
||||
gr.Markdown(
|
||||
'This utility can verify a LoRA network to make sure it is properly trained.'
|
||||
)
|
||||
|
||||
|
||||
lora_ext = gr.Textbox(value='*.pt *.safetensors', visible=False)
|
||||
lora_ext_name = gr.Textbox(value='LoRA model types', visible=False)
|
||||
|
||||
|
||||
with gr.Row():
|
||||
lora_model = gr.Textbox(
|
||||
label='LoRA model',
|
||||
@ -64,7 +70,7 @@ def gradio_verify_lora_tab():
|
||||
inputs=[lora_model, lora_ext, lora_ext_name],
|
||||
outputs=lora_model,
|
||||
)
|
||||
verify_button = gr.Button('Verify', variant="primary")
|
||||
verify_button = gr.Button('Verify', variant='primary')
|
||||
|
||||
lora_model_verif_output = gr.Textbox(
|
||||
label='Output',
|
||||
@ -73,7 +79,7 @@ def gradio_verify_lora_tab():
|
||||
lines=1,
|
||||
max_lines=10,
|
||||
)
|
||||
|
||||
|
||||
lora_model_verif_error = gr.Textbox(
|
||||
label='Error',
|
||||
placeholder='Verification error',
|
||||
@ -87,5 +93,5 @@ def gradio_verify_lora_tab():
|
||||
inputs=[
|
||||
lora_model,
|
||||
],
|
||||
outputs=[lora_model_verif_output, lora_model_verif_error]
|
||||
outputs=[lora_model_verif_output, lora_model_verif_error],
|
||||
)
|
||||
|
@ -14,7 +14,7 @@ def caption_images(train_data_dir, caption_extension, batch_size, thresh):
|
||||
if train_data_dir == '':
|
||||
msgbox('Image folder is missing...')
|
||||
return
|
||||
|
||||
|
||||
if caption_extension == '':
|
||||
msgbox('Please provide an extension for the caption files.')
|
||||
return
|
||||
|
78
lora_gui.py
78
lora_gui.py
@ -91,9 +91,14 @@ def save_configuration(
|
||||
max_train_epochs,
|
||||
max_data_loader_n_workers,
|
||||
network_alpha,
|
||||
training_comment, keep_tokens,
|
||||
lr_scheduler_num_cycles, lr_scheduler_power,
|
||||
training_comment,
|
||||
keep_tokens,
|
||||
lr_scheduler_num_cycles,
|
||||
lr_scheduler_power,
|
||||
persistent_data_loader_workers,
|
||||
bucket_no_upscale,
|
||||
random_crop,
|
||||
bucket_reso_steps,
|
||||
):
|
||||
# Get list of function parameters and values
|
||||
parameters = list(locals().items())
|
||||
@ -182,9 +187,14 @@ def open_configuration(
|
||||
max_train_epochs,
|
||||
max_data_loader_n_workers,
|
||||
network_alpha,
|
||||
training_comment, keep_tokens,
|
||||
lr_scheduler_num_cycles, lr_scheduler_power,
|
||||
training_comment,
|
||||
keep_tokens,
|
||||
lr_scheduler_num_cycles,
|
||||
lr_scheduler_power,
|
||||
persistent_data_loader_workers,
|
||||
bucket_no_upscale,
|
||||
random_crop,
|
||||
bucket_reso_steps,
|
||||
):
|
||||
# Get list of function parameters and values
|
||||
parameters = list(locals().items())
|
||||
@ -257,9 +267,14 @@ def train_model(
|
||||
max_train_epochs,
|
||||
max_data_loader_n_workers,
|
||||
network_alpha,
|
||||
training_comment, keep_tokens,
|
||||
lr_scheduler_num_cycles, lr_scheduler_power,
|
||||
training_comment,
|
||||
keep_tokens,
|
||||
lr_scheduler_num_cycles,
|
||||
lr_scheduler_power,
|
||||
persistent_data_loader_workers,
|
||||
bucket_no_upscale,
|
||||
random_crop,
|
||||
bucket_reso_steps,
|
||||
):
|
||||
if pretrained_model_name_or_path == '':
|
||||
msgbox('Source model information is missing')
|
||||
@ -281,12 +296,18 @@ def train_model(
|
||||
if output_dir == '':
|
||||
msgbox('Output folder path is missing')
|
||||
return
|
||||
|
||||
|
||||
if int(bucket_reso_steps) < 1:
|
||||
msgbox('Bucket resolution steps need to be greater than 0')
|
||||
return
|
||||
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
|
||||
|
||||
if stop_text_encoder_training_pct > 0:
|
||||
msgbox('Output "stop text encoder training" is not yet supported. Ignoring')
|
||||
msgbox(
|
||||
'Output "stop text encoder training" is not yet supported. Ignoring'
|
||||
)
|
||||
stop_text_encoder_training_pct = 0
|
||||
|
||||
# If string is empty set string to 0.
|
||||
@ -358,9 +379,9 @@ def train_model(
|
||||
print(f'lr_warmup_steps = {lr_warmup_steps}')
|
||||
|
||||
run_cmd = f'accelerate launch --num_cpu_threads_per_process={num_cpu_threads_per_process} "train_network.py"'
|
||||
|
||||
run_cmd += f' --bucket_reso_steps=1 --bucket_no_upscale' # --random_crop'
|
||||
|
||||
|
||||
run_cmd += f' --bucket_reso_steps=1 --bucket_no_upscale' # --random_crop'
|
||||
|
||||
if v2:
|
||||
run_cmd += ' --v2'
|
||||
if v_parameterization:
|
||||
@ -390,7 +411,7 @@ def train_model(
|
||||
if not float(prior_loss_weight) == 1.0:
|
||||
run_cmd += f' --prior_loss_weight={prior_loss_weight}'
|
||||
run_cmd += f' --network_module=networks.lora'
|
||||
|
||||
|
||||
if not (float(text_encoder_lr) == 0) or not (float(unet_lr) == 0):
|
||||
if not (float(text_encoder_lr) == 0) and not (float(unet_lr) == 0):
|
||||
run_cmd += f' --text_encoder_lr={text_encoder_lr}'
|
||||
@ -402,14 +423,12 @@ def train_model(
|
||||
run_cmd += f' --unet_lr={unet_lr}'
|
||||
run_cmd += f' --network_train_unet_only'
|
||||
else:
|
||||
if float(text_encoder_lr) == 0:
|
||||
msgbox(
|
||||
'Please input learning rate values.'
|
||||
)
|
||||
if float(text_encoder_lr) == 0:
|
||||
msgbox('Please input learning rate values.')
|
||||
return
|
||||
|
||||
|
||||
run_cmd += f' --network_dim={network_dim}'
|
||||
|
||||
|
||||
if not lora_network_weights == '':
|
||||
run_cmd += f' --network_weights="{lora_network_weights}"'
|
||||
if int(gradient_accumulation_steps) > 1:
|
||||
@ -454,6 +473,9 @@ def train_model(
|
||||
use_8bit_adam=use_8bit_adam,
|
||||
keep_tokens=keep_tokens,
|
||||
persistent_data_loader_workers=persistent_data_loader_workers,
|
||||
bucket_no_upscale=bucket_no_upscale,
|
||||
random_crop=random_crop,
|
||||
bucket_reso_steps=bucket_reso_steps,
|
||||
)
|
||||
|
||||
print(run_cmd)
|
||||
@ -675,11 +697,13 @@ def lora_tab(
|
||||
label='Prior loss weight', value=1.0
|
||||
)
|
||||
lr_scheduler_num_cycles = gr.Textbox(
|
||||
label='LR number of cycles', placeholder='(Optional) For Cosine with restart and polynomial only'
|
||||
label='LR number of cycles',
|
||||
placeholder='(Optional) For Cosine with restart and polynomial only',
|
||||
)
|
||||
|
||||
|
||||
lr_scheduler_power = gr.Textbox(
|
||||
label='LR power', placeholder='(Optional) For Cosine with restart and polynomial only'
|
||||
label='LR power',
|
||||
placeholder='(Optional) For Cosine with restart and polynomial only',
|
||||
)
|
||||
(
|
||||
use_8bit_adam,
|
||||
@ -698,6 +722,9 @@ def lora_tab(
|
||||
max_data_loader_n_workers,
|
||||
keep_tokens,
|
||||
persistent_data_loader_workers,
|
||||
bucket_no_upscale,
|
||||
random_crop,
|
||||
bucket_reso_steps,
|
||||
) = gradio_advanced_training()
|
||||
color_aug.change(
|
||||
color_aug_changed,
|
||||
@ -719,7 +746,6 @@ def lora_tab(
|
||||
gradio_merge_lora_tab()
|
||||
gradio_resize_lora_tab()
|
||||
gradio_verify_lora_tab()
|
||||
|
||||
|
||||
button_run = gr.Button('Train model')
|
||||
|
||||
@ -773,8 +799,12 @@ def lora_tab(
|
||||
network_alpha,
|
||||
training_comment,
|
||||
keep_tokens,
|
||||
lr_scheduler_num_cycles, lr_scheduler_power,
|
||||
lr_scheduler_num_cycles,
|
||||
lr_scheduler_power,
|
||||
persistent_data_loader_workers,
|
||||
bucket_no_upscale,
|
||||
random_crop,
|
||||
bucket_reso_steps,
|
||||
]
|
||||
|
||||
button_open_config.click(
|
||||
|
@ -82,8 +82,18 @@ def save_configuration(
|
||||
max_data_loader_n_workers,
|
||||
mem_eff_attn,
|
||||
gradient_accumulation_steps,
|
||||
model_list, token_string, init_word, num_vectors_per_token, max_train_steps, weights, template, keep_tokens,
|
||||
model_list,
|
||||
token_string,
|
||||
init_word,
|
||||
num_vectors_per_token,
|
||||
max_train_steps,
|
||||
weights,
|
||||
template,
|
||||
keep_tokens,
|
||||
persistent_data_loader_workers,
|
||||
bucket_no_upscale,
|
||||
random_crop,
|
||||
bucket_reso_steps,
|
||||
):
|
||||
# Get list of function parameters and values
|
||||
parameters = list(locals().items())
|
||||
@ -171,8 +181,18 @@ def open_configuration(
|
||||
max_data_loader_n_workers,
|
||||
mem_eff_attn,
|
||||
gradient_accumulation_steps,
|
||||
model_list, token_string, init_word, num_vectors_per_token, max_train_steps, weights, template, keep_tokens,
|
||||
model_list,
|
||||
token_string,
|
||||
init_word,
|
||||
num_vectors_per_token,
|
||||
max_train_steps,
|
||||
weights,
|
||||
template,
|
||||
keep_tokens,
|
||||
persistent_data_loader_workers,
|
||||
bucket_no_upscale,
|
||||
random_crop,
|
||||
bucket_reso_steps,
|
||||
):
|
||||
# Get list of function parameters and values
|
||||
parameters = list(locals().items())
|
||||
@ -241,8 +261,17 @@ def train_model(
|
||||
mem_eff_attn,
|
||||
gradient_accumulation_steps,
|
||||
model_list, # Keep this. Yes, it is unused here but required given the common list used
|
||||
token_string, init_word, num_vectors_per_token, max_train_steps, weights, template, keep_tokens,
|
||||
token_string,
|
||||
init_word,
|
||||
num_vectors_per_token,
|
||||
max_train_steps,
|
||||
weights,
|
||||
template,
|
||||
keep_tokens,
|
||||
persistent_data_loader_workers,
|
||||
bucket_no_upscale,
|
||||
random_crop,
|
||||
bucket_reso_steps,
|
||||
):
|
||||
if pretrained_model_name_or_path == '':
|
||||
msgbox('Source model information is missing')
|
||||
@ -264,15 +293,15 @@ def train_model(
|
||||
if output_dir == '':
|
||||
msgbox('Output folder path is missing')
|
||||
return
|
||||
|
||||
|
||||
if token_string == '':
|
||||
msgbox('Token string is missing')
|
||||
return
|
||||
|
||||
|
||||
if init_word == '':
|
||||
msgbox('Init word is missing')
|
||||
return
|
||||
|
||||
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
|
||||
@ -332,7 +361,7 @@ def train_model(
|
||||
)
|
||||
else:
|
||||
max_train_steps = int(max_train_steps)
|
||||
|
||||
|
||||
print(f'max_train_steps = {max_train_steps}')
|
||||
|
||||
# calculate stop encoder training
|
||||
@ -421,6 +450,9 @@ def train_model(
|
||||
use_8bit_adam=use_8bit_adam,
|
||||
keep_tokens=keep_tokens,
|
||||
persistent_data_loader_workers=persistent_data_loader_workers,
|
||||
bucket_no_upscale=bucket_no_upscale,
|
||||
random_crop=random_crop,
|
||||
bucket_reso_steps=bucket_reso_steps,
|
||||
)
|
||||
run_cmd += f' --token_string="{token_string}"'
|
||||
run_cmd += f' --init_word="{init_word}"'
|
||||
@ -431,7 +463,7 @@ def train_model(
|
||||
run_cmd += f' --use_object_template'
|
||||
elif template == 'style template':
|
||||
run_cmd += f' --use_style_template'
|
||||
|
||||
|
||||
print(run_cmd)
|
||||
# Run the command
|
||||
subprocess.run(run_cmd)
|
||||
@ -576,9 +608,7 @@ def ti_tab(
|
||||
label='Resume TI training',
|
||||
placeholder='(Optional) Path to existing TI embeding file to keep training',
|
||||
)
|
||||
weights_file_input = gr.Button(
|
||||
'📂', elem_id='open_folder_small'
|
||||
)
|
||||
weights_file_input = gr.Button('📂', elem_id='open_folder_small')
|
||||
weights_file_input.click(get_file_path, outputs=weights)
|
||||
with gr.Row():
|
||||
token_string = gr.Textbox(
|
||||
@ -676,6 +706,9 @@ def ti_tab(
|
||||
max_data_loader_n_workers,
|
||||
keep_tokens,
|
||||
persistent_data_loader_workers,
|
||||
bucket_no_upscale,
|
||||
random_crop,
|
||||
bucket_reso_steps,
|
||||
) = gradio_advanced_training()
|
||||
color_aug.change(
|
||||
color_aug_changed,
|
||||
@ -739,9 +772,17 @@ def ti_tab(
|
||||
mem_eff_attn,
|
||||
gradient_accumulation_steps,
|
||||
model_list,
|
||||
token_string, init_word, num_vectors_per_token, max_train_steps, weights, template,
|
||||
token_string,
|
||||
init_word,
|
||||
num_vectors_per_token,
|
||||
max_train_steps,
|
||||
weights,
|
||||
template,
|
||||
keep_tokens,
|
||||
persistent_data_loader_workers,
|
||||
bucket_no_upscale,
|
||||
random_crop,
|
||||
bucket_reso_steps,
|
||||
]
|
||||
|
||||
button_open_config.click(
|
||||
|
@ -1,66 +0,0 @@
|
||||
import os
|
||||
import cv2
|
||||
import argparse
|
||||
import shutil
|
||||
import math
|
||||
|
||||
def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divisible_by=2):
|
||||
# Calculate max_pixels from max_resolution string
|
||||
max_pixels = int(max_resolution.split("x")[0]) * int(max_resolution.split("x")[1])
|
||||
|
||||
# Create destination folder if it does not exist
|
||||
if not os.path.exists(dst_img_folder):
|
||||
os.makedirs(dst_img_folder)
|
||||
|
||||
# Iterate through all files in src_img_folder
|
||||
for filename in os.listdir(src_img_folder):
|
||||
# Check if the image is png, jpg or webp
|
||||
if not filename.endswith(('.png', '.jpg', '.webp')):
|
||||
# Copy the file to the destination folder if not png, jpg or webp
|
||||
shutil.copy(os.path.join(src_img_folder, filename), os.path.join(dst_img_folder, filename))
|
||||
continue
|
||||
|
||||
# Load image
|
||||
img = cv2.imread(os.path.join(src_img_folder, filename))
|
||||
|
||||
# Calculate current number of pixels
|
||||
current_pixels = img.shape[0] * img.shape[1]
|
||||
|
||||
# Check if the image needs resizing
|
||||
if current_pixels > max_pixels:
|
||||
# Calculate scaling factor
|
||||
scale_factor = max_pixels / current_pixels
|
||||
|
||||
# Calculate new dimensions
|
||||
new_height = int(img.shape[0] * math.sqrt(scale_factor))
|
||||
new_width = int(img.shape[1] * math.sqrt(scale_factor))
|
||||
|
||||
# Resize image
|
||||
img = cv2.resize(img, (new_width, new_height))
|
||||
|
||||
# Calculate the new height and width that are divisible by divisible_by
|
||||
new_height = new_height if new_height % divisible_by == 0 else new_height - new_height % divisible_by
|
||||
new_width = new_width if new_width % divisible_by == 0 else new_width - new_width % divisible_by
|
||||
|
||||
# Center crop the image to the calculated dimensions
|
||||
y = int((img.shape[0] - new_height) / 2)
|
||||
x = int((img.shape[1] - new_width) / 2)
|
||||
img = img[y:y + new_height, x:x + new_width]
|
||||
|
||||
# Save resized image in dst_img_folder
|
||||
cv2.imwrite(os.path.join(dst_img_folder, filename), img, [cv2.IMWRITE_JPEG_QUALITY, 100])
|
||||
|
||||
print(f"Resized image: {filename} with size {img.shape[0]}x{img.shape[1]}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Resize images in a folder to a specified max resolution')
|
||||
parser.add_argument('src_img_folder', type=str, help='Source folder containing the images')
|
||||
parser.add_argument('dst_img_folder', type=str, help='Destination folder to save the resized images')
|
||||
parser.add_argument('--max_resolution', type=str, help='Maximum resolution in the format "512x512"', default="512x512")
|
||||
parser.add_argument('--divisible_by', type=int, help='Ensure new dimensions are divisible by this value', default=2)
|
||||
args = parser.parse_args()
|
||||
resize_images(args.src_img_folder, args.dst_img_folder, args.max_resolution)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
76
tools/resize_images_to_resolutions.py
Normal file
76
tools/resize_images_to_resolutions.py
Normal file
@ -0,0 +1,76 @@
|
||||
import os
|
||||
import cv2
|
||||
import argparse
|
||||
import shutil
|
||||
import math
|
||||
|
||||
def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divisible_by=2):
|
||||
# Split the max_resolution string by "," and strip any whitespaces
|
||||
max_resolutions = [res.strip() for res in max_resolution.split(',')]
|
||||
|
||||
# # Calculate max_pixels from max_resolution string
|
||||
# max_pixels = int(max_resolution.split("x")[0]) * int(max_resolution.split("x")[1])
|
||||
|
||||
# Create destination folder if it does not exist
|
||||
if not os.path.exists(dst_img_folder):
|
||||
os.makedirs(dst_img_folder)
|
||||
|
||||
# Iterate through all files in src_img_folder
|
||||
for filename in os.listdir(src_img_folder):
|
||||
# Check if the image is png, jpg or webp
|
||||
if not filename.endswith(('.png', '.jpg', '.webp')):
|
||||
# Copy the file to the destination folder if not png, jpg or webp
|
||||
shutil.copy(os.path.join(src_img_folder, filename), os.path.join(dst_img_folder, filename))
|
||||
continue
|
||||
|
||||
# Load image
|
||||
img = cv2.imread(os.path.join(src_img_folder, filename))
|
||||
|
||||
for max_resolution in max_resolutions:
|
||||
# Calculate max_pixels from max_resolution string
|
||||
max_pixels = int(max_resolution.split("x")[0]) * int(max_resolution.split("x")[1])
|
||||
|
||||
# Calculate current number of pixels
|
||||
current_pixels = img.shape[0] * img.shape[1]
|
||||
|
||||
# Check if the image needs resizing
|
||||
if current_pixels > max_pixels:
|
||||
# Calculate scaling factor
|
||||
scale_factor = max_pixels / current_pixels
|
||||
|
||||
# Calculate new dimensions
|
||||
new_height = int(img.shape[0] * math.sqrt(scale_factor))
|
||||
new_width = int(img.shape[1] * math.sqrt(scale_factor))
|
||||
|
||||
# Resize image
|
||||
img = cv2.resize(img, (new_width, new_height))
|
||||
|
||||
# Calculate the new height and width that are divisible by divisible_by
|
||||
new_height = new_height if new_height % divisible_by == 0 else new_height - new_height % divisible_by
|
||||
new_width = new_width if new_width % divisible_by == 0 else new_width - new_width % divisible_by
|
||||
|
||||
# Center crop the image to the calculated dimensions
|
||||
y = int((img.shape[0] - new_height) / 2)
|
||||
x = int((img.shape[1] - new_width) / 2)
|
||||
img = img[y:y + new_height, x:x + new_width]
|
||||
|
||||
# Split filename into base and extension
|
||||
base, ext = os.path.splitext(filename)
|
||||
new_filename = base + '+' + max_resolution + '.jpg'
|
||||
|
||||
# Save resized image in dst_img_folder
|
||||
cv2.imwrite(os.path.join(dst_img_folder, new_filename), img, [cv2.IMWRITE_JPEG_QUALITY, 100])
|
||||
print(f"Resized image: {filename} with size {img.shape[0]}x{img.shape[1]} as {new_filename}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Resize images in a folder to a specified max resolution(s)')
|
||||
parser.add_argument('src_img_folder', type=str, help='Source folder containing the images')
|
||||
parser.add_argument('dst_img_folder', type=str, help='Destination folder to save the resized images')
|
||||
parser.add_argument('--max_resolution', type=str, help='Maximum resolution(s) in the format "512x512,384x384, etc, etc"', default="512x512,384x384,256x256,128x128")
|
||||
parser.add_argument('--divisible_by', type=int, help='Ensure new dimensions are divisible by this value', default=1)
|
||||
args = parser.parse_args()
|
||||
resize_images(args.src_img_folder, args.dst_img_folder, args.max_resolution)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
Reference in New Issue
Block a user