Merge pull request #403 from bmaltais/dev

v21.2.5
This commit is contained in:
bmaltais 2023-03-19 20:07:41 -04:00 committed by GitHub
commit 1ac6892a2b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 567 additions and 133 deletions

View File

@ -189,6 +189,10 @@ This will store your a backup file with your current locally installed pip packa
## Change History
* 2023/03/19 (v21.2.5):
- Fix basic captioning logic
- Add possibility to not train TE in Dreamboot by setting `Step text encoder training` to -1.
- Update linux scripts
* 2023/03/12 (v21.2.4):
- Fix issue with kohya locon not training the convolution layers
- Update LyCORIS module version

View File

@ -377,7 +377,9 @@ def train_model(
print(f'max_train_steps = {max_train_steps}')
# calculate stop encoder training
if stop_text_encoder_training_pct == None:
if int(stop_text_encoder_training_pct) == -1:
stop_text_encoder_training = -1
elif stop_text_encoder_training_pct == None:
stop_text_encoder_training = 0
else:
stop_text_encoder_training = math.ceil(
@ -624,7 +626,7 @@ def dreambooth_tab(
placeholder='512,512',
)
stop_text_encoder_training = gr.Slider(
minimum=0,
minimum=-1,
maximum=100,
value=0,
step=1,

12
gui.sh
View File

@ -1,3 +1,13 @@
#!/bin/bash
# Activate the virtual environment
source venv/bin/activate
python kohya_gui.py
# Validate the requirements and store the exit code
python tools/validate_requirements.py
exit_code=$?
# If the exit code is 0, run the kohya_gui.py script with the command-line arguments
if [ $exit_code -eq 0 ]; then
python kohya_gui.py "$@"
fi

View File

@ -6,7 +6,9 @@ from finetune_gui import finetune_tab
from textual_inversion_gui import ti_tab
from library.utilities import utilities_tab
from library.extract_lora_gui import gradio_extract_lora_tab
from library.extract_lycoris_locon_gui import gradio_extract_lycoris_locon_tab
from library.merge_lora_gui import gradio_merge_lora_tab
from library.resize_lora_gui import gradio_resize_lora_tab
from lora_gui import lora_tab
@ -43,7 +45,9 @@ def UI(**kwargs):
enable_copy_info_button=True,
)
gradio_extract_lora_tab()
gradio_extract_lycoris_locon_tab()
gradio_merge_lora_tab()
gradio_resize_lora_tab()
# Show the interface
launch_kwargs = {}

View File

@ -6,35 +6,33 @@ import os
def caption_images(
caption_text_input,
images_dir_input,
overwrite_input,
caption_file_ext,
caption_text,
images_dir,
overwrite,
caption_ext,
prefix,
postfix,
find,
replace,
find_text,
replace_text,
):
# Check for images_dir_input
if images_dir_input == '':
# Check for images_dir
if not images_dir:
msgbox('Image folder is missing...')
return
if caption_file_ext == '':
if not caption_ext:
msgbox('Please provide an extension for the caption files.')
return
if not caption_text_input == '':
print(
f'Captioning files in {images_dir_input} with {caption_text_input}...'
)
if caption_text:
print(f'Captioning files in {images_dir} with {caption_text}...')
run_cmd = f'python "tools/caption.py"'
run_cmd += f' --caption_text="{caption_text_input}"'
if overwrite_input:
run_cmd += f' --caption_text="{caption_text}"'
if overwrite:
run_cmd += f' --overwrite'
if caption_file_ext != '':
run_cmd += f' --caption_file_ext="{caption_file_ext}"'
run_cmd += f' "{images_dir_input}"'
if caption_ext:
run_cmd += f' --caption_file_ext="{caption_ext}"'
run_cmd += f' "{images_dir}"'
print(run_cmd)
@ -44,24 +42,24 @@ def caption_images(
else:
subprocess.run(run_cmd)
if overwrite_input:
if not prefix == '' or not postfix == '':
if overwrite:
if prefix or postfix:
# Add prefix and postfix
add_pre_postfix(
folder=images_dir_input,
caption_file_ext=caption_file_ext,
folder=images_dir,
caption_file_ext=caption_ext,
prefix=prefix,
postfix=postfix,
)
if not find == '':
if find_text:
find_replace(
folder=images_dir_input,
caption_file_ext=caption_file_ext,
find=find,
replace=replace,
folder=images_dir,
caption_file_ext=caption_ext,
find=find_text,
replace=replace_text,
)
else:
if not prefix == '' or not postfix == '':
if prefix or postfix:
msgbox(
'Could not modify caption files with requested change because the "Overwrite existing captions in folder" option is not selected...'
)
@ -69,37 +67,31 @@ def caption_images(
print('...captioning done')
###
# Gradio UI
###
def gradio_basic_caption_gui_tab():
with gr.Tab('Basic Captioning'):
gr.Markdown(
'This utility will allow the creation of simple caption files for each images in a folder.'
'This utility will allow the creation of simple caption files for each image in a folder.'
)
with gr.Row():
images_dir_input = gr.Textbox(
images_dir = gr.Textbox(
label='Image folder to caption',
placeholder='Directory containing the images to caption',
interactive=True,
)
button_images_dir_input = gr.Button(
'📂', elem_id='open_folder_small'
)
button_images_dir_input.click(
folder_button = gr.Button('📂', elem_id='open_folder_small')
folder_button.click(
get_folder_path,
outputs=images_dir_input,
outputs=images_dir,
show_progress=False,
)
caption_file_ext = gr.Textbox(
caption_ext = gr.Textbox(
label='Caption file extension',
placeholder='Extention for caption file. eg: .caption, .txt',
placeholder='Extension for caption file. eg: .caption, .txt',
value='.txt',
interactive=True,
)
overwrite_input = gr.Checkbox(
overwrite = gr.Checkbox(
label='Overwrite existing captions in folder',
interactive=True,
value=False,
@ -110,7 +102,7 @@ def gradio_basic_caption_gui_tab():
placeholder='(Optional)',
interactive=True,
)
caption_text_input = gr.Textbox(
caption_text = gr.Textbox(
label='Caption text',
placeholder='Eg: , by some artist. Leave empty if you just want to add pre or postfix',
interactive=True,
@ -121,29 +113,28 @@ def gradio_basic_caption_gui_tab():
interactive=True,
)
with gr.Row():
find = gr.Textbox(
find_text = gr.Textbox(
label='Find text',
placeholder='Eg: , by some artist. Leave empty if you just want to add pre or postfix',
interactive=True,
)
replace = gr.Textbox(
replace_text = gr.Textbox(
label='Replacement text',
placeholder='Eg: , by some artist. Leave empty if you just want to replace with nothing',
interactive=True,
)
caption_button = gr.Button('Caption images')
caption_button.click(
caption_images,
inputs=[
caption_text,
images_dir,
overwrite,
caption_ext,
prefix,
postfix,
find_text,
replace_text,
],
show_progress=False,
)
caption_button = gr.Button('Caption images')
caption_button.click(
caption_images,
inputs=[
caption_text_input,
images_dir_input,
overwrite_input,
caption_file_ext,
prefix,
postfix,
find,
replace,
],
show_progress=False,
)

View File

@ -32,77 +32,112 @@ ALL_PRESET_MODELS = V2_BASE_MODELS + V_PARAMETERIZATION_MODELS + V1_MODELS
def update_my_data(my_data):
if my_data.get('use_8bit_adam', False) == True:
# Update optimizer based on use_8bit_adam flag
use_8bit_adam = my_data.get('use_8bit_adam', False)
if use_8bit_adam:
my_data['optimizer'] = 'AdamW8bit'
# my_data['use_8bit_adam'] = False
if (
my_data.get('optimizer', 'missing') == 'missing'
and my_data.get('use_8bit_adam', False) == False
):
elif 'optimizer' not in my_data:
my_data['optimizer'] = 'AdamW'
if my_data.get('model_list', 'custom') == []:
print('Old config with empty model list. Setting to custom...')
my_data['model_list'] = 'custom'
# If Pretrained model name or path is not one of the preset models then set the preset_model to custom
if not my_data.get('pretrained_model_name_or_path', '') in ALL_PRESET_MODELS:
my_data['model_list'] = 'custom'
# Fix old config files that contain epoch as str instead of int
# Update model_list to custom if empty or pretrained_model_name_or_path is not a preset model
model_list = my_data.get('model_list', [])
pretrained_model_name_or_path = my_data.get('pretrained_model_name_or_path', '')
if not model_list or pretrained_model_name_or_path not in ALL_PRESET_MODELS:
my_data['model_list'] = 'custom'
# Convert epoch and save_every_n_epochs values to int if they are strings
for key in ['epoch', 'save_every_n_epochs']:
value = my_data.get(key, -1)
if type(value) == str:
if value != '':
my_data[key] = int(value)
else:
my_data[key] = -1
if isinstance(value, str) and value:
my_data[key] = int(value)
elif not value:
my_data[key] = -1
# Update LoRA_type if it is set to LoCon
if my_data.get('LoRA_type', 'Standard') == 'LoCon':
my_data['LoRA_type'] = 'LyCORIS/LoCon'
return my_data
# def update_my_data(my_data):
# if my_data.get('use_8bit_adam', False) == True:
# my_data['optimizer'] = 'AdamW8bit'
# # my_data['use_8bit_adam'] = False
# if (
# my_data.get('optimizer', 'missing') == 'missing'
# and my_data.get('use_8bit_adam', False) == False
# ):
# my_data['optimizer'] = 'AdamW'
# if my_data.get('model_list', 'custom') == []:
# print('Old config with empty model list. Setting to custom...')
# my_data['model_list'] = 'custom'
# # If Pretrained model name or path is not one of the preset models then set the preset_model to custom
# if not my_data.get('pretrained_model_name_or_path', '') in ALL_PRESET_MODELS:
# my_data['model_list'] = 'custom'
# # Fix old config files that contain epoch as str instead of int
# for key in ['epoch', 'save_every_n_epochs']:
# value = my_data.get(key, -1)
# if type(value) == str:
# if value != '':
# my_data[key] = int(value)
# else:
# my_data[key] = -1
# if my_data.get('LoRA_type', 'Standard') == 'LoCon':
# my_data['LoRA_type'] = 'LyCORIS/LoCon'
# return my_data
def get_dir_and_file(file_path):
dir_path, file_name = os.path.split(file_path)
return (dir_path, file_name)
def has_ext_files(directory, extension):
# Iterate through all the files in the directory
for file in os.listdir(directory):
# If the file name ends with extension, return True
if file.endswith(extension):
return True
# If no extension files were found, return False
return False
# def has_ext_files(directory, extension):
# # Iterate through all the files in the directory
# for file in os.listdir(directory):
# # If the file name ends with extension, return True
# if file.endswith(extension):
# return True
# # If no extension files were found, return False
# return False
def get_file_path(
file_path='', defaultextension='.json', extension_name='Config files'
file_path='', default_extension='.json', extension_name='Config files'
):
current_file_path = file_path
# print(f'current file path: {current_file_path}')
initial_dir, initial_file = get_dir_and_file(file_path)
# Create a hidden Tkinter root window
root = Tk()
root.wm_attributes('-topmost', 1)
root.withdraw()
# Show the open file dialog and get the selected file path
file_path = filedialog.askopenfilename(
filetypes=(
(f'{extension_name}', f'{defaultextension}'),
('All files', '*'),
(extension_name, f'*{default_extension}'),
('All files', '*.*'),
),
defaultextension=defaultextension,
defaultextension=default_extension,
initialfile=initial_file,
initialdir=initial_dir,
)
# Destroy the hidden root window
root.destroy()
if file_path == '':
# If no file is selected, use the current file path
if not file_path:
file_path = current_file_path
return file_path
@ -230,52 +265,146 @@ def get_saveasfilename_path(
def add_pre_postfix(
folder='', prefix='', postfix='', caption_file_ext='.caption'
):
if not has_ext_files(folder, caption_file_ext):
msgbox(
f'No files with extension {caption_file_ext} were found in {folder}...'
)
return
folder: str = '',
prefix: str = '',
postfix: str = '',
caption_file_ext: str = '.caption'
) -> None:
"""
Add prefix and/or postfix to the content of caption files within a folder.
If no caption files are found, create one with the requested prefix and/or postfix.
Args:
folder (str): Path to the folder containing caption files.
prefix (str, optional): Prefix to add to the content of the caption files.
postfix (str, optional): Postfix to add to the content of the caption files.
caption_file_ext (str, optional): Extension of the caption files.
"""
if prefix == '' and postfix == '':
return
files = [f for f in os.listdir(folder) if f.endswith(caption_file_ext)]
if not prefix == '':
prefix = f'{prefix} '
if not postfix == '':
postfix = f' {postfix}'
image_extensions = ('.jpg', '.jpeg', '.png', '.webp')
image_files = [f for f in os.listdir(folder) if f.lower().endswith(image_extensions)]
for file in files:
with open(os.path.join(folder, file), 'r+') as f:
content = f.read()
content = content.rstrip()
f.seek(0, 0)
f.write(f'{prefix}{content}{postfix}')
f.close()
for image_file in image_files:
caption_file_name = os.path.splitext(image_file)[0] + caption_file_ext
caption_file_path = os.path.join(folder, caption_file_name)
if not os.path.exists(caption_file_path):
with open(caption_file_path, 'w') as f:
separator = ' ' if prefix and postfix else ''
f.write(f'{prefix}{separator}{postfix}')
else:
with open(caption_file_path, 'r+') as f:
content = f.read()
content = content.rstrip()
f.seek(0, 0)
prefix_separator = ' ' if prefix else ''
postfix_separator = ' ' if postfix else ''
f.write(f'{prefix}{prefix_separator}{content}{postfix_separator}{postfix}')
# def add_pre_postfix(
# folder='', prefix='', postfix='', caption_file_ext='.caption'
# ):
# if not has_ext_files(folder, caption_file_ext):
# msgbox(
# f'No files with extension {caption_file_ext} were found in {folder}...'
# )
# return
# if prefix == '' and postfix == '':
# return
# files = [f for f in os.listdir(folder) if f.endswith(caption_file_ext)]
# if not prefix == '':
# prefix = f'{prefix} '
# if not postfix == '':
# postfix = f' {postfix}'
# for file in files:
# with open(os.path.join(folder, file), 'r+') as f:
# content = f.read()
# content = content.rstrip()
# f.seek(0, 0)
# f.write(f'{prefix} {content} {postfix}')
# f.close()
def find_replace(folder='', caption_file_ext='.caption', find='', replace=''):
def has_ext_files(folder_path: str, file_extension: str) -> bool:
"""
Check if there are any files with the specified extension in the given folder.
Args:
folder_path (str): Path to the folder containing files.
file_extension (str): Extension of the files to look for.
Returns:
bool: True if files with the specified extension are found, False otherwise.
"""
for file in os.listdir(folder_path):
if file.endswith(file_extension):
return True
return False
def find_replace(
folder_path: str = '',
caption_file_ext: str = '.caption',
search_text: str = '',
replace_text: str = ''
) -> None:
"""
Find and replace text in caption files within a folder.
Args:
folder_path (str, optional): Path to the folder containing caption files.
caption_file_ext (str, optional): Extension of the caption files.
search_text (str, optional): Text to search for in the caption files.
replace_text (str, optional): Text to replace the search text with.
"""
print('Running caption find/replace')
if not has_ext_files(folder, caption_file_ext):
if not has_ext_files(folder_path, caption_file_ext):
msgbox(
f'No files with extension {caption_file_ext} were found in {folder}...'
f'No files with extension {caption_file_ext} were found in {folder_path}...'
)
return
if find == '':
if search_text == '':
return
files = [f for f in os.listdir(folder) if f.endswith(caption_file_ext)]
for file in files:
with open(os.path.join(folder, file), 'r', errors='ignore') as f:
caption_files = [f for f in os.listdir(folder_path) if f.endswith(caption_file_ext)]
for caption_file in caption_files:
with open(os.path.join(folder_path, caption_file), 'r', errors='ignore') as f:
content = f.read()
f.close
content = content.replace(find, replace)
with open(os.path.join(folder, file), 'w') as f:
content = content.replace(search_text, replace_text)
with open(os.path.join(folder_path, caption_file), 'w') as f:
f.write(content)
f.close()
# def find_replace(folder='', caption_file_ext='.caption', find='', replace=''):
# print('Running caption find/replace')
# if not has_ext_files(folder, caption_file_ext):
# msgbox(
# f'No files with extension {caption_file_ext} were found in {folder}...'
# )
# return
# if find == '':
# return
# files = [f for f in os.listdir(folder) if f.endswith(caption_file_ext)]
# for file in files:
# with open(os.path.join(folder, file), 'r', errors='ignore') as f:
# content = f.read()
# f.close
# content = content.replace(find, replace)
# with open(os.path.join(folder, file), 'w') as f:
# f.write(content)
# f.close()
def color_aug_changed(color_aug):

View File

@ -0,0 +1,273 @@
import gradio as gr
from easygui import msgbox
import subprocess
import os
from .common_gui import (
get_saveasfilename_path,
get_any_file_path,
get_file_path,
)
folder_symbol = '\U0001f4c2' # 📂
refresh_symbol = '\U0001f504' # 🔄
save_style_symbol = '\U0001f4be' # 💾
document_symbol = '\U0001F4C4' # 📄
PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe'
def extract_lycoris_locon(
db_model, base_model, output_name, device,
is_v2, mode, linear_dim, conv_dim,
linear_threshold, conv_threshold,
linear_ratio, conv_ratio,
linear_quantile, conv_quantile,
use_sparse_bias, sparsity, disable_cp
):
# Check for caption_text_input
if db_model == '':
msgbox('Invalid finetuned model file')
return
if base_model == '':
msgbox('Invalid base model file')
return
# Check if source model exist
if not os.path.isfile(db_model):
msgbox('The provided finetuned model is not a file')
return
if not os.path.isfile(base_model):
msgbox('The provided base model is not a file')
return
run_cmd = (
f'{PYTHON} "{os.path.join("tools","lycoris_locon_extract.py")}"'
)
if is_v2:
run_cmd += f' --is_v2'
run_cmd += f' --device {device}'
run_cmd += f' --mode {mode}'
run_cmd += f' --safetensors'
run_cmd += f' --linear_dim {linear_dim}'
run_cmd += f' --conv_dim {conv_dim}'
run_cmd += f' --linear_threshold {linear_threshold}'
run_cmd += f' --conv_threshold {conv_threshold}'
run_cmd += f' --linear_ratio {linear_ratio}'
run_cmd += f' --conv_ratio {conv_ratio}'
run_cmd += f' --linear_quantile {linear_quantile}'
run_cmd += f' --conv_quantile {conv_quantile}'
if use_sparse_bias:
run_cmd += f' --use_sparse_bias'
run_cmd += f' --sparsity {sparsity}'
if disable_cp:
run_cmd += f' --disable_cp'
run_cmd += f' "{base_model}"'
run_cmd += f' "{db_model}"'
run_cmd += f' "{output_name}"'
print(run_cmd)
# Run the command
if os.name == 'posix':
os.system(run_cmd)
else:
subprocess.run(run_cmd)
###
# Gradio UI
###
# def update_mode(mode):
# # 'fixed', 'threshold','ratio','quantile'
# if mode == 'fixed':
# return gr.Row.update(visible=True), gr.Row.update(visible=False), gr.Row.update(visible=False), gr.Row.update(visible=False)
# if mode == 'threshold':
# return gr.Row.update(visible=False), gr.Row.update(visible=True), gr.Row.update(visible=False), gr.Row.update(visible=False)
# if mode == 'ratio':
# return gr.Row.update(visible=False), gr.Row.update(visible=False), gr.Row.update(visible=True), gr.Row.update(visible=False)
# if mode == 'threshold':
# return gr.Row.update(visible=False), gr.Row.update(visible=False), gr.Row.update(visible=False), gr.Row.update(visible=True)
def update_mode(mode):
# Create a list of possible mode values
modes = ['fixed', 'threshold', 'ratio', 'quantile']
# Initialize an empty list to store visibility updates
updates = []
# Iterate through the possible modes
for m in modes:
# Add a visibility update for each mode, setting it to True if the input mode matches the current mode in the loop
updates.append(gr.Row.update(visible=(mode == m)))
# Return the visibility updates as a tuple
return tuple(updates)
def gradio_extract_lycoris_locon_tab():
with gr.Tab('Extract LyCORIS LoCON'):
gr.Markdown(
'This utility can extract a LyCORIS LoCon network from a finetuned model.'
)
lora_ext = gr.Textbox(value='*.safetensors', visible=False) # lora_ext = gr.Textbox(value='*.safetensors *.pt', visible=False)
lora_ext_name = gr.Textbox(value='LoRA model types', visible=False)
model_ext = gr.Textbox(value='*.safetensors *.ckpt', visible=False)
model_ext_name = gr.Textbox(value='Model types', visible=False)
with gr.Row():
db_model = gr.Textbox(
label='Finetuned model',
placeholder='Path to the finetuned model to extract',
interactive=True,
)
button_db_model_file = gr.Button(
folder_symbol, elem_id='open_folder_small'
)
button_db_model_file.click(
get_file_path,
inputs=[db_model, model_ext, model_ext_name],
outputs=db_model,
show_progress=False,
)
base_model = gr.Textbox(
label='Stable Diffusion base model',
placeholder='Stable Diffusion original model: ckpt or safetensors file',
interactive=True,
)
button_base_model_file = gr.Button(
folder_symbol, elem_id='open_folder_small'
)
button_base_model_file.click(
get_file_path,
inputs=[base_model, model_ext, model_ext_name],
outputs=base_model,
show_progress=False,
)
with gr.Row():
output_name = gr.Textbox(
label='Save to',
placeholder='path where to save the extracted LoRA model...',
interactive=True,
)
button_output_name = gr.Button(
folder_symbol, elem_id='open_folder_small'
)
button_output_name.click(
get_saveasfilename_path,
inputs=[output_name, lora_ext, lora_ext_name],
outputs=output_name,
show_progress=False,
)
device = gr.Dropdown(
label='Device',
choices=['cpu', 'cuda',],
value='cuda',
interactive=True,
)
is_v2 = gr.Checkbox(label='is v2', value=False, interactive=True)
mode = gr.Dropdown(
label='Mode',
choices=['fixed', 'threshold','ratio','quantile'],
value='fixed',
interactive=True,
)
with gr.Row(visible=True) as fixed:
linear_dim = gr.Slider(
minimum=1,
maximum=1024,
label='Network Dimension',
value=1,
step=1,
interactive=True,
)
conv_dim = gr.Slider(
minimum=1,
maximum=1024,
label='Conv Dimension',
value=1,
step=1,
interactive=True,
)
with gr.Row(visible=False) as threshold:
linear_threshold = gr.Slider(
minimum=0,
maximum=1,
label='Linear threshold',
value=0,
step=0.01,
interactive=True,
)
conv_threshold = gr.Slider(
minimum=0,
maximum=1,
label='Conv threshold',
value=0,
step=0.01,
interactive=True,
)
with gr.Row(visible=False) as ratio:
linear_ratio = gr.Slider(
minimum=0,
maximum=1,
label='Linear ratio',
value=0,
step=0.01,
interactive=True,
)
conv_ratio = gr.Slider(
minimum=0,
maximum=1,
label='Conv ratio',
value=0,
step=0.01,
interactive=True,
)
with gr.Row(visible=False) as quantile:
linear_quantile = gr.Slider(
minimum=0,
maximum=1,
label='Linear quantile',
value=0.75,
step=0.01,
interactive=True,
)
conv_quantile = gr.Slider(
minimum=0,
maximum=1,
label='Conv quantile',
value=0.75,
step=0.01,
interactive=True,
)
with gr.Row():
use_sparse_bias = gr.Checkbox(label='Use sparse biais', value=False, interactive=True)
sparsity = gr.Slider(
minimum=0,
maximum=1,
label='Sparsity',
value=0.98,
step=0.01,
interactive=True,
)
disable_cp = gr.Checkbox(label='Disable CP decomposition', value=False, interactive=True)
mode.change(
update_mode,
inputs=[mode],
outputs=[
fixed, threshold, ratio, quantile,
]
)
extract_button = gr.Button('Extract LyCORIS LoCon')
extract_button.click(
extract_lycoris_locon,
inputs=[db_model, base_model, output_name, device,
is_v2, mode, linear_dim, conv_dim,
linear_threshold, conv_threshold,
linear_ratio, conv_ratio,
linear_quantile, conv_quantile,
use_sparse_bias, sparsity, disable_cp],
show_progress=False,
)

View File

@ -417,13 +417,16 @@ def train_model(
or f.endswith('.webp')
]
)
print(f'Folder {folder}: {num_images} images found')
# Calculate the total number of steps for this folder
steps = repeats * num_images
total_steps += steps
# Print the result
print(f'Folder {folder}: {steps} steps')
total_steps += steps
# calculate max_train_steps
max_train_steps = int(

View File

@ -116,7 +116,7 @@ def main():
linear_mode_param, conv_mode_param,
args.device,
args.use_sparse_bias, args.sparsity,
# not args.disable_small_conv
not args.disable_cp
)
if args.safetensors:

View File

@ -1,3 +1,4 @@
import os
import sys
import pkg_resources
@ -32,7 +33,8 @@ if missing_requirements or wrong_version_requirements:
print("Error: The following packages have the wrong version:")
for requirement, expected_version, actual_version in wrong_version_requirements:
print(f" - {requirement} (expected version {expected_version}, found version {actual_version})")
print('\nRun \033[33mupgrade.ps1\033[0m or \033[33mpip install -U -r requirements.txt\033[0m to resolve the missing requirements listed above...')
upgrade_script = "upgrade.ps1" if os.name == "nt" else "upgrade.sh"
print(f"\nRun \033[33m{upgrade_script}\033[0m or \033[33mpip install -U -r requirements.txt\033[0m to resolve the missing requirements listed above...")
sys.exit(1)

16
upgrade.sh Executable file
View File

@ -0,0 +1,16 @@
#!/bin/bash
# Check if there are any changes that need to be committed
if [[ -n $(git status --short) ]]; then
echo "There are changes that need to be committed. Please stash or undo your changes before running this script." >&2
exit 1
fi
# Pull the latest changes from the remote repository
git pull
# Activate the virtual environment
source venv/bin/activate
# Upgrade the required packages
pip install --upgrade -r requirements.txt