commit
17960c880f
@ -176,6 +176,11 @@ This will store your a backup file with your current locally installed pip packa
|
|||||||
|
|
||||||
## Change History
|
## Change History
|
||||||
|
|
||||||
|
* 2023/03/05 (v21.1.5):
|
||||||
|
- Add replace underscore with space option to WD14 captioning. Thanks @sALTaccount!
|
||||||
|
- Improve how custom preset is set and handles.
|
||||||
|
- Add support for `--listen` argument. This allow gradio to listen for connections from other devices on the network (or internet). For example: `gui.ps1 --listen "0.0.0.0"` will allow anyone to connect to the gradio webui.
|
||||||
|
- Updated `Resize LoRA` tab to support LoCon resizing. Added new resize
|
||||||
* 2023/03/05 (v21.1.4):
|
* 2023/03/05 (v21.1.4):
|
||||||
- Removing legacy and confusing use 8bit adam chackbox. It is now configured using the Optimiser drop down list. It will be set properly based on legacy config files.
|
- Removing legacy and confusing use 8bit adam chackbox. It is now configured using the Optimiser drop down list. It will be set properly based on legacy config files.
|
||||||
* 2023/03/04 (v21.1.3):
|
* 2023/03/04 (v21.1.3):
|
||||||
|
@ -125,7 +125,7 @@ def main(args):
|
|||||||
tag_text = ""
|
tag_text = ""
|
||||||
for i, p in enumerate(prob[4:]): # numpyとか使うのが良いけど、まあそれほど数も多くないのでループで
|
for i, p in enumerate(prob[4:]): # numpyとか使うのが良いけど、まあそれほど数も多くないのでループで
|
||||||
if p >= args.thresh and i < len(tags):
|
if p >= args.thresh and i < len(tags):
|
||||||
tag_text += ", " + tags[i]
|
tag_text += ", " + (tags[i].replace("_", " ") if args.replace_underscores else tags[i])
|
||||||
|
|
||||||
if len(tag_text) > 0:
|
if len(tag_text) > 0:
|
||||||
tag_text = tag_text[2:] # 最初の ", " を消す
|
tag_text = tag_text[2:] # 最初の ", " を消す
|
||||||
@ -190,6 +190,7 @@ if __name__ == '__main__':
|
|||||||
help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
|
help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
|
||||||
parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子")
|
parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子")
|
||||||
parser.add_argument("--debug", action="store_true", help="debug mode")
|
parser.add_argument("--debug", action="store_true", help="debug mode")
|
||||||
|
parser.add_argument("--replace_underscores", action="store_true", help="replace underscores in tags with spaces / タグのアンダースコアをスペースに置き換える")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
@ -52,6 +52,9 @@ def UI(**kwargs):
|
|||||||
server_port = kwargs.get('server_port', 0)
|
server_port = kwargs.get('server_port', 0)
|
||||||
inbrowser = kwargs.get('inbrowser', False)
|
inbrowser = kwargs.get('inbrowser', False)
|
||||||
share = kwargs.get('share', False)
|
share = kwargs.get('share', False)
|
||||||
|
server_name = kwargs.get('listen')
|
||||||
|
|
||||||
|
launch_kwargs['server_name'] = server_name
|
||||||
if username and password:
|
if username and password:
|
||||||
launch_kwargs['auth'] = (username, password)
|
launch_kwargs['auth'] = (username, password)
|
||||||
if server_port > 0:
|
if server_port > 0:
|
||||||
@ -66,6 +69,9 @@ def UI(**kwargs):
|
|||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# torch.cuda.set_per_process_memory_fraction(0.48)
|
# torch.cuda.set_per_process_memory_fraction(0.48)
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
'--listen', type=str, default='127.0.0.1', help='IP to listen on for connections to Gradio'
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--username', type=str, default='', help='Username for authentication'
|
'--username', type=str, default='', help='Username for authentication'
|
||||||
)
|
)
|
||||||
@ -93,4 +99,5 @@ if __name__ == '__main__':
|
|||||||
inbrowser=args.inbrowser,
|
inbrowser=args.inbrowser,
|
||||||
server_port=args.server_port,
|
server_port=args.server_port,
|
||||||
share=args.share,
|
share=args.share,
|
||||||
|
listen=args.listen,
|
||||||
)
|
)
|
||||||
|
@ -9,18 +9,47 @@ refresh_symbol = '\U0001f504' # 🔄
|
|||||||
save_style_symbol = '\U0001f4be' # 💾
|
save_style_symbol = '\U0001f4be' # 💾
|
||||||
document_symbol = '\U0001F4C4' # 📄
|
document_symbol = '\U0001F4C4' # 📄
|
||||||
|
|
||||||
|
# define a list of substrings to search for v2 base models
|
||||||
|
V2_BASE_MODELS = [
|
||||||
|
'stabilityai/stable-diffusion-2-1-base',
|
||||||
|
'stabilityai/stable-diffusion-2-base',
|
||||||
|
]
|
||||||
|
|
||||||
|
# define a list of substrings to search for v_parameterization models
|
||||||
|
V_PARAMETERIZATION_MODELS = [
|
||||||
|
'stabilityai/stable-diffusion-2-1',
|
||||||
|
'stabilityai/stable-diffusion-2',
|
||||||
|
]
|
||||||
|
|
||||||
|
# define a list of substrings to v1.x models
|
||||||
|
V1_MODELS = [
|
||||||
|
'CompVis/stable-diffusion-v1-4',
|
||||||
|
'runwayml/stable-diffusion-v1-5',
|
||||||
|
]
|
||||||
|
|
||||||
|
# define a list of substrings to search for
|
||||||
|
ALL_PRESET_MODELS = V2_BASE_MODELS + V_PARAMETERIZATION_MODELS + V1_MODELS
|
||||||
|
|
||||||
|
|
||||||
def update_my_data(my_data):
|
def update_my_data(my_data):
|
||||||
if my_data.get('use_8bit_adam', False) == True:
|
if my_data.get('use_8bit_adam', False) == True:
|
||||||
my_data['optimizer'] = 'AdamW8bit'
|
my_data['optimizer'] = 'AdamW8bit'
|
||||||
# my_data['use_8bit_adam'] = False
|
# my_data['use_8bit_adam'] = False
|
||||||
|
|
||||||
if my_data.get('optimizer', 'missing') == 'missing' and my_data.get('use_8bit_adam', False) == False:
|
if (
|
||||||
|
my_data.get('optimizer', 'missing') == 'missing'
|
||||||
|
and my_data.get('use_8bit_adam', False) == False
|
||||||
|
):
|
||||||
my_data['optimizer'] = 'AdamW'
|
my_data['optimizer'] = 'AdamW'
|
||||||
|
|
||||||
if my_data.get('model_list', 'custom') == []:
|
if my_data.get('model_list', 'custom') == []:
|
||||||
print('Old config with empty model list. Setting to custom...')
|
print('Old config with empty model list. Setting to custom...')
|
||||||
my_data['model_list'] = 'custom'
|
my_data['model_list'] = 'custom'
|
||||||
|
|
||||||
|
# If Pretrained model name or path is not one of the preset models then set the preset_model to custom
|
||||||
|
if not my_data.get('pretrained_model_name_or_path', '') in ALL_PRESET_MODELS:
|
||||||
|
my_data['model_list'] = 'custom'
|
||||||
|
|
||||||
return my_data
|
return my_data
|
||||||
|
|
||||||
|
|
||||||
@ -282,63 +311,76 @@ def save_inference_file(output_dir, v2, v_parameterization, output_name):
|
|||||||
def set_pretrained_model_name_or_path_input(
|
def set_pretrained_model_name_or_path_input(
|
||||||
model_list, pretrained_model_name_or_path, v2, v_parameterization
|
model_list, pretrained_model_name_or_path, v2, v_parameterization
|
||||||
):
|
):
|
||||||
# define a list of substrings to search for
|
|
||||||
substrings_v2 = [
|
|
||||||
'stabilityai/stable-diffusion-2-1-base',
|
|
||||||
'stabilityai/stable-diffusion-2-base',
|
|
||||||
]
|
|
||||||
|
|
||||||
# check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v2 list
|
# check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v2 list
|
||||||
if str(model_list) in substrings_v2:
|
if str(model_list) in V2_BASE_MODELS:
|
||||||
print('SD v2 model detected. Setting --v2 parameter')
|
print('SD v2 model detected. Setting --v2 parameter')
|
||||||
v2 = True
|
v2 = True
|
||||||
v_parameterization = False
|
v_parameterization = False
|
||||||
|
pretrained_model_name_or_path = str(model_list)
|
||||||
return model_list, v2, v_parameterization
|
|
||||||
|
|
||||||
# define a list of substrings to search for v-objective
|
|
||||||
substrings_v_parameterization = [
|
|
||||||
'stabilityai/stable-diffusion-2-1',
|
|
||||||
'stabilityai/stable-diffusion-2',
|
|
||||||
]
|
|
||||||
|
|
||||||
# check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v_parameterization list
|
# check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v_parameterization list
|
||||||
if str(model_list) in substrings_v_parameterization:
|
if str(model_list) in V_PARAMETERIZATION_MODELS:
|
||||||
print(
|
print(
|
||||||
'SD v2 v_parameterization detected. Setting --v2 parameter and --v_parameterization'
|
'SD v2 v_parameterization detected. Setting --v2 parameter and --v_parameterization'
|
||||||
)
|
)
|
||||||
v2 = True
|
v2 = True
|
||||||
v_parameterization = True
|
v_parameterization = True
|
||||||
|
pretrained_model_name_or_path = str(model_list)
|
||||||
|
|
||||||
return model_list, v2, v_parameterization
|
if str(model_list) in V1_MODELS:
|
||||||
|
|
||||||
# define a list of substrings to v1.x
|
|
||||||
substrings_v1_model = [
|
|
||||||
'CompVis/stable-diffusion-v1-4',
|
|
||||||
'runwayml/stable-diffusion-v1-5',
|
|
||||||
]
|
|
||||||
|
|
||||||
if str(model_list) in substrings_v1_model:
|
|
||||||
v2 = False
|
v2 = False
|
||||||
v_parameterization = False
|
v_parameterization = False
|
||||||
|
pretrained_model_name_or_path = str(model_list)
|
||||||
return model_list, v2, v_parameterization
|
|
||||||
|
|
||||||
if model_list == 'custom':
|
if model_list == 'custom':
|
||||||
if (
|
if (
|
||||||
str(pretrained_model_name_or_path) in substrings_v1_model
|
str(pretrained_model_name_or_path) in V1_MODELS
|
||||||
or str(pretrained_model_name_or_path) in substrings_v2
|
or str(pretrained_model_name_or_path) in V2_BASE_MODELS
|
||||||
or str(pretrained_model_name_or_path)
|
or str(pretrained_model_name_or_path)
|
||||||
in substrings_v_parameterization
|
in V_PARAMETERIZATION_MODELS
|
||||||
):
|
):
|
||||||
pretrained_model_name_or_path = ''
|
pretrained_model_name_or_path = ''
|
||||||
v2 = False
|
v2 = False
|
||||||
v_parameterization = False
|
v_parameterization = False
|
||||||
return pretrained_model_name_or_path, v2, v_parameterization
|
return model_list, pretrained_model_name_or_path, v2, v_parameterization
|
||||||
|
|
||||||
###
|
def set_v2_checkbox(
|
||||||
### Gradio common GUI section
|
model_list, v2, v_parameterization
|
||||||
###
|
):
|
||||||
|
# check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v2 list
|
||||||
|
if str(model_list) in V2_BASE_MODELS:
|
||||||
|
v2 = True
|
||||||
|
v_parameterization = False
|
||||||
|
|
||||||
|
# check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v_parameterization list
|
||||||
|
if str(model_list) in V_PARAMETERIZATION_MODELS:
|
||||||
|
v2 = True
|
||||||
|
v_parameterization = True
|
||||||
|
|
||||||
|
if str(model_list) in V1_MODELS:
|
||||||
|
v2 = False
|
||||||
|
v_parameterization = False
|
||||||
|
|
||||||
|
return v2, v_parameterization
|
||||||
|
|
||||||
|
def set_model_list(
|
||||||
|
model_list,
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
v2,
|
||||||
|
v_parameterization,
|
||||||
|
):
|
||||||
|
|
||||||
|
if not pretrained_model_name_or_path in ALL_PRESET_MODELS:
|
||||||
|
model_list = 'custom'
|
||||||
|
else:
|
||||||
|
model_list = pretrained_model_name_or_path
|
||||||
|
|
||||||
|
return model_list, v2, v_parameterization
|
||||||
|
|
||||||
|
|
||||||
|
###
|
||||||
|
### Gradio common GUI section
|
||||||
|
###
|
||||||
|
|
||||||
|
|
||||||
def gradio_config():
|
def gradio_config():
|
||||||
@ -362,6 +404,15 @@ def gradio_config():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_pretrained_model_name_or_path_file(
|
||||||
|
model_list, pretrained_model_name_or_path
|
||||||
|
):
|
||||||
|
pretrained_model_name_or_path = get_any_file_path(
|
||||||
|
pretrained_model_name_or_path
|
||||||
|
)
|
||||||
|
set_model_list(model_list, pretrained_model_name_or_path)
|
||||||
|
|
||||||
|
|
||||||
def gradio_source_model():
|
def gradio_source_model():
|
||||||
with gr.Tab('Source model'):
|
with gr.Tab('Source model'):
|
||||||
# Define the input elements
|
# Define the input elements
|
||||||
@ -419,6 +470,8 @@ def gradio_source_model():
|
|||||||
v_parameterization = gr.Checkbox(
|
v_parameterization = gr.Checkbox(
|
||||||
label='v_parameterization', value=False
|
label='v_parameterization', value=False
|
||||||
)
|
)
|
||||||
|
v2.change(set_v2_checkbox, inputs=[model_list, v2, v_parameterization], outputs=[v2, v_parameterization],show_progress=False)
|
||||||
|
v_parameterization.change(set_v2_checkbox, inputs=[model_list, v2, v_parameterization], outputs=[v2, v_parameterization],show_progress=False)
|
||||||
model_list.change(
|
model_list.change(
|
||||||
set_pretrained_model_name_or_path_input,
|
set_pretrained_model_name_or_path_input,
|
||||||
inputs=[
|
inputs=[
|
||||||
@ -428,10 +481,28 @@ def gradio_source_model():
|
|||||||
v_parameterization,
|
v_parameterization,
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
|
model_list,
|
||||||
pretrained_model_name_or_path,
|
pretrained_model_name_or_path,
|
||||||
v2,
|
v2,
|
||||||
v_parameterization,
|
v_parameterization,
|
||||||
],
|
],
|
||||||
|
show_progress=False,
|
||||||
|
)
|
||||||
|
# Update the model list and parameters when user click outside the button or field
|
||||||
|
pretrained_model_name_or_path.change(
|
||||||
|
set_model_list,
|
||||||
|
inputs=[
|
||||||
|
model_list,
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
v2,
|
||||||
|
v_parameterization,
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
model_list,
|
||||||
|
v2,
|
||||||
|
v_parameterization,
|
||||||
|
],
|
||||||
|
show_progress=False,
|
||||||
)
|
)
|
||||||
return (
|
return (
|
||||||
pretrained_model_name_or_path,
|
pretrained_model_name_or_path,
|
||||||
|
@ -17,6 +17,9 @@ def resize_lora(
|
|||||||
save_to,
|
save_to,
|
||||||
save_precision,
|
save_precision,
|
||||||
device,
|
device,
|
||||||
|
dynamic_method,
|
||||||
|
dynamic_param,
|
||||||
|
verbose,
|
||||||
):
|
):
|
||||||
# Check for caption_text_input
|
# Check for caption_text_input
|
||||||
if model == '':
|
if model == '':
|
||||||
@ -27,16 +30,35 @@ def resize_lora(
|
|||||||
if not os.path.isfile(model):
|
if not os.path.isfile(model):
|
||||||
msgbox('The provided model is not a file')
|
msgbox('The provided model is not a file')
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if dynamic_method == 'sv_ratio':
|
||||||
|
if float(dynamic_param) < 2:
|
||||||
|
msgbox(f'Dynamic parameter for {dynamic_method} need to be 2 or greater...')
|
||||||
|
return
|
||||||
|
|
||||||
|
if dynamic_method == 'sv_fro' or dynamic_method == 'sv_cumulative':
|
||||||
|
if float(dynamic_param) < 0 or float(dynamic_param) > 1:
|
||||||
|
msgbox(f'Dynamic parameter for {dynamic_method} need to be between 0 and 1...')
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check if save_to end with one of the defines extension. If not add .safetensors.
|
||||||
|
if not save_to.endswith(('.pt', '.safetensors')):
|
||||||
|
save_to += '.safetensors'
|
||||||
|
|
||||||
if device == '':
|
if device == '':
|
||||||
device = 'cuda'
|
device = 'cuda'
|
||||||
|
|
||||||
run_cmd = f'{PYTHON} "{os.path.join("networks","resize_lora.py")}"'
|
run_cmd = f'{PYTHON} "{os.path.join("tools","resize_lora.py")}"'
|
||||||
run_cmd += f' --save_precision {save_precision}'
|
run_cmd += f' --save_precision {save_precision}'
|
||||||
run_cmd += f' --save_to {save_to}'
|
run_cmd += f' --save_to {save_to}'
|
||||||
run_cmd += f' --model {model}'
|
run_cmd += f' --model {model}'
|
||||||
run_cmd += f' --new_rank {new_rank}'
|
run_cmd += f' --new_rank {new_rank}'
|
||||||
run_cmd += f' --device {device}'
|
run_cmd += f' --device {device}'
|
||||||
|
if not dynamic_method == 'None':
|
||||||
|
run_cmd += f' --dynamic_method {dynamic_method}'
|
||||||
|
run_cmd += f' --dynamic_param {dynamic_param}'
|
||||||
|
if verbose:
|
||||||
|
run_cmd += f' --verbose'
|
||||||
|
|
||||||
print(run_cmd)
|
print(run_cmd)
|
||||||
|
|
||||||
@ -56,7 +78,7 @@ def gradio_resize_lora_tab():
|
|||||||
with gr.Tab('Resize LoRA'):
|
with gr.Tab('Resize LoRA'):
|
||||||
gr.Markdown('This utility can resize a LoRA.')
|
gr.Markdown('This utility can resize a LoRA.')
|
||||||
|
|
||||||
lora_ext = gr.Textbox(value='*.pt *.safetensors', visible=False)
|
lora_ext = gr.Textbox(value='*.safetensors *.pt', visible=False)
|
||||||
lora_ext_name = gr.Textbox(value='LoRA model types', visible=False)
|
lora_ext_name = gr.Textbox(value='LoRA model types', visible=False)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
@ -84,6 +106,27 @@ def gradio_resize_lora_tab():
|
|||||||
interactive=True,
|
interactive=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
dynamic_method = gr.Dropdown(
|
||||||
|
choices=['None',
|
||||||
|
'sv_ratio',
|
||||||
|
'sv_fro',
|
||||||
|
'sv_cumulative'
|
||||||
|
],
|
||||||
|
value='sv_fro',
|
||||||
|
label='Dynamic method',
|
||||||
|
interactive=True
|
||||||
|
)
|
||||||
|
dynamic_param = gr.Textbox(
|
||||||
|
label='Dynamic parameter',
|
||||||
|
value='0.9',
|
||||||
|
interactive=True,
|
||||||
|
placeholder='Value for the dynamic method selected.'
|
||||||
|
)
|
||||||
|
verbose = gr.Checkbox(
|
||||||
|
label='Verbose',
|
||||||
|
value=False
|
||||||
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
save_to = gr.Textbox(
|
save_to = gr.Textbox(
|
||||||
label='Save to',
|
label='Save to',
|
||||||
@ -109,6 +152,7 @@ def gradio_resize_lora_tab():
|
|||||||
label='Device',
|
label='Device',
|
||||||
placeholder='{Optional) device to use, cuda for GPU. Default: cuda',
|
placeholder='{Optional) device to use, cuda for GPU. Default: cuda',
|
||||||
interactive=True,
|
interactive=True,
|
||||||
|
value='cuda',
|
||||||
)
|
)
|
||||||
|
|
||||||
convert_button = gr.Button('Resize model')
|
convert_button = gr.Button('Resize model')
|
||||||
@ -121,6 +165,9 @@ def gradio_resize_lora_tab():
|
|||||||
save_to,
|
save_to,
|
||||||
save_precision,
|
save_precision,
|
||||||
device,
|
device,
|
||||||
|
dynamic_method,
|
||||||
|
dynamic_param,
|
||||||
|
verbose,
|
||||||
],
|
],
|
||||||
show_progress=False,
|
show_progress=False,
|
||||||
)
|
)
|
||||||
|
@ -5,7 +5,7 @@ from .common_gui import get_folder_path
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
||||||
def caption_images(train_data_dir, caption_extension, batch_size, thresh):
|
def caption_images(train_data_dir, caption_extension, batch_size, thresh, replace_underscores):
|
||||||
# Check for caption_text_input
|
# Check for caption_text_input
|
||||||
# if caption_text_input == "":
|
# if caption_text_input == "":
|
||||||
# msgbox("Caption text is missing...")
|
# msgbox("Caption text is missing...")
|
||||||
@ -24,6 +24,7 @@ def caption_images(train_data_dir, caption_extension, batch_size, thresh):
|
|||||||
run_cmd = f'accelerate launch "./finetune/tag_images_by_wd14_tagger.py"'
|
run_cmd = f'accelerate launch "./finetune/tag_images_by_wd14_tagger.py"'
|
||||||
run_cmd += f' --batch_size="{int(batch_size)}"'
|
run_cmd += f' --batch_size="{int(batch_size)}"'
|
||||||
run_cmd += f' --thresh="{thresh}"'
|
run_cmd += f' --thresh="{thresh}"'
|
||||||
|
run_cmd += f' --replace_underscores' if replace_underscores else ''
|
||||||
if caption_extension != '':
|
if caption_extension != '':
|
||||||
run_cmd += f' --caption_extension="{caption_extension}"'
|
run_cmd += f' --caption_extension="{caption_extension}"'
|
||||||
run_cmd += f' "{train_data_dir}"'
|
run_cmd += f' "{train_data_dir}"'
|
||||||
@ -75,11 +76,17 @@ def gradio_wd14_caption_gui_tab():
|
|||||||
batch_size = gr.Number(
|
batch_size = gr.Number(
|
||||||
value=1, label='Batch size', interactive=True
|
value=1, label='Batch size', interactive=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
replace_underscores = gr.Checkbox(
|
||||||
|
label='Replace underscores in filenames with spaces',
|
||||||
|
value=False,
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
|
||||||
caption_button = gr.Button('Caption images')
|
caption_button = gr.Button('Caption images')
|
||||||
|
|
||||||
caption_button.click(
|
caption_button.click(
|
||||||
caption_images,
|
caption_images,
|
||||||
inputs=[train_data_dir, caption_extension, batch_size, thresh],
|
inputs=[train_data_dir, caption_extension, batch_size, thresh, replace_underscores],
|
||||||
show_progress=False,
|
show_progress=False,
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user