Merge pull request #63 from bmaltais/dev

Quick fix for captioninf extension
This commit is contained in:
bmaltais 2023-01-20 19:08:28 -05:00 committed by GitHub
commit f7e8a807a4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 23 additions and 10 deletions

View File

@ -333,10 +333,6 @@ def train_model(
run_cmd += ' --enable_bucket' run_cmd += ' --enable_bucket'
if no_token_padding: if no_token_padding:
run_cmd += ' --no_token_padding' run_cmd += ' --no_token_padding'
if use_8bit_adam:
run_cmd += ' --use_8bit_adam'
if xformers:
run_cmd += ' --xformers'
run_cmd += ( run_cmd += (
f' --pretrained_model_name_or_path="{pretrained_model_name_or_path}"' f' --pretrained_model_name_or_path="{pretrained_model_name_or_path}"'
) )

View File

@ -16,6 +16,10 @@ def caption_images(
if images_dir_input == '': if images_dir_input == '':
msgbox('Image folder is missing...') msgbox('Image folder is missing...')
return return
if caption_file_ext == '':
msgbox('Please provide an extension for the caption files.')
return
if not caption_text_input == '': if not caption_text_input == '':
print( print(
@ -83,7 +87,8 @@ def gradio_basic_caption_gui_tab():
) )
caption_file_ext = gr.Textbox( caption_file_ext = gr.Textbox(
label='Caption file extension', label='Caption file extension',
placeholder='(Optional) Default: .caption', placeholder='Extention for caption file. eg: .caption, .txt',
value='.txt',
interactive=True, interactive=True,
) )
overwrite_input = gr.Checkbox( overwrite_input = gr.Checkbox(

View File

@ -26,6 +26,10 @@ def caption_images(
if train_data_dir == '': if train_data_dir == '':
msgbox('Image folder is missing...') msgbox('Image folder is missing...')
return return
if caption_file_ext == '':
msgbox('Please provide an extension for the caption files.')
return
print(f'Captioning files in {train_data_dir}...') print(f'Captioning files in {train_data_dir}...')
run_cmd = f'.\\venv\\Scripts\\python.exe "finetune/make_captions.py"' run_cmd = f'.\\venv\\Scripts\\python.exe "finetune/make_captions.py"'
@ -82,7 +86,8 @@ def gradio_blip_caption_gui_tab():
with gr.Row(): with gr.Row():
caption_file_ext = gr.Textbox( caption_file_ext = gr.Textbox(
label='Caption file extension', label='Caption file extension',
placeholder='(Optional) Default: .caption', placeholder='Extention for caption file. eg: .caption, .txt',
value='.txt',
interactive=True, interactive=True,
) )

View File

@ -200,7 +200,7 @@ def find_replace(folder='', caption_file_ext='.caption', find='', replace=''):
files = [f for f in os.listdir(folder) if f.endswith(caption_file_ext)] files = [f for f in os.listdir(folder) if f.endswith(caption_file_ext)]
for file in files: for file in files:
with open(os.path.join(folder, file), 'r') as f: with open(os.path.join(folder, file), 'r', errors="ignore") as f:
content = f.read() content = f.read()
f.close f.close
content = content.replace(find, replace) content = content.replace(find, replace)

View File

@ -14,6 +14,10 @@ def caption_images(train_data_dir, caption_extension, batch_size, thresh):
if train_data_dir == '': if train_data_dir == '':
msgbox('Image folder is missing...') msgbox('Image folder is missing...')
return return
if caption_extension == '':
msgbox('Please provide an extension for the caption files.')
return
print(f'Captioning files in {train_data_dir}...') print(f'Captioning files in {train_data_dir}...')
run_cmd = f'accelerate launch "./finetune/tag_images_by_wd14_tagger.py"' run_cmd = f'accelerate launch "./finetune/tag_images_by_wd14_tagger.py"'
@ -56,7 +60,8 @@ def gradio_wd14_caption_gui_tab():
caption_extension = gr.Textbox( caption_extension = gr.Textbox(
label='Caption file extension', label='Caption file extension',
placeholder='(Optional) Default: .caption', placeholder='Extention for caption file. eg: .caption, .txt',
value='.txt',
interactive=True, interactive=True,
) )
thresh = gr.Number(value=0.35, label='Threshold') thresh = gr.Number(value=0.35, label='Threshold')

View File

@ -267,6 +267,10 @@ def train_model(
if output_dir == '': if output_dir == '':
msgbox('Output folder path is missing') msgbox('Output folder path is missing')
return return
if stop_text_encoder_training_pct > 0:
msgbox('Output "stop text encoder training" is not yet supported. Ignoring')
stop_text_encoder_training_pct = 0
# If string is empty set string to 0. # If string is empty set string to 0.
if text_encoder_lr == '': if text_encoder_lr == '':
@ -353,8 +357,6 @@ def train_model(
run_cmd += f' --reg_data_dir="{reg_data_dir}"' run_cmd += f' --reg_data_dir="{reg_data_dir}"'
run_cmd += f' --resolution={max_resolution}' run_cmd += f' --resolution={max_resolution}'
run_cmd += f' --output_dir="{output_dir}"' run_cmd += f' --output_dir="{output_dir}"'
run_cmd += f' --use_8bit_adam'
run_cmd += f' --xformers'
run_cmd += f' --logging_dir="{logging_dir}"' run_cmd += f' --logging_dir="{logging_dir}"'
if not stop_text_encoder_training == 0: if not stop_text_encoder_training == 0:
run_cmd += ( run_cmd += (