Integrate new bucket parameters in GUI
This commit is contained in:
parent
2486af9903
commit
cbfc311687
@ -82,8 +82,12 @@ def save_configuration(
|
|||||||
max_data_loader_n_workers,
|
max_data_loader_n_workers,
|
||||||
mem_eff_attn,
|
mem_eff_attn,
|
||||||
gradient_accumulation_steps,
|
gradient_accumulation_steps,
|
||||||
model_list, keep_tokens,
|
model_list,
|
||||||
|
keep_tokens,
|
||||||
persistent_data_loader_workers,
|
persistent_data_loader_workers,
|
||||||
|
bucket_no_upscale,
|
||||||
|
random_crop,
|
||||||
|
bucket_reso_steps,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -167,8 +171,12 @@ def open_configuration(
|
|||||||
max_data_loader_n_workers,
|
max_data_loader_n_workers,
|
||||||
mem_eff_attn,
|
mem_eff_attn,
|
||||||
gradient_accumulation_steps,
|
gradient_accumulation_steps,
|
||||||
model_list, keep_tokens,
|
model_list,
|
||||||
|
keep_tokens,
|
||||||
persistent_data_loader_workers,
|
persistent_data_loader_workers,
|
||||||
|
bucket_no_upscale,
|
||||||
|
random_crop,
|
||||||
|
bucket_reso_steps,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -239,6 +247,9 @@ def train_model(
|
|||||||
model_list, # Keep this. Yes, it is unused here but required given the common list used
|
model_list, # Keep this. Yes, it is unused here but required given the common list used
|
||||||
keep_tokens,
|
keep_tokens,
|
||||||
persistent_data_loader_workers,
|
persistent_data_loader_workers,
|
||||||
|
bucket_no_upscale,
|
||||||
|
random_crop,
|
||||||
|
bucket_reso_steps,
|
||||||
):
|
):
|
||||||
if pretrained_model_name_or_path == '':
|
if pretrained_model_name_or_path == '':
|
||||||
msgbox('Source model information is missing')
|
msgbox('Source model information is missing')
|
||||||
@ -402,6 +413,9 @@ def train_model(
|
|||||||
use_8bit_adam=use_8bit_adam,
|
use_8bit_adam=use_8bit_adam,
|
||||||
keep_tokens=keep_tokens,
|
keep_tokens=keep_tokens,
|
||||||
persistent_data_loader_workers=persistent_data_loader_workers,
|
persistent_data_loader_workers=persistent_data_loader_workers,
|
||||||
|
bucket_no_upscale=bucket_no_upscale,
|
||||||
|
random_crop=random_crop,
|
||||||
|
bucket_reso_steps=bucket_reso_steps,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(run_cmd)
|
print(run_cmd)
|
||||||
@ -610,6 +624,9 @@ def dreambooth_tab(
|
|||||||
max_data_loader_n_workers,
|
max_data_loader_n_workers,
|
||||||
keep_tokens,
|
keep_tokens,
|
||||||
persistent_data_loader_workers,
|
persistent_data_loader_workers,
|
||||||
|
bucket_no_upscale,
|
||||||
|
random_crop,
|
||||||
|
bucket_reso_steps,
|
||||||
) = gradio_advanced_training()
|
) = gradio_advanced_training()
|
||||||
color_aug.change(
|
color_aug.change(
|
||||||
color_aug_changed,
|
color_aug_changed,
|
||||||
@ -675,6 +692,9 @@ def dreambooth_tab(
|
|||||||
model_list,
|
model_list,
|
||||||
keep_tokens,
|
keep_tokens,
|
||||||
persistent_data_loader_workers,
|
persistent_data_loader_workers,
|
||||||
|
bucket_no_upscale,
|
||||||
|
random_crop,
|
||||||
|
bucket_reso_steps,
|
||||||
]
|
]
|
||||||
|
|
||||||
button_open_config.click(
|
button_open_config.click(
|
||||||
|
@ -78,8 +78,12 @@ def save_configuration(
|
|||||||
color_aug,
|
color_aug,
|
||||||
model_list,
|
model_list,
|
||||||
cache_latents,
|
cache_latents,
|
||||||
use_latent_files, keep_tokens,
|
use_latent_files,
|
||||||
|
keep_tokens,
|
||||||
persistent_data_loader_workers,
|
persistent_data_loader_workers,
|
||||||
|
bucket_no_upscale,
|
||||||
|
random_crop,
|
||||||
|
bucket_reso_steps,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -169,8 +173,12 @@ def open_config_file(
|
|||||||
color_aug,
|
color_aug,
|
||||||
model_list,
|
model_list,
|
||||||
cache_latents,
|
cache_latents,
|
||||||
use_latent_files, keep_tokens,
|
use_latent_files,
|
||||||
|
keep_tokens,
|
||||||
persistent_data_loader_workers,
|
persistent_data_loader_workers,
|
||||||
|
bucket_no_upscale,
|
||||||
|
random_crop,
|
||||||
|
bucket_reso_steps,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -245,8 +253,12 @@ def train_model(
|
|||||||
color_aug,
|
color_aug,
|
||||||
model_list, # Keep this. Yes, it is unused here but required given the common list used
|
model_list, # Keep this. Yes, it is unused here but required given the common list used
|
||||||
cache_latents,
|
cache_latents,
|
||||||
use_latent_files, keep_tokens,
|
use_latent_files,
|
||||||
|
keep_tokens,
|
||||||
persistent_data_loader_workers,
|
persistent_data_loader_workers,
|
||||||
|
bucket_no_upscale,
|
||||||
|
random_crop,
|
||||||
|
bucket_reso_steps,
|
||||||
):
|
):
|
||||||
# create caption json file
|
# create caption json file
|
||||||
if generate_caption_database:
|
if generate_caption_database:
|
||||||
@ -295,7 +307,11 @@ def train_model(
|
|||||||
subprocess.run(run_cmd)
|
subprocess.run(run_cmd)
|
||||||
|
|
||||||
image_num = len(
|
image_num = len(
|
||||||
[f for f in os.listdir(image_folder) if f.endswith('.jpg') or f.endswith('.png') or f.endswith('.webp')]
|
[
|
||||||
|
f
|
||||||
|
for f in os.listdir(image_folder)
|
||||||
|
if f.endswith('.jpg') or f.endswith('.png') or f.endswith('.webp')
|
||||||
|
]
|
||||||
)
|
)
|
||||||
print(f'image_num = {image_num}')
|
print(f'image_num = {image_num}')
|
||||||
|
|
||||||
@ -386,6 +402,9 @@ def train_model(
|
|||||||
use_8bit_adam=use_8bit_adam,
|
use_8bit_adam=use_8bit_adam,
|
||||||
keep_tokens=keep_tokens,
|
keep_tokens=keep_tokens,
|
||||||
persistent_data_loader_workers=persistent_data_loader_workers,
|
persistent_data_loader_workers=persistent_data_loader_workers,
|
||||||
|
bucket_no_upscale=bucket_no_upscale,
|
||||||
|
random_crop=random_crop,
|
||||||
|
bucket_reso_steps=bucket_reso_steps,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(run_cmd)
|
print(run_cmd)
|
||||||
@ -592,6 +611,9 @@ def finetune_tab():
|
|||||||
max_data_loader_n_workers,
|
max_data_loader_n_workers,
|
||||||
keep_tokens,
|
keep_tokens,
|
||||||
persistent_data_loader_workers,
|
persistent_data_loader_workers,
|
||||||
|
bucket_no_upscale,
|
||||||
|
random_crop,
|
||||||
|
bucket_reso_steps,
|
||||||
) = gradio_advanced_training()
|
) = gradio_advanced_training()
|
||||||
color_aug.change(
|
color_aug.change(
|
||||||
color_aug_changed,
|
color_aug_changed,
|
||||||
@ -653,6 +675,9 @@ def finetune_tab():
|
|||||||
use_latent_files,
|
use_latent_files,
|
||||||
keep_tokens,
|
keep_tokens,
|
||||||
persistent_data_loader_workers,
|
persistent_data_loader_workers,
|
||||||
|
bucket_no_upscale,
|
||||||
|
random_crop,
|
||||||
|
bucket_reso_steps,
|
||||||
]
|
]
|
||||||
|
|
||||||
button_run.click(train_model, inputs=settings_list)
|
button_run.click(train_model, inputs=settings_list)
|
||||||
|
@ -19,7 +19,7 @@ def UI(username, password):
|
|||||||
print('Load CSS...')
|
print('Load CSS...')
|
||||||
css += file.read() + '\n'
|
css += file.read() + '\n'
|
||||||
|
|
||||||
interface = gr.Blocks(css=css, title="Kohya_ss GUI")
|
interface = gr.Blocks(css=css, title='Kohya_ss GUI')
|
||||||
|
|
||||||
with interface:
|
with interface:
|
||||||
with gr.Tab('Dreambooth'):
|
with gr.Tab('Dreambooth'):
|
||||||
|
@ -10,13 +10,15 @@ def caption_images(
|
|||||||
overwrite_input,
|
overwrite_input,
|
||||||
caption_file_ext,
|
caption_file_ext,
|
||||||
prefix,
|
prefix,
|
||||||
postfix, find, replace
|
postfix,
|
||||||
|
find,
|
||||||
|
replace,
|
||||||
):
|
):
|
||||||
# Check for images_dir_input
|
# Check for images_dir_input
|
||||||
if images_dir_input == '':
|
if images_dir_input == '':
|
||||||
msgbox('Image folder is missing...')
|
msgbox('Image folder is missing...')
|
||||||
return
|
return
|
||||||
|
|
||||||
if caption_file_ext == '':
|
if caption_file_ext == '':
|
||||||
msgbox('Please provide an extension for the caption files.')
|
msgbox('Please provide an extension for the caption files.')
|
||||||
return
|
return
|
||||||
@ -39,7 +41,7 @@ def caption_images(
|
|||||||
subprocess.run(run_cmd)
|
subprocess.run(run_cmd)
|
||||||
|
|
||||||
if overwrite_input:
|
if overwrite_input:
|
||||||
if not prefix=='' or not postfix=='':
|
if not prefix == '' or not postfix == '':
|
||||||
# Add prefix and postfix
|
# Add prefix and postfix
|
||||||
add_pre_postfix(
|
add_pre_postfix(
|
||||||
folder=images_dir_input,
|
folder=images_dir_input,
|
||||||
@ -47,7 +49,7 @@ def caption_images(
|
|||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
postfix=postfix,
|
postfix=postfix,
|
||||||
)
|
)
|
||||||
if not find=='':
|
if not find == '':
|
||||||
find_replace(
|
find_replace(
|
||||||
folder=images_dir_input,
|
folder=images_dir_input,
|
||||||
caption_file_ext=caption_file_ext,
|
caption_file_ext=caption_file_ext,
|
||||||
@ -134,6 +136,7 @@ def gradio_basic_caption_gui_tab():
|
|||||||
caption_file_ext,
|
caption_file_ext,
|
||||||
prefix,
|
prefix,
|
||||||
postfix,
|
postfix,
|
||||||
find, replace
|
find,
|
||||||
|
replace,
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -26,7 +26,7 @@ def caption_images(
|
|||||||
if train_data_dir == '':
|
if train_data_dir == '':
|
||||||
msgbox('Image folder is missing...')
|
msgbox('Image folder is missing...')
|
||||||
return
|
return
|
||||||
|
|
||||||
if caption_file_ext == '':
|
if caption_file_ext == '':
|
||||||
msgbox('Please provide an extension for the caption files.')
|
msgbox('Please provide an extension for the caption files.')
|
||||||
return
|
return
|
||||||
|
@ -9,6 +9,7 @@ refresh_symbol = '\U0001f504' # 🔄
|
|||||||
save_style_symbol = '\U0001f4be' # 💾
|
save_style_symbol = '\U0001f4be' # 💾
|
||||||
document_symbol = '\U0001F4C4' # 📄
|
document_symbol = '\U0001F4C4' # 📄
|
||||||
|
|
||||||
|
|
||||||
def get_dir_and_file(file_path):
|
def get_dir_and_file(file_path):
|
||||||
dir_path, file_name = os.path.split(file_path)
|
dir_path, file_name = os.path.split(file_path)
|
||||||
return (dir_path, file_name)
|
return (dir_path, file_name)
|
||||||
@ -200,7 +201,7 @@ def find_replace(folder='', caption_file_ext='.caption', find='', replace=''):
|
|||||||
|
|
||||||
files = [f for f in os.listdir(folder) if f.endswith(caption_file_ext)]
|
files = [f for f in os.listdir(folder) if f.endswith(caption_file_ext)]
|
||||||
for file in files:
|
for file in files:
|
||||||
with open(os.path.join(folder, file), 'r', errors="ignore") as f:
|
with open(os.path.join(folder, file), 'r', errors='ignore') as f:
|
||||||
content = f.read()
|
content = f.read()
|
||||||
f.close
|
f.close
|
||||||
content = content.replace(find, replace)
|
content = content.replace(find, replace)
|
||||||
@ -304,7 +305,8 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
|
|||||||
###
|
###
|
||||||
### Gradio common GUI section
|
### Gradio common GUI section
|
||||||
###
|
###
|
||||||
|
|
||||||
|
|
||||||
def gradio_config():
|
def gradio_config():
|
||||||
with gr.Accordion('Configuration file', open=False):
|
with gr.Accordion('Configuration file', open=False):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
@ -318,7 +320,13 @@ def gradio_config():
|
|||||||
placeholder="type the configuration file path or use the 'Open' button above to select it...",
|
placeholder="type the configuration file path or use the 'Open' button above to select it...",
|
||||||
interactive=True,
|
interactive=True,
|
||||||
)
|
)
|
||||||
return (button_open_config, button_save_config, button_save_as_config, config_file_name)
|
return (
|
||||||
|
button_open_config,
|
||||||
|
button_save_config,
|
||||||
|
button_save_as_config,
|
||||||
|
config_file_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def gradio_source_model():
|
def gradio_source_model():
|
||||||
with gr.Tab('Source model'):
|
with gr.Tab('Source model'):
|
||||||
@ -382,9 +390,20 @@ def gradio_source_model():
|
|||||||
v_parameterization,
|
v_parameterization,
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
return (pretrained_model_name_or_path, v2, v_parameterization, save_model_as, model_list)
|
return (
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
v2,
|
||||||
|
v_parameterization,
|
||||||
|
save_model_as,
|
||||||
|
model_list,
|
||||||
|
)
|
||||||
|
|
||||||
def gradio_training(learning_rate_value='1e-6', lr_scheduler_value='constant', lr_warmup_value='0'):
|
|
||||||
|
def gradio_training(
|
||||||
|
learning_rate_value='1e-6',
|
||||||
|
lr_scheduler_value='constant',
|
||||||
|
lr_warmup_value='0',
|
||||||
|
):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
train_batch_size = gr.Slider(
|
train_batch_size = gr.Slider(
|
||||||
minimum=1,
|
minimum=1,
|
||||||
@ -394,9 +413,7 @@ def gradio_training(learning_rate_value='1e-6', lr_scheduler_value='constant', l
|
|||||||
step=1,
|
step=1,
|
||||||
)
|
)
|
||||||
epoch = gr.Textbox(label='Epoch', value=1)
|
epoch = gr.Textbox(label='Epoch', value=1)
|
||||||
save_every_n_epochs = gr.Textbox(
|
save_every_n_epochs = gr.Textbox(label='Save every N epochs', value=1)
|
||||||
label='Save every N epochs', value=1
|
|
||||||
)
|
|
||||||
caption_extension = gr.Textbox(
|
caption_extension = gr.Textbox(
|
||||||
label='Caption Extension',
|
label='Caption Extension',
|
||||||
placeholder='(Optional) Extension for caption files. default: .caption',
|
placeholder='(Optional) Extension for caption files. default: .caption',
|
||||||
@ -429,7 +446,9 @@ def gradio_training(learning_rate_value='1e-6', lr_scheduler_value='constant', l
|
|||||||
)
|
)
|
||||||
seed = gr.Textbox(label='Seed', value=1234)
|
seed = gr.Textbox(label='Seed', value=1234)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
learning_rate = gr.Textbox(label='Learning rate', value=learning_rate_value)
|
learning_rate = gr.Textbox(
|
||||||
|
label='Learning rate', value=learning_rate_value
|
||||||
|
)
|
||||||
lr_scheduler = gr.Dropdown(
|
lr_scheduler = gr.Dropdown(
|
||||||
label='LR Scheduler',
|
label='LR Scheduler',
|
||||||
choices=[
|
choices=[
|
||||||
@ -442,7 +461,9 @@ def gradio_training(learning_rate_value='1e-6', lr_scheduler_value='constant', l
|
|||||||
],
|
],
|
||||||
value=lr_scheduler_value,
|
value=lr_scheduler_value,
|
||||||
)
|
)
|
||||||
lr_warmup = gr.Textbox(label='LR warmup (% of steps)', value=lr_warmup_value)
|
lr_warmup = gr.Textbox(
|
||||||
|
label='LR warmup (% of steps)', value=lr_warmup_value
|
||||||
|
)
|
||||||
cache_latents = gr.Checkbox(label='Cache latent', value=True)
|
cache_latents = gr.Checkbox(label='Cache latent', value=True)
|
||||||
return (
|
return (
|
||||||
learning_rate,
|
learning_rate,
|
||||||
@ -459,50 +480,38 @@ def gradio_training(learning_rate_value='1e-6', lr_scheduler_value='constant', l
|
|||||||
cache_latents,
|
cache_latents,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def run_cmd_training(**kwargs):
|
def run_cmd_training(**kwargs):
|
||||||
options = [
|
options = [
|
||||||
f' --learning_rate="{kwargs.get("learning_rate", "")}"'
|
f' --learning_rate="{kwargs.get("learning_rate", "")}"'
|
||||||
if kwargs.get('learning_rate')
|
if kwargs.get('learning_rate')
|
||||||
else '',
|
else '',
|
||||||
|
|
||||||
f' --lr_scheduler="{kwargs.get("lr_scheduler", "")}"'
|
f' --lr_scheduler="{kwargs.get("lr_scheduler", "")}"'
|
||||||
if kwargs.get('lr_scheduler')
|
if kwargs.get('lr_scheduler')
|
||||||
else '',
|
else '',
|
||||||
|
|
||||||
f' --lr_warmup_steps="{kwargs.get("lr_warmup_steps", "")}"'
|
f' --lr_warmup_steps="{kwargs.get("lr_warmup_steps", "")}"'
|
||||||
if kwargs.get('lr_warmup_steps')
|
if kwargs.get('lr_warmup_steps')
|
||||||
else '',
|
else '',
|
||||||
|
|
||||||
f' --train_batch_size="{kwargs.get("train_batch_size", "")}"'
|
f' --train_batch_size="{kwargs.get("train_batch_size", "")}"'
|
||||||
if kwargs.get('train_batch_size')
|
if kwargs.get('train_batch_size')
|
||||||
else '',
|
else '',
|
||||||
|
|
||||||
f' --max_train_steps="{kwargs.get("max_train_steps", "")}"'
|
f' --max_train_steps="{kwargs.get("max_train_steps", "")}"'
|
||||||
if kwargs.get('max_train_steps')
|
if kwargs.get('max_train_steps')
|
||||||
else '',
|
else '',
|
||||||
|
|
||||||
f' --save_every_n_epochs="{kwargs.get("save_every_n_epochs", "")}"'
|
f' --save_every_n_epochs="{kwargs.get("save_every_n_epochs", "")}"'
|
||||||
if kwargs.get('save_every_n_epochs')
|
if kwargs.get('save_every_n_epochs')
|
||||||
else '',
|
else '',
|
||||||
|
|
||||||
f' --mixed_precision="{kwargs.get("mixed_precision", "")}"'
|
f' --mixed_precision="{kwargs.get("mixed_precision", "")}"'
|
||||||
if kwargs.get('mixed_precision')
|
if kwargs.get('mixed_precision')
|
||||||
else '',
|
else '',
|
||||||
|
|
||||||
f' --save_precision="{kwargs.get("save_precision", "")}"'
|
f' --save_precision="{kwargs.get("save_precision", "")}"'
|
||||||
if kwargs.get('save_precision')
|
if kwargs.get('save_precision')
|
||||||
else '',
|
else '',
|
||||||
|
f' --seed="{kwargs.get("seed", "")}"' if kwargs.get('seed') else '',
|
||||||
f' --seed="{kwargs.get("seed", "")}"'
|
|
||||||
if kwargs.get('seed')
|
|
||||||
else '',
|
|
||||||
|
|
||||||
f' --caption_extension="{kwargs.get("caption_extension", "")}"'
|
f' --caption_extension="{kwargs.get("caption_extension", "")}"'
|
||||||
if kwargs.get('caption_extension')
|
if kwargs.get('caption_extension')
|
||||||
else '',
|
else '',
|
||||||
|
|
||||||
' --cache_latents' if kwargs.get('cache_latents') else '',
|
' --cache_latents' if kwargs.get('cache_latents') else '',
|
||||||
|
|
||||||
]
|
]
|
||||||
run_cmd = ''.join(options)
|
run_cmd = ''.join(options)
|
||||||
return run_cmd
|
return run_cmd
|
||||||
@ -532,9 +541,7 @@ def gradio_advanced_training():
|
|||||||
gradient_checkpointing = gr.Checkbox(
|
gradient_checkpointing = gr.Checkbox(
|
||||||
label='Gradient checkpointing', value=False
|
label='Gradient checkpointing', value=False
|
||||||
)
|
)
|
||||||
shuffle_caption = gr.Checkbox(
|
shuffle_caption = gr.Checkbox(label='Shuffle caption', value=False)
|
||||||
label='Shuffle caption', value=False
|
|
||||||
)
|
|
||||||
persistent_data_loader_workers = gr.Checkbox(
|
persistent_data_loader_workers = gr.Checkbox(
|
||||||
label='Persistent data loader', value=False
|
label='Persistent data loader', value=False
|
||||||
)
|
)
|
||||||
@ -544,10 +551,18 @@ def gradio_advanced_training():
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
use_8bit_adam = gr.Checkbox(label='Use 8bit adam', value=True)
|
use_8bit_adam = gr.Checkbox(label='Use 8bit adam', value=True)
|
||||||
xformers = gr.Checkbox(label='Use xformers', value=True)
|
xformers = gr.Checkbox(label='Use xformers', value=True)
|
||||||
color_aug = gr.Checkbox(
|
color_aug = gr.Checkbox(label='Color augmentation', value=False)
|
||||||
label='Color augmentation', value=False
|
|
||||||
)
|
|
||||||
flip_aug = gr.Checkbox(label='Flip augmentation', value=False)
|
flip_aug = gr.Checkbox(label='Flip augmentation', value=False)
|
||||||
|
with gr.Row():
|
||||||
|
bucket_no_upscale = gr.Checkbox(
|
||||||
|
label="Don't upscale bucket resolution", value=True
|
||||||
|
)
|
||||||
|
random_crop = gr.Checkbox(
|
||||||
|
label='Random crop instead of center crop', value=False
|
||||||
|
)
|
||||||
|
bucket_reso_steps = gr.Number(
|
||||||
|
label='Bucket resolution steps', value=64
|
||||||
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
save_state = gr.Checkbox(label='Save training state', value=False)
|
save_state = gr.Checkbox(label='Save training state', value=False)
|
||||||
resume = gr.Textbox(
|
resume = gr.Textbox(
|
||||||
@ -581,55 +596,53 @@ def gradio_advanced_training():
|
|||||||
max_data_loader_n_workers,
|
max_data_loader_n_workers,
|
||||||
keep_tokens,
|
keep_tokens,
|
||||||
persistent_data_loader_workers,
|
persistent_data_loader_workers,
|
||||||
|
bucket_no_upscale,
|
||||||
|
random_crop,
|
||||||
|
bucket_reso_steps,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def run_cmd_advanced_training(**kwargs):
|
def run_cmd_advanced_training(**kwargs):
|
||||||
options = [
|
options = [
|
||||||
f' --max_train_epochs="{kwargs.get("max_train_epochs", "")}"'
|
f' --max_train_epochs="{kwargs.get("max_train_epochs", "")}"'
|
||||||
if kwargs.get('max_train_epochs')
|
if kwargs.get('max_train_epochs')
|
||||||
else '',
|
else '',
|
||||||
|
|
||||||
f' --max_data_loader_n_workers="{kwargs.get("max_data_loader_n_workers", "")}"'
|
f' --max_data_loader_n_workers="{kwargs.get("max_data_loader_n_workers", "")}"'
|
||||||
if kwargs.get('max_data_loader_n_workers')
|
if kwargs.get('max_data_loader_n_workers')
|
||||||
else '',
|
else '',
|
||||||
|
|
||||||
f' --max_token_length={kwargs.get("max_token_length", "")}'
|
f' --max_token_length={kwargs.get("max_token_length", "")}'
|
||||||
if int(kwargs.get('max_token_length', 75)) > 75
|
if int(kwargs.get('max_token_length', 75)) > 75
|
||||||
else '',
|
else '',
|
||||||
|
|
||||||
f' --clip_skip={kwargs.get("clip_skip", "")}'
|
f' --clip_skip={kwargs.get("clip_skip", "")}'
|
||||||
if int(kwargs.get('clip_skip', 1)) > 1
|
if int(kwargs.get('clip_skip', 1)) > 1
|
||||||
else '',
|
else '',
|
||||||
|
|
||||||
f' --resume="{kwargs.get("resume", "")}"'
|
f' --resume="{kwargs.get("resume", "")}"'
|
||||||
if kwargs.get('resume')
|
if kwargs.get('resume')
|
||||||
else '',
|
else '',
|
||||||
|
|
||||||
f' --keep_tokens="{kwargs.get("keep_tokens", "")}"'
|
f' --keep_tokens="{kwargs.get("keep_tokens", "")}"'
|
||||||
if int(kwargs.get('keep_tokens', 0)) > 0
|
if int(kwargs.get('keep_tokens', 0)) > 0
|
||||||
else '',
|
else '',
|
||||||
|
|
||||||
|
f' --bucket_reso_steps={int(kwargs.get("bucket_reso_steps", 1))}'
|
||||||
|
if int(kwargs.get('bucket_reso_steps', 64)) >= 1
|
||||||
|
else '',
|
||||||
|
|
||||||
' --save_state' if kwargs.get('save_state') else '',
|
' --save_state' if kwargs.get('save_state') else '',
|
||||||
|
|
||||||
' --mem_eff_attn' if kwargs.get('mem_eff_attn') else '',
|
' --mem_eff_attn' if kwargs.get('mem_eff_attn') else '',
|
||||||
|
|
||||||
' --color_aug' if kwargs.get('color_aug') else '',
|
' --color_aug' if kwargs.get('color_aug') else '',
|
||||||
|
|
||||||
' --flip_aug' if kwargs.get('flip_aug') else '',
|
' --flip_aug' if kwargs.get('flip_aug') else '',
|
||||||
|
|
||||||
' --shuffle_caption' if kwargs.get('shuffle_caption') else '',
|
' --shuffle_caption' if kwargs.get('shuffle_caption') else '',
|
||||||
|
' --gradient_checkpointing'
|
||||||
' --gradient_checkpointing' if kwargs.get('gradient_checkpointing') else '',
|
if kwargs.get('gradient_checkpointing')
|
||||||
|
else '',
|
||||||
' --full_fp16' if kwargs.get('full_fp16') else '',
|
' --full_fp16' if kwargs.get('full_fp16') else '',
|
||||||
|
|
||||||
' --xformers' if kwargs.get('xformers') else '',
|
' --xformers' if kwargs.get('xformers') else '',
|
||||||
|
|
||||||
' --use_8bit_adam' if kwargs.get('use_8bit_adam') else '',
|
' --use_8bit_adam' if kwargs.get('use_8bit_adam') else '',
|
||||||
|
' --persistent_data_loader_workers'
|
||||||
' --persistent_data_loader_workers' if kwargs.get('persistent_data_loader_workers') else '',
|
if kwargs.get('persistent_data_loader_workers')
|
||||||
|
else '',
|
||||||
|
' --bucket_no_upscale' if kwargs.get('bucket_no_upscale') else '',
|
||||||
|
' --random_crop' if kwargs.get('random_crop') else '',
|
||||||
]
|
]
|
||||||
run_cmd = ''.join(options)
|
run_cmd = ''.join(options)
|
||||||
return run_cmd
|
return run_cmd
|
||||||
|
|
||||||
|
@ -191,9 +191,7 @@ def gradio_dreambooth_folder_creation_tab(
|
|||||||
util_training_dir_output,
|
util_training_dir_output,
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
button_copy_info_to_Folders_tab = gr.Button(
|
button_copy_info_to_Folders_tab = gr.Button('Copy info to Folders Tab')
|
||||||
'Copy info to Folders Tab'
|
|
||||||
)
|
|
||||||
button_copy_info_to_Folders_tab.click(
|
button_copy_info_to_Folders_tab.click(
|
||||||
copy_info_to_Folders_tab,
|
copy_info_to_Folders_tab,
|
||||||
inputs=[util_training_dir_output],
|
inputs=[util_training_dir_output],
|
||||||
|
@ -2,7 +2,11 @@ import gradio as gr
|
|||||||
from easygui import msgbox
|
from easygui import msgbox
|
||||||
import subprocess
|
import subprocess
|
||||||
import os
|
import os
|
||||||
from .common_gui import get_saveasfilename_path, get_any_file_path, get_file_path
|
from .common_gui import (
|
||||||
|
get_saveasfilename_path,
|
||||||
|
get_any_file_path,
|
||||||
|
get_file_path,
|
||||||
|
)
|
||||||
|
|
||||||
folder_symbol = '\U0001f4c2' # 📂
|
folder_symbol = '\U0001f4c2' # 📂
|
||||||
refresh_symbol = '\U0001f504' # 🔄
|
refresh_symbol = '\U0001f504' # 🔄
|
||||||
@ -11,13 +15,18 @@ document_symbol = '\U0001F4C4' # 📄
|
|||||||
|
|
||||||
|
|
||||||
def extract_lora(
|
def extract_lora(
|
||||||
model_tuned, model_org, save_to, save_precision, dim, v2,
|
model_tuned,
|
||||||
|
model_org,
|
||||||
|
save_to,
|
||||||
|
save_precision,
|
||||||
|
dim,
|
||||||
|
v2,
|
||||||
):
|
):
|
||||||
# Check for caption_text_input
|
# Check for caption_text_input
|
||||||
if model_tuned == '':
|
if model_tuned == '':
|
||||||
msgbox('Invalid finetuned model file')
|
msgbox('Invalid finetuned model file')
|
||||||
return
|
return
|
||||||
|
|
||||||
if model_org == '':
|
if model_org == '':
|
||||||
msgbox('Invalid base model file')
|
msgbox('Invalid base model file')
|
||||||
return
|
return
|
||||||
@ -26,12 +35,14 @@ def extract_lora(
|
|||||||
if not os.path.isfile(model_tuned):
|
if not os.path.isfile(model_tuned):
|
||||||
msgbox('The provided finetuned model is not a file')
|
msgbox('The provided finetuned model is not a file')
|
||||||
return
|
return
|
||||||
|
|
||||||
if not os.path.isfile(model_org):
|
if not os.path.isfile(model_org):
|
||||||
msgbox('The provided base model is not a file')
|
msgbox('The provided base model is not a file')
|
||||||
return
|
return
|
||||||
|
|
||||||
run_cmd = f'.\\venv\Scripts\python.exe "networks\extract_lora_from_models.py"'
|
run_cmd = (
|
||||||
|
f'.\\venv\Scripts\python.exe "networks\extract_lora_from_models.py"'
|
||||||
|
)
|
||||||
run_cmd += f' --save_precision {save_precision}'
|
run_cmd += f' --save_precision {save_precision}'
|
||||||
run_cmd += f' --save_to "{save_to}"'
|
run_cmd += f' --save_to "{save_to}"'
|
||||||
run_cmd += f' --model_org "{model_org}"'
|
run_cmd += f' --model_org "{model_org}"'
|
||||||
@ -60,7 +71,7 @@ def gradio_extract_lora_tab():
|
|||||||
lora_ext_name = gr.Textbox(value='LoRA model types', visible=False)
|
lora_ext_name = gr.Textbox(value='LoRA model types', visible=False)
|
||||||
model_ext = gr.Textbox(value='*.ckpt *.safetensors', visible=False)
|
model_ext = gr.Textbox(value='*.ckpt *.safetensors', visible=False)
|
||||||
model_ext_name = gr.Textbox(value='Model types', visible=False)
|
model_ext_name = gr.Textbox(value='Model types', visible=False)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
model_tuned = gr.Textbox(
|
model_tuned = gr.Textbox(
|
||||||
label='Finetuned model',
|
label='Finetuned model',
|
||||||
@ -75,7 +86,7 @@ def gradio_extract_lora_tab():
|
|||||||
inputs=[model_tuned, model_ext, model_ext_name],
|
inputs=[model_tuned, model_ext, model_ext_name],
|
||||||
outputs=model_tuned,
|
outputs=model_tuned,
|
||||||
)
|
)
|
||||||
|
|
||||||
model_org = gr.Textbox(
|
model_org = gr.Textbox(
|
||||||
label='Stable Diffusion base model',
|
label='Stable Diffusion base model',
|
||||||
placeholder='Stable Diffusion original model: ckpt or safetensors file',
|
placeholder='Stable Diffusion original model: ckpt or safetensors file',
|
||||||
@ -99,7 +110,9 @@ def gradio_extract_lora_tab():
|
|||||||
folder_symbol, elem_id='open_folder_small'
|
folder_symbol, elem_id='open_folder_small'
|
||||||
)
|
)
|
||||||
button_save_to.click(
|
button_save_to.click(
|
||||||
get_saveasfilename_path, inputs=[save_to, lora_ext, lora_ext_name], outputs=save_to
|
get_saveasfilename_path,
|
||||||
|
inputs=[save_to, lora_ext, lora_ext_name],
|
||||||
|
outputs=save_to,
|
||||||
)
|
)
|
||||||
save_precision = gr.Dropdown(
|
save_precision = gr.Dropdown(
|
||||||
label='Save precison',
|
label='Save precison',
|
||||||
@ -122,6 +135,5 @@ def gradio_extract_lora_tab():
|
|||||||
|
|
||||||
extract_button.click(
|
extract_button.click(
|
||||||
extract_lora,
|
extract_lora,
|
||||||
inputs=[model_tuned, model_org, save_to, save_precision, dim, v2
|
inputs=[model_tuned, model_org, save_to, save_precision, dim, v2],
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
@ -15,11 +15,11 @@ def caption_images(
|
|||||||
prefix,
|
prefix,
|
||||||
postfix,
|
postfix,
|
||||||
):
|
):
|
||||||
# Check for images_dir_input
|
# Check for images_dir_input
|
||||||
if train_data_dir == '':
|
if train_data_dir == '':
|
||||||
msgbox('Image folder is missing...')
|
msgbox('Image folder is missing...')
|
||||||
return
|
return
|
||||||
|
|
||||||
if caption_ext == '':
|
if caption_ext == '':
|
||||||
msgbox('Please provide an extension for the caption files.')
|
msgbox('Please provide an extension for the caption files.')
|
||||||
return
|
return
|
||||||
@ -29,7 +29,9 @@ def caption_images(
|
|||||||
if not model_id == '':
|
if not model_id == '':
|
||||||
run_cmd += f' --model_id="{model_id}"'
|
run_cmd += f' --model_id="{model_id}"'
|
||||||
run_cmd += f' --batch_size="{int(batch_size)}"'
|
run_cmd += f' --batch_size="{int(batch_size)}"'
|
||||||
run_cmd += f' --max_data_loader_n_workers="{int(max_data_loader_n_workers)}"'
|
run_cmd += (
|
||||||
|
f' --max_data_loader_n_workers="{int(max_data_loader_n_workers)}"'
|
||||||
|
)
|
||||||
run_cmd += f' --max_length="{int(max_length)}"'
|
run_cmd += f' --max_length="{int(max_length)}"'
|
||||||
if caption_ext != '':
|
if caption_ext != '':
|
||||||
run_cmd += f' --caption_extension="{caption_ext}"'
|
run_cmd += f' --caption_extension="{caption_ext}"'
|
||||||
@ -105,8 +107,9 @@ def gradio_git_caption_gui_tab():
|
|||||||
value=75, label='Max length', interactive=True
|
value=75, label='Max length', interactive=True
|
||||||
)
|
)
|
||||||
model_id = gr.Textbox(
|
model_id = gr.Textbox(
|
||||||
label="Model",
|
label='Model',
|
||||||
placeholder="(Optional) model id for GIT in Hugging Face", interactive=True
|
placeholder='(Optional) model id for GIT in Hugging Face',
|
||||||
|
interactive=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
caption_button = gr.Button('Caption images')
|
caption_button = gr.Button('Caption images')
|
||||||
|
@ -2,7 +2,11 @@ import gradio as gr
|
|||||||
from easygui import msgbox
|
from easygui import msgbox
|
||||||
import subprocess
|
import subprocess
|
||||||
import os
|
import os
|
||||||
from .common_gui import get_saveasfilename_path, get_any_file_path, get_file_path
|
from .common_gui import (
|
||||||
|
get_saveasfilename_path,
|
||||||
|
get_any_file_path,
|
||||||
|
get_file_path,
|
||||||
|
)
|
||||||
|
|
||||||
folder_symbol = '\U0001f4c2' # 📂
|
folder_symbol = '\U0001f4c2' # 📂
|
||||||
refresh_symbol = '\U0001f504' # 🔄
|
refresh_symbol = '\U0001f504' # 🔄
|
||||||
@ -11,13 +15,18 @@ document_symbol = '\U0001F4C4' # 📄
|
|||||||
|
|
||||||
|
|
||||||
def merge_lora(
|
def merge_lora(
|
||||||
lora_a_model, lora_b_model, ratio, save_to, precision, save_precision,
|
lora_a_model,
|
||||||
|
lora_b_model,
|
||||||
|
ratio,
|
||||||
|
save_to,
|
||||||
|
precision,
|
||||||
|
save_precision,
|
||||||
):
|
):
|
||||||
# Check for caption_text_input
|
# Check for caption_text_input
|
||||||
if lora_a_model == '':
|
if lora_a_model == '':
|
||||||
msgbox('Invalid model A file')
|
msgbox('Invalid model A file')
|
||||||
return
|
return
|
||||||
|
|
||||||
if lora_b_model == '':
|
if lora_b_model == '':
|
||||||
msgbox('Invalid model B file')
|
msgbox('Invalid model B file')
|
||||||
return
|
return
|
||||||
@ -26,7 +35,7 @@ def merge_lora(
|
|||||||
if not os.path.isfile(lora_a_model):
|
if not os.path.isfile(lora_a_model):
|
||||||
msgbox('The provided model A is not a file')
|
msgbox('The provided model A is not a file')
|
||||||
return
|
return
|
||||||
|
|
||||||
if not os.path.isfile(lora_b_model):
|
if not os.path.isfile(lora_b_model):
|
||||||
msgbox('The provided model B is not a file')
|
msgbox('The provided model B is not a file')
|
||||||
return
|
return
|
||||||
@ -54,13 +63,11 @@ def merge_lora(
|
|||||||
|
|
||||||
def gradio_merge_lora_tab():
|
def gradio_merge_lora_tab():
|
||||||
with gr.Tab('Merge LoRA'):
|
with gr.Tab('Merge LoRA'):
|
||||||
gr.Markdown(
|
gr.Markdown('This utility can merge two LoRA networks together.')
|
||||||
'This utility can merge two LoRA networks together.'
|
|
||||||
)
|
|
||||||
|
|
||||||
lora_ext = gr.Textbox(value='*.pt *.safetensors', visible=False)
|
lora_ext = gr.Textbox(value='*.pt *.safetensors', visible=False)
|
||||||
lora_ext_name = gr.Textbox(value='LoRA model types', visible=False)
|
lora_ext_name = gr.Textbox(value='LoRA model types', visible=False)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
lora_a_model = gr.Textbox(
|
lora_a_model = gr.Textbox(
|
||||||
label='LoRA model "A"',
|
label='LoRA model "A"',
|
||||||
@ -75,7 +82,7 @@ def gradio_merge_lora_tab():
|
|||||||
inputs=[lora_a_model, lora_ext, lora_ext_name],
|
inputs=[lora_a_model, lora_ext, lora_ext_name],
|
||||||
outputs=lora_a_model,
|
outputs=lora_a_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
lora_b_model = gr.Textbox(
|
lora_b_model = gr.Textbox(
|
||||||
label='LoRA model "B"',
|
label='LoRA model "B"',
|
||||||
placeholder='Path to the LoRA B model',
|
placeholder='Path to the LoRA B model',
|
||||||
@ -90,9 +97,15 @@ def gradio_merge_lora_tab():
|
|||||||
outputs=lora_b_model,
|
outputs=lora_b_model,
|
||||||
)
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
ratio = gr.Slider(label="Merge ratio (eg: 0.7 mean 70% of model A and 30% of model B", minimum=0, maximum=1, step=0.01, value=0.5,
|
ratio = gr.Slider(
|
||||||
interactive=True,)
|
label='Merge ratio (eg: 0.7 mean 70% of model A and 30% of model B',
|
||||||
|
minimum=0,
|
||||||
|
maximum=1,
|
||||||
|
step=0.01,
|
||||||
|
value=0.5,
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
save_to = gr.Textbox(
|
save_to = gr.Textbox(
|
||||||
label='Save to',
|
label='Save to',
|
||||||
@ -103,7 +116,9 @@ def gradio_merge_lora_tab():
|
|||||||
folder_symbol, elem_id='open_folder_small'
|
folder_symbol, elem_id='open_folder_small'
|
||||||
)
|
)
|
||||||
button_save_to.click(
|
button_save_to.click(
|
||||||
get_saveasfilename_path, inputs=[save_to, lora_ext, lora_ext_name], outputs=save_to
|
get_saveasfilename_path,
|
||||||
|
inputs=[save_to, lora_ext, lora_ext_name],
|
||||||
|
outputs=save_to,
|
||||||
)
|
)
|
||||||
precision = gr.Dropdown(
|
precision = gr.Dropdown(
|
||||||
label='Merge precison',
|
label='Merge precison',
|
||||||
@ -122,6 +137,12 @@ def gradio_merge_lora_tab():
|
|||||||
|
|
||||||
convert_button.click(
|
convert_button.click(
|
||||||
merge_lora,
|
merge_lora,
|
||||||
inputs=[lora_a_model, lora_b_model, ratio, save_to, precision, save_precision,
|
inputs=[
|
||||||
|
lora_a_model,
|
||||||
|
lora_b_model,
|
||||||
|
ratio,
|
||||||
|
save_to,
|
||||||
|
precision,
|
||||||
|
save_precision,
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -11,7 +11,11 @@ document_symbol = '\U0001F4C4' # 📄
|
|||||||
|
|
||||||
|
|
||||||
def resize_lora(
|
def resize_lora(
|
||||||
model, new_rank, save_to, save_precision, device,
|
model,
|
||||||
|
new_rank,
|
||||||
|
save_to,
|
||||||
|
save_precision,
|
||||||
|
device,
|
||||||
):
|
):
|
||||||
# Check for caption_text_input
|
# Check for caption_text_input
|
||||||
if model == '':
|
if model == '':
|
||||||
@ -22,7 +26,7 @@ def resize_lora(
|
|||||||
if not os.path.isfile(model):
|
if not os.path.isfile(model):
|
||||||
msgbox('The provided model is not a file')
|
msgbox('The provided model is not a file')
|
||||||
return
|
return
|
||||||
|
|
||||||
if device == '':
|
if device == '':
|
||||||
device = 'cuda'
|
device = 'cuda'
|
||||||
|
|
||||||
@ -46,13 +50,11 @@ def resize_lora(
|
|||||||
|
|
||||||
def gradio_resize_lora_tab():
|
def gradio_resize_lora_tab():
|
||||||
with gr.Tab('Resize LoRA'):
|
with gr.Tab('Resize LoRA'):
|
||||||
gr.Markdown(
|
gr.Markdown('This utility can resize a LoRA.')
|
||||||
'This utility can resize a LoRA.'
|
|
||||||
)
|
|
||||||
|
|
||||||
lora_ext = gr.Textbox(value='*.pt *.safetensors', visible=False)
|
lora_ext = gr.Textbox(value='*.pt *.safetensors', visible=False)
|
||||||
lora_ext_name = gr.Textbox(value='LoRA model types', visible=False)
|
lora_ext_name = gr.Textbox(value='LoRA model types', visible=False)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
model = gr.Textbox(
|
model = gr.Textbox(
|
||||||
label='Source LoRA',
|
label='Source LoRA',
|
||||||
@ -68,9 +70,15 @@ def gradio_resize_lora_tab():
|
|||||||
outputs=model,
|
outputs=model,
|
||||||
)
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
new_rank = gr.Slider(label="Desired LoRA rank", minimum=1, maximum=1024, step=1, value=4,
|
new_rank = gr.Slider(
|
||||||
interactive=True,)
|
label='Desired LoRA rank',
|
||||||
|
minimum=1,
|
||||||
|
maximum=1024,
|
||||||
|
step=1,
|
||||||
|
value=4,
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
save_to = gr.Textbox(
|
save_to = gr.Textbox(
|
||||||
label='Save to',
|
label='Save to',
|
||||||
@ -81,7 +89,9 @@ def gradio_resize_lora_tab():
|
|||||||
folder_symbol, elem_id='open_folder_small'
|
folder_symbol, elem_id='open_folder_small'
|
||||||
)
|
)
|
||||||
button_save_to.click(
|
button_save_to.click(
|
||||||
get_saveasfilename_path, inputs=[save_to, lora_ext, lora_ext_name], outputs=save_to
|
get_saveasfilename_path,
|
||||||
|
inputs=[save_to, lora_ext, lora_ext_name],
|
||||||
|
outputs=save_to,
|
||||||
)
|
)
|
||||||
save_precision = gr.Dropdown(
|
save_precision = gr.Dropdown(
|
||||||
label='Save precison',
|
label='Save precison',
|
||||||
@ -99,6 +109,11 @@ def gradio_resize_lora_tab():
|
|||||||
|
|
||||||
convert_button.click(
|
convert_button.click(
|
||||||
resize_lora,
|
resize_lora,
|
||||||
inputs=[model, new_rank, save_to, save_precision, device,
|
inputs=[
|
||||||
|
model,
|
||||||
|
new_rank,
|
||||||
|
save_to,
|
||||||
|
save_precision,
|
||||||
|
device,
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -2,7 +2,11 @@ import gradio as gr
|
|||||||
from easygui import msgbox
|
from easygui import msgbox
|
||||||
import subprocess
|
import subprocess
|
||||||
import os
|
import os
|
||||||
from .common_gui import get_saveasfilename_path, get_any_file_path, get_file_path
|
from .common_gui import (
|
||||||
|
get_saveasfilename_path,
|
||||||
|
get_any_file_path,
|
||||||
|
get_file_path,
|
||||||
|
)
|
||||||
|
|
||||||
folder_symbol = '\U0001f4c2' # 📂
|
folder_symbol = '\U0001f4c2' # 📂
|
||||||
refresh_symbol = '\U0001f504' # 🔄
|
refresh_symbol = '\U0001f504' # 🔄
|
||||||
@ -30,9 +34,11 @@ def verify_lora(
|
|||||||
|
|
||||||
# Run the command
|
# Run the command
|
||||||
subprocess.run(run_cmd)
|
subprocess.run(run_cmd)
|
||||||
process = subprocess.Popen(run_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
process = subprocess.Popen(
|
||||||
|
run_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
||||||
|
)
|
||||||
output, error = process.communicate()
|
output, error = process.communicate()
|
||||||
|
|
||||||
return (output.decode(), error.decode())
|
return (output.decode(), error.decode())
|
||||||
|
|
||||||
|
|
||||||
@ -46,10 +52,10 @@ def gradio_verify_lora_tab():
|
|||||||
gr.Markdown(
|
gr.Markdown(
|
||||||
'This utility can verify a LoRA network to make sure it is properly trained.'
|
'This utility can verify a LoRA network to make sure it is properly trained.'
|
||||||
)
|
)
|
||||||
|
|
||||||
lora_ext = gr.Textbox(value='*.pt *.safetensors', visible=False)
|
lora_ext = gr.Textbox(value='*.pt *.safetensors', visible=False)
|
||||||
lora_ext_name = gr.Textbox(value='LoRA model types', visible=False)
|
lora_ext_name = gr.Textbox(value='LoRA model types', visible=False)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
lora_model = gr.Textbox(
|
lora_model = gr.Textbox(
|
||||||
label='LoRA model',
|
label='LoRA model',
|
||||||
@ -64,7 +70,7 @@ def gradio_verify_lora_tab():
|
|||||||
inputs=[lora_model, lora_ext, lora_ext_name],
|
inputs=[lora_model, lora_ext, lora_ext_name],
|
||||||
outputs=lora_model,
|
outputs=lora_model,
|
||||||
)
|
)
|
||||||
verify_button = gr.Button('Verify', variant="primary")
|
verify_button = gr.Button('Verify', variant='primary')
|
||||||
|
|
||||||
lora_model_verif_output = gr.Textbox(
|
lora_model_verif_output = gr.Textbox(
|
||||||
label='Output',
|
label='Output',
|
||||||
@ -73,7 +79,7 @@ def gradio_verify_lora_tab():
|
|||||||
lines=1,
|
lines=1,
|
||||||
max_lines=10,
|
max_lines=10,
|
||||||
)
|
)
|
||||||
|
|
||||||
lora_model_verif_error = gr.Textbox(
|
lora_model_verif_error = gr.Textbox(
|
||||||
label='Error',
|
label='Error',
|
||||||
placeholder='Verification error',
|
placeholder='Verification error',
|
||||||
@ -87,5 +93,5 @@ def gradio_verify_lora_tab():
|
|||||||
inputs=[
|
inputs=[
|
||||||
lora_model,
|
lora_model,
|
||||||
],
|
],
|
||||||
outputs=[lora_model_verif_output, lora_model_verif_error]
|
outputs=[lora_model_verif_output, lora_model_verif_error],
|
||||||
)
|
)
|
||||||
|
@ -14,7 +14,7 @@ def caption_images(train_data_dir, caption_extension, batch_size, thresh):
|
|||||||
if train_data_dir == '':
|
if train_data_dir == '':
|
||||||
msgbox('Image folder is missing...')
|
msgbox('Image folder is missing...')
|
||||||
return
|
return
|
||||||
|
|
||||||
if caption_extension == '':
|
if caption_extension == '':
|
||||||
msgbox('Please provide an extension for the caption files.')
|
msgbox('Please provide an extension for the caption files.')
|
||||||
return
|
return
|
||||||
|
78
lora_gui.py
78
lora_gui.py
@ -91,9 +91,14 @@ def save_configuration(
|
|||||||
max_train_epochs,
|
max_train_epochs,
|
||||||
max_data_loader_n_workers,
|
max_data_loader_n_workers,
|
||||||
network_alpha,
|
network_alpha,
|
||||||
training_comment, keep_tokens,
|
training_comment,
|
||||||
lr_scheduler_num_cycles, lr_scheduler_power,
|
keep_tokens,
|
||||||
|
lr_scheduler_num_cycles,
|
||||||
|
lr_scheduler_power,
|
||||||
persistent_data_loader_workers,
|
persistent_data_loader_workers,
|
||||||
|
bucket_no_upscale,
|
||||||
|
random_crop,
|
||||||
|
bucket_reso_steps,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -182,9 +187,14 @@ def open_configuration(
|
|||||||
max_train_epochs,
|
max_train_epochs,
|
||||||
max_data_loader_n_workers,
|
max_data_loader_n_workers,
|
||||||
network_alpha,
|
network_alpha,
|
||||||
training_comment, keep_tokens,
|
training_comment,
|
||||||
lr_scheduler_num_cycles, lr_scheduler_power,
|
keep_tokens,
|
||||||
|
lr_scheduler_num_cycles,
|
||||||
|
lr_scheduler_power,
|
||||||
persistent_data_loader_workers,
|
persistent_data_loader_workers,
|
||||||
|
bucket_no_upscale,
|
||||||
|
random_crop,
|
||||||
|
bucket_reso_steps,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -257,9 +267,14 @@ def train_model(
|
|||||||
max_train_epochs,
|
max_train_epochs,
|
||||||
max_data_loader_n_workers,
|
max_data_loader_n_workers,
|
||||||
network_alpha,
|
network_alpha,
|
||||||
training_comment, keep_tokens,
|
training_comment,
|
||||||
lr_scheduler_num_cycles, lr_scheduler_power,
|
keep_tokens,
|
||||||
|
lr_scheduler_num_cycles,
|
||||||
|
lr_scheduler_power,
|
||||||
persistent_data_loader_workers,
|
persistent_data_loader_workers,
|
||||||
|
bucket_no_upscale,
|
||||||
|
random_crop,
|
||||||
|
bucket_reso_steps,
|
||||||
):
|
):
|
||||||
if pretrained_model_name_or_path == '':
|
if pretrained_model_name_or_path == '':
|
||||||
msgbox('Source model information is missing')
|
msgbox('Source model information is missing')
|
||||||
@ -281,12 +296,18 @@ def train_model(
|
|||||||
if output_dir == '':
|
if output_dir == '':
|
||||||
msgbox('Output folder path is missing')
|
msgbox('Output folder path is missing')
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if int(bucket_reso_steps) < 1:
|
||||||
|
msgbox('Bucket resolution steps need to be greater than 0')
|
||||||
|
return
|
||||||
|
|
||||||
if not os.path.exists(output_dir):
|
if not os.path.exists(output_dir):
|
||||||
os.makedirs(output_dir)
|
os.makedirs(output_dir)
|
||||||
|
|
||||||
if stop_text_encoder_training_pct > 0:
|
if stop_text_encoder_training_pct > 0:
|
||||||
msgbox('Output "stop text encoder training" is not yet supported. Ignoring')
|
msgbox(
|
||||||
|
'Output "stop text encoder training" is not yet supported. Ignoring'
|
||||||
|
)
|
||||||
stop_text_encoder_training_pct = 0
|
stop_text_encoder_training_pct = 0
|
||||||
|
|
||||||
# If string is empty set string to 0.
|
# If string is empty set string to 0.
|
||||||
@ -358,9 +379,9 @@ def train_model(
|
|||||||
print(f'lr_warmup_steps = {lr_warmup_steps}')
|
print(f'lr_warmup_steps = {lr_warmup_steps}')
|
||||||
|
|
||||||
run_cmd = f'accelerate launch --num_cpu_threads_per_process={num_cpu_threads_per_process} "train_network.py"'
|
run_cmd = f'accelerate launch --num_cpu_threads_per_process={num_cpu_threads_per_process} "train_network.py"'
|
||||||
|
|
||||||
run_cmd += f' --bucket_reso_steps=1 --bucket_no_upscale' # --random_crop'
|
run_cmd += f' --bucket_reso_steps=1 --bucket_no_upscale' # --random_crop'
|
||||||
|
|
||||||
if v2:
|
if v2:
|
||||||
run_cmd += ' --v2'
|
run_cmd += ' --v2'
|
||||||
if v_parameterization:
|
if v_parameterization:
|
||||||
@ -390,7 +411,7 @@ def train_model(
|
|||||||
if not float(prior_loss_weight) == 1.0:
|
if not float(prior_loss_weight) == 1.0:
|
||||||
run_cmd += f' --prior_loss_weight={prior_loss_weight}'
|
run_cmd += f' --prior_loss_weight={prior_loss_weight}'
|
||||||
run_cmd += f' --network_module=networks.lora'
|
run_cmd += f' --network_module=networks.lora'
|
||||||
|
|
||||||
if not (float(text_encoder_lr) == 0) or not (float(unet_lr) == 0):
|
if not (float(text_encoder_lr) == 0) or not (float(unet_lr) == 0):
|
||||||
if not (float(text_encoder_lr) == 0) and not (float(unet_lr) == 0):
|
if not (float(text_encoder_lr) == 0) and not (float(unet_lr) == 0):
|
||||||
run_cmd += f' --text_encoder_lr={text_encoder_lr}'
|
run_cmd += f' --text_encoder_lr={text_encoder_lr}'
|
||||||
@ -402,14 +423,12 @@ def train_model(
|
|||||||
run_cmd += f' --unet_lr={unet_lr}'
|
run_cmd += f' --unet_lr={unet_lr}'
|
||||||
run_cmd += f' --network_train_unet_only'
|
run_cmd += f' --network_train_unet_only'
|
||||||
else:
|
else:
|
||||||
if float(text_encoder_lr) == 0:
|
if float(text_encoder_lr) == 0:
|
||||||
msgbox(
|
msgbox('Please input learning rate values.')
|
||||||
'Please input learning rate values.'
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
run_cmd += f' --network_dim={network_dim}'
|
run_cmd += f' --network_dim={network_dim}'
|
||||||
|
|
||||||
if not lora_network_weights == '':
|
if not lora_network_weights == '':
|
||||||
run_cmd += f' --network_weights="{lora_network_weights}"'
|
run_cmd += f' --network_weights="{lora_network_weights}"'
|
||||||
if int(gradient_accumulation_steps) > 1:
|
if int(gradient_accumulation_steps) > 1:
|
||||||
@ -454,6 +473,9 @@ def train_model(
|
|||||||
use_8bit_adam=use_8bit_adam,
|
use_8bit_adam=use_8bit_adam,
|
||||||
keep_tokens=keep_tokens,
|
keep_tokens=keep_tokens,
|
||||||
persistent_data_loader_workers=persistent_data_loader_workers,
|
persistent_data_loader_workers=persistent_data_loader_workers,
|
||||||
|
bucket_no_upscale=bucket_no_upscale,
|
||||||
|
random_crop=random_crop,
|
||||||
|
bucket_reso_steps=bucket_reso_steps,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(run_cmd)
|
print(run_cmd)
|
||||||
@ -675,11 +697,13 @@ def lora_tab(
|
|||||||
label='Prior loss weight', value=1.0
|
label='Prior loss weight', value=1.0
|
||||||
)
|
)
|
||||||
lr_scheduler_num_cycles = gr.Textbox(
|
lr_scheduler_num_cycles = gr.Textbox(
|
||||||
label='LR number of cycles', placeholder='(Optional) For Cosine with restart and polynomial only'
|
label='LR number of cycles',
|
||||||
|
placeholder='(Optional) For Cosine with restart and polynomial only',
|
||||||
)
|
)
|
||||||
|
|
||||||
lr_scheduler_power = gr.Textbox(
|
lr_scheduler_power = gr.Textbox(
|
||||||
label='LR power', placeholder='(Optional) For Cosine with restart and polynomial only'
|
label='LR power',
|
||||||
|
placeholder='(Optional) For Cosine with restart and polynomial only',
|
||||||
)
|
)
|
||||||
(
|
(
|
||||||
use_8bit_adam,
|
use_8bit_adam,
|
||||||
@ -698,6 +722,9 @@ def lora_tab(
|
|||||||
max_data_loader_n_workers,
|
max_data_loader_n_workers,
|
||||||
keep_tokens,
|
keep_tokens,
|
||||||
persistent_data_loader_workers,
|
persistent_data_loader_workers,
|
||||||
|
bucket_no_upscale,
|
||||||
|
random_crop,
|
||||||
|
bucket_reso_steps,
|
||||||
) = gradio_advanced_training()
|
) = gradio_advanced_training()
|
||||||
color_aug.change(
|
color_aug.change(
|
||||||
color_aug_changed,
|
color_aug_changed,
|
||||||
@ -719,7 +746,6 @@ def lora_tab(
|
|||||||
gradio_merge_lora_tab()
|
gradio_merge_lora_tab()
|
||||||
gradio_resize_lora_tab()
|
gradio_resize_lora_tab()
|
||||||
gradio_verify_lora_tab()
|
gradio_verify_lora_tab()
|
||||||
|
|
||||||
|
|
||||||
button_run = gr.Button('Train model')
|
button_run = gr.Button('Train model')
|
||||||
|
|
||||||
@ -773,8 +799,12 @@ def lora_tab(
|
|||||||
network_alpha,
|
network_alpha,
|
||||||
training_comment,
|
training_comment,
|
||||||
keep_tokens,
|
keep_tokens,
|
||||||
lr_scheduler_num_cycles, lr_scheduler_power,
|
lr_scheduler_num_cycles,
|
||||||
|
lr_scheduler_power,
|
||||||
persistent_data_loader_workers,
|
persistent_data_loader_workers,
|
||||||
|
bucket_no_upscale,
|
||||||
|
random_crop,
|
||||||
|
bucket_reso_steps,
|
||||||
]
|
]
|
||||||
|
|
||||||
button_open_config.click(
|
button_open_config.click(
|
||||||
|
@ -82,8 +82,18 @@ def save_configuration(
|
|||||||
max_data_loader_n_workers,
|
max_data_loader_n_workers,
|
||||||
mem_eff_attn,
|
mem_eff_attn,
|
||||||
gradient_accumulation_steps,
|
gradient_accumulation_steps,
|
||||||
model_list, token_string, init_word, num_vectors_per_token, max_train_steps, weights, template, keep_tokens,
|
model_list,
|
||||||
|
token_string,
|
||||||
|
init_word,
|
||||||
|
num_vectors_per_token,
|
||||||
|
max_train_steps,
|
||||||
|
weights,
|
||||||
|
template,
|
||||||
|
keep_tokens,
|
||||||
persistent_data_loader_workers,
|
persistent_data_loader_workers,
|
||||||
|
bucket_no_upscale,
|
||||||
|
random_crop,
|
||||||
|
bucket_reso_steps,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -171,8 +181,18 @@ def open_configuration(
|
|||||||
max_data_loader_n_workers,
|
max_data_loader_n_workers,
|
||||||
mem_eff_attn,
|
mem_eff_attn,
|
||||||
gradient_accumulation_steps,
|
gradient_accumulation_steps,
|
||||||
model_list, token_string, init_word, num_vectors_per_token, max_train_steps, weights, template, keep_tokens,
|
model_list,
|
||||||
|
token_string,
|
||||||
|
init_word,
|
||||||
|
num_vectors_per_token,
|
||||||
|
max_train_steps,
|
||||||
|
weights,
|
||||||
|
template,
|
||||||
|
keep_tokens,
|
||||||
persistent_data_loader_workers,
|
persistent_data_loader_workers,
|
||||||
|
bucket_no_upscale,
|
||||||
|
random_crop,
|
||||||
|
bucket_reso_steps,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -241,8 +261,17 @@ def train_model(
|
|||||||
mem_eff_attn,
|
mem_eff_attn,
|
||||||
gradient_accumulation_steps,
|
gradient_accumulation_steps,
|
||||||
model_list, # Keep this. Yes, it is unused here but required given the common list used
|
model_list, # Keep this. Yes, it is unused here but required given the common list used
|
||||||
token_string, init_word, num_vectors_per_token, max_train_steps, weights, template, keep_tokens,
|
token_string,
|
||||||
|
init_word,
|
||||||
|
num_vectors_per_token,
|
||||||
|
max_train_steps,
|
||||||
|
weights,
|
||||||
|
template,
|
||||||
|
keep_tokens,
|
||||||
persistent_data_loader_workers,
|
persistent_data_loader_workers,
|
||||||
|
bucket_no_upscale,
|
||||||
|
random_crop,
|
||||||
|
bucket_reso_steps,
|
||||||
):
|
):
|
||||||
if pretrained_model_name_or_path == '':
|
if pretrained_model_name_or_path == '':
|
||||||
msgbox('Source model information is missing')
|
msgbox('Source model information is missing')
|
||||||
@ -264,15 +293,15 @@ def train_model(
|
|||||||
if output_dir == '':
|
if output_dir == '':
|
||||||
msgbox('Output folder path is missing')
|
msgbox('Output folder path is missing')
|
||||||
return
|
return
|
||||||
|
|
||||||
if token_string == '':
|
if token_string == '':
|
||||||
msgbox('Token string is missing')
|
msgbox('Token string is missing')
|
||||||
return
|
return
|
||||||
|
|
||||||
if init_word == '':
|
if init_word == '':
|
||||||
msgbox('Init word is missing')
|
msgbox('Init word is missing')
|
||||||
return
|
return
|
||||||
|
|
||||||
if not os.path.exists(output_dir):
|
if not os.path.exists(output_dir):
|
||||||
os.makedirs(output_dir)
|
os.makedirs(output_dir)
|
||||||
|
|
||||||
@ -332,7 +361,7 @@ def train_model(
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
max_train_steps = int(max_train_steps)
|
max_train_steps = int(max_train_steps)
|
||||||
|
|
||||||
print(f'max_train_steps = {max_train_steps}')
|
print(f'max_train_steps = {max_train_steps}')
|
||||||
|
|
||||||
# calculate stop encoder training
|
# calculate stop encoder training
|
||||||
@ -421,6 +450,9 @@ def train_model(
|
|||||||
use_8bit_adam=use_8bit_adam,
|
use_8bit_adam=use_8bit_adam,
|
||||||
keep_tokens=keep_tokens,
|
keep_tokens=keep_tokens,
|
||||||
persistent_data_loader_workers=persistent_data_loader_workers,
|
persistent_data_loader_workers=persistent_data_loader_workers,
|
||||||
|
bucket_no_upscale=bucket_no_upscale,
|
||||||
|
random_crop=random_crop,
|
||||||
|
bucket_reso_steps=bucket_reso_steps,
|
||||||
)
|
)
|
||||||
run_cmd += f' --token_string="{token_string}"'
|
run_cmd += f' --token_string="{token_string}"'
|
||||||
run_cmd += f' --init_word="{init_word}"'
|
run_cmd += f' --init_word="{init_word}"'
|
||||||
@ -431,7 +463,7 @@ def train_model(
|
|||||||
run_cmd += f' --use_object_template'
|
run_cmd += f' --use_object_template'
|
||||||
elif template == 'style template':
|
elif template == 'style template':
|
||||||
run_cmd += f' --use_style_template'
|
run_cmd += f' --use_style_template'
|
||||||
|
|
||||||
print(run_cmd)
|
print(run_cmd)
|
||||||
# Run the command
|
# Run the command
|
||||||
subprocess.run(run_cmd)
|
subprocess.run(run_cmd)
|
||||||
@ -576,9 +608,7 @@ def ti_tab(
|
|||||||
label='Resume TI training',
|
label='Resume TI training',
|
||||||
placeholder='(Optional) Path to existing TI embeding file to keep training',
|
placeholder='(Optional) Path to existing TI embeding file to keep training',
|
||||||
)
|
)
|
||||||
weights_file_input = gr.Button(
|
weights_file_input = gr.Button('📂', elem_id='open_folder_small')
|
||||||
'📂', elem_id='open_folder_small'
|
|
||||||
)
|
|
||||||
weights_file_input.click(get_file_path, outputs=weights)
|
weights_file_input.click(get_file_path, outputs=weights)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
token_string = gr.Textbox(
|
token_string = gr.Textbox(
|
||||||
@ -676,6 +706,9 @@ def ti_tab(
|
|||||||
max_data_loader_n_workers,
|
max_data_loader_n_workers,
|
||||||
keep_tokens,
|
keep_tokens,
|
||||||
persistent_data_loader_workers,
|
persistent_data_loader_workers,
|
||||||
|
bucket_no_upscale,
|
||||||
|
random_crop,
|
||||||
|
bucket_reso_steps,
|
||||||
) = gradio_advanced_training()
|
) = gradio_advanced_training()
|
||||||
color_aug.change(
|
color_aug.change(
|
||||||
color_aug_changed,
|
color_aug_changed,
|
||||||
@ -739,9 +772,17 @@ def ti_tab(
|
|||||||
mem_eff_attn,
|
mem_eff_attn,
|
||||||
gradient_accumulation_steps,
|
gradient_accumulation_steps,
|
||||||
model_list,
|
model_list,
|
||||||
token_string, init_word, num_vectors_per_token, max_train_steps, weights, template,
|
token_string,
|
||||||
|
init_word,
|
||||||
|
num_vectors_per_token,
|
||||||
|
max_train_steps,
|
||||||
|
weights,
|
||||||
|
template,
|
||||||
keep_tokens,
|
keep_tokens,
|
||||||
persistent_data_loader_workers,
|
persistent_data_loader_workers,
|
||||||
|
bucket_no_upscale,
|
||||||
|
random_crop,
|
||||||
|
bucket_reso_steps,
|
||||||
]
|
]
|
||||||
|
|
||||||
button_open_config.click(
|
button_open_config.click(
|
||||||
|
@ -1,66 +0,0 @@
|
|||||||
import os
|
|
||||||
import cv2
|
|
||||||
import argparse
|
|
||||||
import shutil
|
|
||||||
import math
|
|
||||||
|
|
||||||
def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divisible_by=2):
|
|
||||||
# Calculate max_pixels from max_resolution string
|
|
||||||
max_pixels = int(max_resolution.split("x")[0]) * int(max_resolution.split("x")[1])
|
|
||||||
|
|
||||||
# Create destination folder if it does not exist
|
|
||||||
if not os.path.exists(dst_img_folder):
|
|
||||||
os.makedirs(dst_img_folder)
|
|
||||||
|
|
||||||
# Iterate through all files in src_img_folder
|
|
||||||
for filename in os.listdir(src_img_folder):
|
|
||||||
# Check if the image is png, jpg or webp
|
|
||||||
if not filename.endswith(('.png', '.jpg', '.webp')):
|
|
||||||
# Copy the file to the destination folder if not png, jpg or webp
|
|
||||||
shutil.copy(os.path.join(src_img_folder, filename), os.path.join(dst_img_folder, filename))
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Load image
|
|
||||||
img = cv2.imread(os.path.join(src_img_folder, filename))
|
|
||||||
|
|
||||||
# Calculate current number of pixels
|
|
||||||
current_pixels = img.shape[0] * img.shape[1]
|
|
||||||
|
|
||||||
# Check if the image needs resizing
|
|
||||||
if current_pixels > max_pixels:
|
|
||||||
# Calculate scaling factor
|
|
||||||
scale_factor = max_pixels / current_pixels
|
|
||||||
|
|
||||||
# Calculate new dimensions
|
|
||||||
new_height = int(img.shape[0] * math.sqrt(scale_factor))
|
|
||||||
new_width = int(img.shape[1] * math.sqrt(scale_factor))
|
|
||||||
|
|
||||||
# Resize image
|
|
||||||
img = cv2.resize(img, (new_width, new_height))
|
|
||||||
|
|
||||||
# Calculate the new height and width that are divisible by divisible_by
|
|
||||||
new_height = new_height if new_height % divisible_by == 0 else new_height - new_height % divisible_by
|
|
||||||
new_width = new_width if new_width % divisible_by == 0 else new_width - new_width % divisible_by
|
|
||||||
|
|
||||||
# Center crop the image to the calculated dimensions
|
|
||||||
y = int((img.shape[0] - new_height) / 2)
|
|
||||||
x = int((img.shape[1] - new_width) / 2)
|
|
||||||
img = img[y:y + new_height, x:x + new_width]
|
|
||||||
|
|
||||||
# Save resized image in dst_img_folder
|
|
||||||
cv2.imwrite(os.path.join(dst_img_folder, filename), img, [cv2.IMWRITE_JPEG_QUALITY, 100])
|
|
||||||
|
|
||||||
print(f"Resized image: {filename} with size {img.shape[0]}x{img.shape[1]}")
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(description='Resize images in a folder to a specified max resolution')
|
|
||||||
parser.add_argument('src_img_folder', type=str, help='Source folder containing the images')
|
|
||||||
parser.add_argument('dst_img_folder', type=str, help='Destination folder to save the resized images')
|
|
||||||
parser.add_argument('--max_resolution', type=str, help='Maximum resolution in the format "512x512"', default="512x512")
|
|
||||||
parser.add_argument('--divisible_by', type=int, help='Ensure new dimensions are divisible by this value', default=2)
|
|
||||||
args = parser.parse_args()
|
|
||||||
resize_images(args.src_img_folder, args.dst_img_folder, args.max_resolution)
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
76
tools/resize_images_to_resolutions.py
Normal file
76
tools/resize_images_to_resolutions.py
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
import os
|
||||||
|
import cv2
|
||||||
|
import argparse
|
||||||
|
import shutil
|
||||||
|
import math
|
||||||
|
|
||||||
|
def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divisible_by=2):
|
||||||
|
# Split the max_resolution string by "," and strip any whitespaces
|
||||||
|
max_resolutions = [res.strip() for res in max_resolution.split(',')]
|
||||||
|
|
||||||
|
# # Calculate max_pixels from max_resolution string
|
||||||
|
# max_pixels = int(max_resolution.split("x")[0]) * int(max_resolution.split("x")[1])
|
||||||
|
|
||||||
|
# Create destination folder if it does not exist
|
||||||
|
if not os.path.exists(dst_img_folder):
|
||||||
|
os.makedirs(dst_img_folder)
|
||||||
|
|
||||||
|
# Iterate through all files in src_img_folder
|
||||||
|
for filename in os.listdir(src_img_folder):
|
||||||
|
# Check if the image is png, jpg or webp
|
||||||
|
if not filename.endswith(('.png', '.jpg', '.webp')):
|
||||||
|
# Copy the file to the destination folder if not png, jpg or webp
|
||||||
|
shutil.copy(os.path.join(src_img_folder, filename), os.path.join(dst_img_folder, filename))
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Load image
|
||||||
|
img = cv2.imread(os.path.join(src_img_folder, filename))
|
||||||
|
|
||||||
|
for max_resolution in max_resolutions:
|
||||||
|
# Calculate max_pixels from max_resolution string
|
||||||
|
max_pixels = int(max_resolution.split("x")[0]) * int(max_resolution.split("x")[1])
|
||||||
|
|
||||||
|
# Calculate current number of pixels
|
||||||
|
current_pixels = img.shape[0] * img.shape[1]
|
||||||
|
|
||||||
|
# Check if the image needs resizing
|
||||||
|
if current_pixels > max_pixels:
|
||||||
|
# Calculate scaling factor
|
||||||
|
scale_factor = max_pixels / current_pixels
|
||||||
|
|
||||||
|
# Calculate new dimensions
|
||||||
|
new_height = int(img.shape[0] * math.sqrt(scale_factor))
|
||||||
|
new_width = int(img.shape[1] * math.sqrt(scale_factor))
|
||||||
|
|
||||||
|
# Resize image
|
||||||
|
img = cv2.resize(img, (new_width, new_height))
|
||||||
|
|
||||||
|
# Calculate the new height and width that are divisible by divisible_by
|
||||||
|
new_height = new_height if new_height % divisible_by == 0 else new_height - new_height % divisible_by
|
||||||
|
new_width = new_width if new_width % divisible_by == 0 else new_width - new_width % divisible_by
|
||||||
|
|
||||||
|
# Center crop the image to the calculated dimensions
|
||||||
|
y = int((img.shape[0] - new_height) / 2)
|
||||||
|
x = int((img.shape[1] - new_width) / 2)
|
||||||
|
img = img[y:y + new_height, x:x + new_width]
|
||||||
|
|
||||||
|
# Split filename into base and extension
|
||||||
|
base, ext = os.path.splitext(filename)
|
||||||
|
new_filename = base + '+' + max_resolution + '.jpg'
|
||||||
|
|
||||||
|
# Save resized image in dst_img_folder
|
||||||
|
cv2.imwrite(os.path.join(dst_img_folder, new_filename), img, [cv2.IMWRITE_JPEG_QUALITY, 100])
|
||||||
|
print(f"Resized image: {filename} with size {img.shape[0]}x{img.shape[1]} as {new_filename}")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description='Resize images in a folder to a specified max resolution(s)')
|
||||||
|
parser.add_argument('src_img_folder', type=str, help='Source folder containing the images')
|
||||||
|
parser.add_argument('dst_img_folder', type=str, help='Destination folder to save the resized images')
|
||||||
|
parser.add_argument('--max_resolution', type=str, help='Maximum resolution(s) in the format "512x512,384x384, etc, etc"', default="512x512,384x384,256x256,128x128")
|
||||||
|
parser.add_argument('--divisible_by', type=int, help='Ensure new dimensions are divisible by this value', default=1)
|
||||||
|
args = parser.parse_args()
|
||||||
|
resize_images(args.src_img_folder, args.dst_img_folder, args.max_resolution)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
Loading…
Reference in New Issue
Block a user