Merge branch 'dev' into min-sr

This commit is contained in:
bmaltais 2023-03-27 13:25:29 -04:00
commit 7b228b7f51
43 changed files with 732 additions and 439 deletions

5
.gitignore vendored
View File

@ -1,10 +1,11 @@
venv venv
__pycache__ __pycache__
*.txt
cudnn_windows cudnn_windows
.vscode .vscode
*.egg-info *.egg-info
build build
wd14_tagger_model wd14_tagger_model
.DS_Store .DS_Store
locon locon
gui-user.bat
gui-user.ps1

173
README.md
View File

@ -44,6 +44,17 @@ If you run on Linux and would like to use the GUI, there is now a port of it as
### Runpod ### Runpod
Follow the instructions found in this discussion: https://github.com/bmaltais/kohya_ss/discussions/379 Follow the instructions found in this discussion: https://github.com/bmaltais/kohya_ss/discussions/379
### MacOS
In the terminal, run
```
git clone https://github.com/bmaltais/kohya_ss.git
cd kohya_ss
bash macos_setup.sh
```
During the accelerate config screen after running the script answer "This machine", "None", "No" for the remaining questions.
### Ubuntu ### Ubuntu
In the terminal, run In the terminal, run
@ -99,7 +110,17 @@ Run the following commands to install:
python .\tools\cudann_1.8_install.py python .\tools\cudann_1.8_install.py
``` ```
## Upgrading ## Upgrading MacOS
When a new release comes out, you can upgrade your repo with the following commands in the root directory:
```bash
upgrade_macos.sh
```
Once the commands have completed successfully you should be ready to use the new version. MacOS support is not tested and has been mostly taken from https://gist.github.com/jstayco/9f5733f05b9dc29de95c4056a023d645
## Upgrading Windows
When a new release comes out, you can upgrade your repo with the following commands in the root directory: When a new release comes out, you can upgrade your repo with the following commands in the root directory:
@ -192,7 +213,24 @@ This will store your a backup file with your current locally installed pip packa
## Change History ## Change History
* 2023/03/19 (v21.3.0) * 2023/03/26 (v21.3.6)
- Fixed the error while images are ended with capital image extensions. Thanks to @kvzn. https://github.com/bmaltais/kohya_ss/pull/454
* 2023/03/26 (v21.3.5)
- Fix for https://github.com/bmaltais/kohya_ss/issues/230
- Added detection for Google Colab to not bring up the GUI file/folder window on the platform. Instead it will only use the file/folder path provided in the input field.
* 2023/03/25 (v21.3.4)
- Added untested support for MacOS base on this gist: https://gist.github.com/jstayco/9f5733f05b9dc29de95c4056a023d645
Let me know how this work. From the look of it it appear to be well tought out. I modified a few things to make it fit better with the rest of the code in the repo.
- Fix for issue https://github.com/bmaltais/kohya_ss/issues/433 by implementing default of 0.
- Removed non applicable save_model_as choices for LoRA and TI.
* 2023/03/24 (v21.3.3)
- Add support for custom user gui files. THey will be created at installation time or when upgrading is missing. You will see two files in the root of the folder. One named `gui-user.bat` and the other `gui-user.ps1`. Edit the file based on your prefered terminal. Simply add the parameters you want to pass the gui in there and execute it to start the gui with them. Enjoy!
* 2023/03/23 (v21.3.2)
- Fix issue reported: https://github.com/bmaltais/kohya_ss/issues/439
* 2023/03/23 (v21.3.1)
- Merge PR to fix refactor naming issue for basic captions. Thank @zrma
* 2023/03/22 (v21.3.0)
- Add a function to load training config with `.toml` to each training script. Thanks to Linaqruf for this great contribution! - Add a function to load training config with `.toml` to each training script. Thanks to Linaqruf for this great contribution!
- Specify `.toml` file with `--config_file`. `.toml` file has `key=value` entries. Keys are same as command line options. See [#241](https://github.com/kohya-ss/sd-scripts/pull/241) for details. - Specify `.toml` file with `--config_file`. `.toml` file has `key=value` entries. Keys are same as command line options. See [#241](https://github.com/kohya-ss/sd-scripts/pull/241) for details.
- All sub-sections are combined to a single dictionary (the section names are ignored.) - All sub-sections are combined to a single dictionary (the section names are ignored.)
@ -205,125 +243,12 @@ This will store your a backup file with your current locally installed pip packa
- `( )`, `(xxxx:1.2)` and `[ ]` can be used. - `( )`, `(xxxx:1.2)` and `[ ]` can be used.
- Fix exception on training model in diffusers format with `train_network.py` Thanks to orenwang! [#290](https://github.com/kohya-ss/sd-scripts/pull/290) - Fix exception on training model in diffusers format with `train_network.py` Thanks to orenwang! [#290](https://github.com/kohya-ss/sd-scripts/pull/290)
- Add warning if you are about to overwrite an existing model: https://github.com/bmaltais/kohya_ss/issues/404 - Add warning if you are about to overwrite an existing model: https://github.com/bmaltais/kohya_ss/issues/404
* 2023/03/19 (v21.2.5): - Add `--vae_batch_size` for faster latents caching to each training script. This batches VAE calls.
- Fix basic captioning logic - Please start with`2` or `4` depending on the size of VRAM.
- Add possibility to not train TE in Dreamboot by setting `Step text encoder training` to -1. - Fix a number of training steps with `--gradient_accumulation_steps` and `--max_train_epochs`. Thanks to tsukimiya!
- Update linux scripts - Extract parser setup to external scripts. Thanks to robertsmieja!
* 2023/03/12 (v21.2.4): - Fix an issue without `.npz` and with `--full_path` in training.
- Fix issue with kohya locon not training the convolution layers - Support extensions with upper cases for images for not Windows environment.
- Update LyCORIS module version - Fix `resize_lora.py` to work with LoRA with dynamic rank (including `conv_dim != network_dim`). Thanks to toshiaki!
- Update LyCORYS locon extract tool - Fix issue: https://github.com/bmaltais/kohya_ss/issues/406
* 2023/03/12 (v21.2.3): - Add device support to LoRA extract.
- Add validation that all requirements are met before starting the GUI.
* 2023/03/11 (v21.2.2):
- Add support for LoRA LoHa type. See https://github.com/KohakuBlueleaf/LyCORIS for more details.
* 2023/03/10 (v21.2.1):
- Update to latest sd-script code
- Add support for SVD based LoRA merge
* 2023/03/09 (v21.2.0):
- Fix issue https://github.com/bmaltais/kohya_ss/issues/335
- Add option to print LoRA trainer command without executing it
- Add support for samples during trainin via a new `Sample images config` accordion in the `Training parameters` tab.
- Added new `Additional parameters` under the `Advanced Configuration` section of the `Training parameters` tab to allow for the specifications of parameters not handles by the GUI.
- Added support for sample as a new Accordion under the `Training parameters` tab. More info about the prompt options can be found here: https://github.com/kohya-ss/sd-scripts/issues/256#issuecomment-1455005709
- There may be problems due to major changes. If you cannot revert back to the previous version when problems occur, please do not update for a while.
- Minimum metadata (module name, dim, alpha and network_args) is recorded even with `--no_metadata`, issue https://github.com/kohya-ss/sd-scripts/issues/254
- `train_network.py` supports LoRA for Conv2d-3x3 (extended to conv2d with a kernel size not 1x1).
- Same as a current version of [LoCon](https://github.com/KohakuBlueleaf/LoCon). __Thank you very much KohakuBlueleaf for your help!__
- LoCon will be enhanced in the future. Compatibility for future versions is not guaranteed.
- Specify `--network_args` option like: `--network_args "conv_dim=4" "conv_alpha=1"`
- [Additional Networks extension](https://github.com/kohya-ss/sd-webui-additional-networks) version 0.5.0 or later is required to use 'LoRA for Conv2d-3x3' in Stable Diffusion web UI.
- __Stable Diffusion web UI built-in LoRA does not support 'LoRA for Conv2d-3x3' now. Consider carefully whether or not to use it.__
- Merging/extracting scripts also support LoRA for Conv2d-3x3.
- Free CUDA memory after sample generation to reduce VRAM usage, issue https://github.com/kohya-ss/sd-scripts/issues/260
- Empty caption doesn't cause error now, issue https://github.com/kohya-ss/sd-scripts/issues/258
- Fix sample generation is crashing in Textual Inversion training when using templates, or if height/width is not divisible by 8.
- Update documents (Japanese only).
- Dependencies are updated, Please [upgrade](#upgrade) the repo.
- Add detail dataset config feature by extra config file. Thanks to fur0ut0 for this great contribution!
- Documentation is [here](https://github-com.translate.goog/kohya-ss/sd-scripts/blob/main/config_README-ja.md) (only in Japanese currently.)
- Specify `.toml` file with `--dataset_config` option.
- The options supported under the previous release can be used as is instead of the `.toml` config file.
- There might be bugs due to the large scale of update, please report any problems if you find at https://github.com/kohya-ss/sd-scripts/issues.
- Add feature to generate sample images in the middle of training for each training scripts.
- `--sample_every_n_steps` and `--sample_every_n_epochs` options: frequency to generate.
- `--sample_prompts` option: the file contains prompts (each line generates one image.)
- The prompt is subset of `gen_img_diffusers.py`. The prompt options `w, h, d, l, s, n` are supported.
- `--sample_sampler` option: sampler (scheduler) for generating, such as ddim or k_euler. See help for useable samplers.
- Add `--tokenizer_cache_dir` to each training and generation scripts to cache Tokenizer locally from Diffusers.
- Scripts will support offline training/generation after caching.
- Support letents upscaling for highres. fix, and VAE batch size in `gen_img_diffusers.py` (no documentation yet.)
- Sample image generation:
A prompt file might look like this, for example
```
# prompt 1
masterpiece, best quality, 1girl, in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28
# prompt 2
masterpiece, best quality, 1boy, in business suit, standing at street, looking back --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 576 --h 832 --d 2 --l 5.5 --s 40
```
Lines beginning with `#` are comments. You can specify options for the generated image with options like `--n` after the prompt. The following can be used.
* `--n` Negative prompt up to the next option.
* `--w` Specifies the width of the generated image.
* `--h` Specifies the height of the generated image.
* `--d` Specifies the seed of the generated image.
* `--l` Specifies the CFG scale of the generated image.
* `--s` Specifies the number of steps in the generation.
The prompt weighting such as `( )` and `[ ]` are not working.
Please read [Releases](https://github.com/kohya-ss/sd-scripts/releases) for recent updates.
* 2023/03/05 (v21.1.5):
- Add replace underscore with space option to WD14 captioning. Thanks @sALTaccount!
- Improve how custom preset is set and handles.
- Add support for `--listen` argument. This allow gradio to listen for connections from other devices on the network (or internet). For example: `gui.ps1 --listen "0.0.0.0"` will allow anyone to connect to the gradio webui.
- Updated `Resize LoRA` tab to support LoCon resizing. Added new resize
* 2023/03/05 (v21.1.4):
- Removing legacy and confusing use 8bit adam chackbox. It is now configured using the Optimiser drop down list. It will be set properly based on legacy config files.
* 2023/03/04 (v21.1.3):
- Fix progress bar being displayed when not required.
- Add support for linux, thank you @devNegative-asm
* 2023/03/03 (v21.1.2):
- Fix issue https://github.com/bmaltais/kohya_ss/issues/277
- Fix issue https://github.com/bmaltais/kohya_ss/issues/278 introduce by LoCon project switching to pip module. Make sure to run upgrade.ps1 to install the latest pip requirements for LoCon support.
* 2023/03/02 (v21.1.1):
- Emergency fix for https://github.com/bmaltais/kohya_ss/issues/261
* 2023/03/02 (v21.1.0):
- Add LoCon support (https://github.com/KohakuBlueleaf/LoCon.git) to the Dreambooth LoRA tab. This will allow to create a new type of LoRA that include conv layers as part of the LoRA... hence the name LoCon. LoCon will work with the native Auto1111 implementation of LoRA. If you want to use it with the Kohya_ss additionalNetwork you will need to install this other extension... until Kohya_ss support it natively: https://github.com/KohakuBlueleaf/a1111-sd-webui-locon
* 2023/03/01 (v21.0.1):
- Add warning to tensorboard start if the log information is missing
- Fix issue with 8bitadam on older config file load
* 2023/02/27 (v21.0.0):
- Add tensorboard start and stop support to the GUI
* 2023/02/26 (v20.8.2):
- Fix issue https://github.com/bmaltais/kohya_ss/issues/231
- Change default for seed to random
- Add support for --share argument to `kohya_gui.py` and `gui.ps1`
- Implement 8bit adam login to help with the legacy `Use 8bit adam` checkbox that is now superceided by the `Optimizer` dropdown selection. This field will be eventually removed. Kept for now for backward compatibility.
* 2023/02/23 (v20.8.1):
- Fix instability training issue in `train_network.py`.
- `fp16` training is probably not affected by this issue.
- Training with `float` for SD2.x models will work now. Also training with bf16 might be improved.
- This issue seems to have occurred in [PR#190](https://github.com/kohya-ss/sd-scripts/pull/190).
- Add some metadata to LoRA model. Thanks to space-nuko!
- Raise an error if optimizer options conflict (e.g. `--optimizer_type` and `--use_8bit_adam`.)
- Support ControlNet in `gen_img_diffusers.py` (no documentation yet.)
* 2023/02/22 (v20.8.0):
- Add gui support for optimizers: `AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, AdaFactor`
- Add gui support for `--noise_offset`
- Refactor optmizer options. Thanks to mgz-dev!
- Add `--optimizer_type` option for each training script. Please see help. Japanese documentation is [here](https://github-com.translate.goog/kohya-ss/sd-scripts/blob/main/train_network_README-ja.md?_x_tr_sl=fr&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#%E3%82%AA%E3%83%97%E3%83%86%E3%82%A3%E3%83%9E%E3%82%A4%E3%82%B6%E3%81%AE%E6%8C%87%E5%AE%9A%E3%81%AB%E3%81%A4%E3%81%84%E3%81%A6).
- `--use_8bit_adam` and `--use_lion_optimizer` options also work and will override the options above for backward compatibility.
- Add SGDNesterov and its 8bit.
- Add [D-Adaptation](https://github.com/facebookresearch/dadaptation) optimizer. Thanks to BootsofLagrangian and all!
- Please install D-Adaptation optimizer with `pip install dadaptation` (it is not in requirements.txt currently.)
- Please see https://github.com/kohya-ss/sd-scripts/issues/181 for details.
- Add AdaFactor optimizer. Thanks to Toshiaki!
- Extra lr scheduler settings (num_cycles etc.) are working in training scripts other than `train_network.py`.
- Add `--max_grad_norm` option for each training script for gradient clipping. `0.0` disables clipping.
- Symbolic link can be loaded in each training script. Thanks to TkskKurumi!

View File

@ -107,6 +107,7 @@ def save_configuration(
sample_sampler, sample_sampler,
sample_prompts, sample_prompts,
additional_parameters, additional_parameters,
vae_batch_size,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -214,6 +215,7 @@ def open_configuration(
sample_sampler, sample_sampler,
sample_prompts, sample_prompts,
additional_parameters, additional_parameters,
vae_batch_size,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -303,6 +305,7 @@ def train_model(
sample_sampler, sample_sampler,
sample_prompts, sample_prompts,
additional_parameters, additional_parameters,
vae_batch_size,
): ):
if pretrained_model_name_or_path == '': if pretrained_model_name_or_path == '':
msgbox('Source model information is missing') msgbox('Source model information is missing')
@ -328,38 +331,53 @@ def train_model(
if check_if_model_exist(output_name, output_dir, save_model_as): if check_if_model_exist(output_name, output_dir, save_model_as):
return return
# Get a list of all subfolders in train_data_dir # Get a list of all subfolders in train_data_dir, excluding hidden folders
subfolders = [ subfolders = [
f f
for f in os.listdir(train_data_dir) for f in os.listdir(train_data_dir)
if os.path.isdir(os.path.join(train_data_dir, f)) if os.path.isdir(os.path.join(train_data_dir, f)) and not f.startswith('.')
] ]
# Check if subfolders are present. If not let the user know and return
if not subfolders:
print('\033[33mNo subfolders were found in', train_data_dir, ' can\'t train\...033[0m')
return
total_steps = 0 total_steps = 0
# Loop through each subfolder and extract the number of repeats # Loop through each subfolder and extract the number of repeats
for folder in subfolders: for folder in subfolders:
# Extract the number of repeats from the folder name # Extract the number of repeats from the folder name
repeats = int(folder.split('_')[0]) try:
repeats = int(folder.split('_')[0])
except ValueError:
print('\033[33mSubfolder', folder, 'does not have a proper repeat value, please correct the name or remove it... can\'t train...\033[0m')
continue
# Count the number of images in the folder # Count the number of images in the folder
num_images = len( num_images = len(
[ [
f f
for f in os.listdir(os.path.join(train_data_dir, folder)) for f, lower_f in (
if f.endswith('.jpg') (file, file.lower()) for file in os.listdir(os.path.join(train_data_dir, folder))
or f.endswith('.jpeg') )
or f.endswith('.png') if lower_f.endswith(('.jpg', '.jpeg', '.png', '.webp'))
or f.endswith('.webp')
] ]
) )
if num_images == 0:
print(f'{folder} folder contain no images, skipping...')
else:
# Calculate the total number of steps for this folder
steps = repeats * num_images
total_steps += steps
# Calculate the total number of steps for this folder # Print the result
steps = repeats * num_images print('\033[33mFolder', folder, ':', steps, 'steps\033[0m')
total_steps += steps
# Print the result if total_steps == 0:
print(f'Folder {folder}: {steps} steps') print('\033[33mNo images were found in folder', train_data_dir, '... please rectify!\033[0m')
return
# Print the result # Print the result
# print(f"{total_steps} total steps") # print(f"{total_steps} total steps")
@ -367,9 +385,7 @@ def train_model(
if reg_data_dir == '': if reg_data_dir == '':
reg_factor = 1 reg_factor = 1
else: else:
print( print('\033[94mRegularisation images are used... Will double the number of steps required...\033[0m')
'Regularisation images are used... Will double the number of steps required...'
)
reg_factor = 2 reg_factor = 2
# calculate max_train_steps # calculate max_train_steps
@ -480,6 +496,7 @@ def train_model(
caption_dropout_rate=caption_dropout_rate, caption_dropout_rate=caption_dropout_rate,
noise_offset=noise_offset, noise_offset=noise_offset,
additional_parameters=additional_parameters, additional_parameters=additional_parameters,
vae_batch_size=vae_batch_size,
) )
run_cmd += run_cmd_sample( run_cmd += run_cmd_sample(
@ -686,6 +703,7 @@ def dreambooth_tab(
caption_dropout_rate, caption_dropout_rate,
noise_offset, noise_offset,
additional_parameters, additional_parameters,
vae_batch_size,
) = gradio_advanced_training() ) = gradio_advanced_training()
color_aug.change( color_aug.change(
color_aug_changed, color_aug_changed,
@ -786,6 +804,7 @@ def dreambooth_tab(
sample_sampler, sample_sampler,
sample_prompts, sample_prompts,
additional_parameters, additional_parameters,
vae_batch_size,
] ]
button_open_config.click( button_open_config.click(

View File

@ -0,0 +1,13 @@
python tools\lycoris_locon_extract.py --mode quantile --safetensors --linear_ratio 0.9 --conv_ratio 0.9 --device cuda D:/models/v1-5-pruned.ckpt D:/models/cyberrealistic_v12.safetensors "D:/lora/sd1.5/cyberrealistic_v12.safetensors"
python tools\lycoris_locon_extract.py --mode quantile --safetensors --linear_quantile 0.75 --conv_quantile 0.75 --device cuda D:/models/v1-5-pruned.ckpt "C:\Users\berna\Downloads\deliberate_v2.safetensors" "D:/lora/sd1.5/deliberate_v2.safetensors"
python tools\lycoris_locon_extract.py --mode fixed --safetensors --linear_dim 512 --conv_dim 512 --device cuda D:/models/v1-5-pruned.ckpt D:/models/cyberrealistic_v12.safetensors "D:/lora/sd1.5/cyberrealistic_v12.safetensors"
python tools\lycoris_locon_extract.py --use_sparse_bias --sparsity 0.98 --mode quantile --safetensors --linear_quantile 0.75 --conv_quantile 0.75 --device cuda D:/models/v1-5-pruned.ckpt "C:\Users\berna\Downloads\deliberate_v2.safetensors" "D:/lora/sd1.5/deliberate_v2.safetensors"
python tools\lycoris_locon_extract.py --use_sparse_bias --sparsity 0.98 --mode quantile --safetensors --linear_quantile 0.75 --conv_quantile 0.75 --device cuda D:/models/v1-5-pruned.ckpt "D:/models/test\claire_v1.0ee2-000003.safetensors" "D:/lora/sd1.5/claire_v1.0ee2-000003.safetensors"
python tools\lycoris_locon_extract.py --use_sparse_bias --sparsity 0.98 --mode quantile --safetensors --linear_quantile 0.5 --conv_quantile 0.5 --device cuda D:/models/v1-5-pruned.ckpt "D:/models/test\claire_v1.0ee2-000003.safetensors" "D:/lora/sd1.5/claire_v1.0ee2-0.5.safetensors"
python tools\lycoris_locon_extract.py --use_sparse_bias --sparsity 0.98 --mode quantile --safetensors --linear_quantile 0.5 --conv_quantile 0.5 --device cuda D:/models/v1-5-pruned.ckpt "D:/models/test\claire_v1.0f.safetensors" "D:/lora/sd1.5/claire_v1.0f0.5.safetensors"

View File

@ -138,7 +138,7 @@ def train(args):
vae.requires_grad_(False) vae.requires_grad_(False)
vae.eval() vae.eval()
with torch.no_grad(): with torch.no_grad():
train_dataset_group.cache_latents(vae) train_dataset_group.cache_latents(vae, args.vae_batch_size)
vae.to("cpu") vae.to("cpu")
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -194,7 +194,7 @@ def train(args):
# 学習ステップ数を計算する # 学習ステップ数を計算する
if args.max_train_epochs is not None: if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * len(train_dataloader) args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps)
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# lr schedulerを用意する # lr schedulerを用意する
@ -240,7 +240,7 @@ def train(args):
print(f" num epochs / epoch数: {num_train_epochs}") print(f" num epochs / epoch数: {num_train_epochs}")
print(f" batch size per device / バッチサイズ: {args.train_batch_size}") print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
@ -387,7 +387,7 @@ def train(args):
print("model saved.") print("model saved.")
if __name__ == "__main__": def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
train_util.add_sd_models_arguments(parser) train_util.add_sd_models_arguments(parser)
@ -400,6 +400,12 @@ if __name__ == "__main__":
parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する") parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する")
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args() args = parser.parse_args()
args = train_util.read_config_from_file(args, parser) args = train_util.read_config_from_file(args, parser)

View File

@ -163,13 +163,19 @@ def main(args):
print("done!") print("done!")
if __name__ == '__main__': def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
# parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") # parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル") parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
parser.add_argument("--debug", action="store_true", help="debug mode") parser.add_argument("--debug", action="store_true", help="debug mode")
return parser
if __name__ == '__main__':
parser = setup_parser()
args, unknown = parser.parse_known_args() args, unknown = parser.parse_known_args()
if len(unknown) == 1: if len(unknown) == 1:
print("WARNING: train_data_dir argument is removed. This script will not work with three arguments in future. Please specify two arguments: in_json and out_json.") print("WARNING: train_data_dir argument is removed. This script will not work with three arguments in future. Please specify two arguments: in_json and out_json.")

View File

@ -133,7 +133,7 @@ def main(args):
print("done!") print("done!")
if __name__ == '__main__': def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("--caption_weights", type=str, default="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth", parser.add_argument("--caption_weights", type=str, default="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth",
@ -153,6 +153,12 @@ if __name__ == '__main__':
parser.add_argument('--seed', default=42, type=int, help='seed for reproducibility / 再現性を確保するための乱数seed') parser.add_argument('--seed', default=42, type=int, help='seed for reproducibility / 再現性を確保するための乱数seed')
parser.add_argument("--debug", action="store_true", help="debug mode") parser.add_argument("--debug", action="store_true", help="debug mode")
return parser
if __name__ == '__main__':
parser = setup_parser()
args = parser.parse_args() args = parser.parse_args()
# スペルミスしていたオプションを復元する # スペルミスしていたオプションを復元する

View File

@ -127,7 +127,7 @@ def main(args):
print("done!") print("done!")
if __name__ == '__main__': def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子") parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子")
@ -141,5 +141,11 @@ if __name__ == '__main__':
help="remove like `with the words xxx` from caption / `with the words xxx`のような部分をキャプションから削除する") help="remove like `with the words xxx` from caption / `with the words xxx`のような部分をキャプションから削除する")
parser.add_argument("--debug", action="store_true", help="debug mode") parser.add_argument("--debug", action="store_true", help="debug mode")
return parser
if __name__ == '__main__':
parser = setup_parser()
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)

View File

@ -46,7 +46,7 @@ def main(args):
print("done!") print("done!")
if __name__ == '__main__': def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
@ -61,6 +61,12 @@ if __name__ == '__main__':
help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す") help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す")
parser.add_argument("--debug", action="store_true", help="debug mode") parser.add_argument("--debug", action="store_true", help="debug mode")
return parser
if __name__ == '__main__':
parser = setup_parser()
args = parser.parse_args() args = parser.parse_args()
# スペルミスしていたオプションを復元する # スペルミスしていたオプションを復元する

View File

@ -47,7 +47,7 @@ def main(args):
print("done!") print("done!")
if __name__ == '__main__': def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
@ -61,5 +61,11 @@ if __name__ == '__main__':
help="extension of caption (tag) file / 読み込むキャプション(タグ)ファイルの拡張子") help="extension of caption (tag) file / 読み込むキャプション(タグ)ファイルの拡張子")
parser.add_argument("--debug", action="store_true", help="debug mode, print tags") parser.add_argument("--debug", action="store_true", help="debug mode, print tags")
return parser
if __name__ == '__main__':
parser = setup_parser()
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)

View File

@ -229,7 +229,7 @@ def main(args):
print("done!") print("done!")
if __name__ == '__main__': def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル") parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
@ -257,5 +257,11 @@ if __name__ == '__main__':
parser.add_argument("--skip_existing", action="store_true", parser.add_argument("--skip_existing", action="store_true",
help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップするflip_aug有効時は通常、反転の両方が存在する画像をスキップ") help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップするflip_aug有効時は通常、反転の両方が存在する画像をスキップ")
return parser
if __name__ == '__main__':
parser = setup_parser()
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)

View File

@ -173,7 +173,7 @@ def main(args):
print("done!") print("done!")
if __name__ == '__main__': def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("--repo_id", type=str, default=DEFAULT_WD14_TAGGER_REPO, parser.add_argument("--repo_id", type=str, default=DEFAULT_WD14_TAGGER_REPO,
@ -191,6 +191,12 @@ if __name__ == '__main__':
parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子") parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子")
parser.add_argument("--debug", action="store_true", help="debug mode") parser.add_argument("--debug", action="store_true", help="debug mode")
return parser
if __name__ == '__main__':
parser = setup_parser()
args = parser.parse_args() args = parser.parse_args()
# スペルミスしていたオプションを復元する # スペルミスしていたオプションを復元する

View File

@ -104,7 +104,7 @@ def save_configuration(
sample_every_n_epochs, sample_every_n_epochs,
sample_sampler, sample_sampler,
sample_prompts, sample_prompts,
additional_parameters, additional_parameters,vae_batch_size,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -217,7 +217,7 @@ def open_configuration(
sample_every_n_epochs, sample_every_n_epochs,
sample_sampler, sample_sampler,
sample_prompts, sample_prompts,
additional_parameters, additional_parameters,vae_batch_size,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -312,7 +312,7 @@ def train_model(
sample_every_n_epochs, sample_every_n_epochs,
sample_sampler, sample_sampler,
sample_prompts, sample_prompts,
additional_parameters, additional_parameters,vae_batch_size,
): ):
if check_if_model_exist(output_name, output_dir, save_model_as): if check_if_model_exist(output_name, output_dir, save_model_as):
return return
@ -368,8 +368,10 @@ def train_model(
image_num = len( image_num = len(
[ [
f f
for f in os.listdir(image_folder) for f, lower_f in (
if f.endswith('.jpg') or f.endswith('.png') or f.endswith('.webp') (file, file.lower()) for file in os.listdir(image_folder)
)
if lower_f.endswith(('.jpg', '.jpeg', '.png', '.webp'))
] ]
) )
print(f'image_num = {image_num}') print(f'image_num = {image_num}')
@ -470,6 +472,7 @@ def train_model(
caption_dropout_rate=caption_dropout_rate, caption_dropout_rate=caption_dropout_rate,
noise_offset=noise_offset, noise_offset=noise_offset,
additional_parameters=additional_parameters, additional_parameters=additional_parameters,
vae_batch_size=vae_batch_size,
) )
run_cmd += run_cmd_sample( run_cmd += run_cmd_sample(
@ -686,6 +689,7 @@ def finetune_tab():
caption_dropout_rate, caption_dropout_rate,
noise_offset, noise_offset,
additional_parameters, additional_parameters,
vae_batch_size,
) = gradio_advanced_training() ) = gradio_advanced_training()
color_aug.change( color_aug.change(
color_aug_changed, color_aug_changed,
@ -780,6 +784,7 @@ def finetune_tab():
sample_sampler, sample_sampler,
sample_prompts, sample_prompts,
additional_parameters, additional_parameters,
vae_batch_size,
] ]
button_run.click(train_model, inputs=settings_list) button_run.click(train_model, inputs=settings_list)

View File

@ -2690,7 +2690,7 @@ def main(args):
print("done!") print("done!")
if __name__ == '__main__': def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--v2", action='store_true', help='load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む') parser.add_argument("--v2", action='store_true', help='load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む')
@ -2786,5 +2786,11 @@ if __name__ == '__main__':
parser.add_argument("--control_net_ratios", type=float, default=None, nargs='*', parser.add_argument("--control_net_ratios", type=float, default=None, nargs='*',
help='ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率') help='ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率')
return parser
if __name__ == '__main__':
parser = setup_parser()
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)

13
gui.ps1
View File

@ -4,7 +4,14 @@
# Validate the requirements and store the exit code # Validate the requirements and store the exit code
python.exe .\tools\validate_requirements.py python.exe .\tools\validate_requirements.py
# If the exit code is 0, run the kohya_gui.py script with the command-line arguments # If the exit code is 0, read arguments from gui_parameters.txt (if it exists)
# and run the kohya_gui.py script with the command-line arguments
if ($LASTEXITCODE -eq 0) { if ($LASTEXITCODE -eq 0) {
python.exe kohya_gui.py $args $argsFromFile = @()
} if (Test-Path .\gui_parameters.txt) {
$argsFromFile = Get-Content .\gui_parameters.txt -Encoding UTF8 | Where-Object { $_ -notmatch "^#" } | Foreach-Object { $_ -split " " }
}
$args_combo = $argsFromFile + $args
Write-Host "The arguments passed to this script were: $args_combo"
python.exe kohya_gui.py $args_combo
}

13
gui_macos.sh Executable file
View File

@ -0,0 +1,13 @@
#!/bin/bash
# Activate the virtual environment
source venv/bin/activate
# Validate the requirements and store the exit code
python tools/validate_requirements.py --requirements requirements_macos.txt
exit_code=$?
# If the exit code is 0, run the kohya_gui.py script with the command-line arguments
if [ $exit_code -eq 0 ]; then
python kohya_gui.py "$@"
fi

View File

@ -53,10 +53,10 @@ def caption_images(
) )
if find_text: if find_text:
find_replace( find_replace(
folder=images_dir, folder_path=images_dir,
caption_file_ext=caption_ext, caption_file_ext=caption_ext,
find=find_text, search_text=find_text,
replace=replace_text, replace_text=replace_text,
) )
else: else:
if prefix or postfix: if prefix or postfix:

View File

@ -1,4 +1,5 @@
from tkinter import filedialog, Tk from tkinter import filedialog, Tk
from easygui import msgbox
import os import os
import gradio as gr import gradio as gr
import easygui import easygui
@ -30,6 +31,8 @@ V1_MODELS = [
# define a list of substrings to search for # define a list of substrings to search for
ALL_PRESET_MODELS = V2_BASE_MODELS + V_PARAMETERIZATION_MODELS + V1_MODELS ALL_PRESET_MODELS = V2_BASE_MODELS + V_PARAMETERIZATION_MODELS + V1_MODELS
FILE_ENV_EXCLUSION = ['COLAB_GPU', 'RUNPOD_ENVIRONMENT']
def check_if_model_exist(output_name, output_dir, save_model_as): def check_if_model_exist(output_name, output_dir, save_model_as):
if save_model_as in ['diffusers', 'diffusers_safetendors']: if save_model_as in ['diffusers', 'diffusers_safetendors']:
@ -60,28 +63,20 @@ def check_if_model_exist(output_name, output_dir, save_model_as):
def update_my_data(my_data): def update_my_data(my_data):
# Update optimizer based on use_8bit_adam flag # Update the optimizer based on the use_8bit_adam flag
use_8bit_adam = my_data.get('use_8bit_adam', False) use_8bit_adam = my_data.get('use_8bit_adam', False)
if use_8bit_adam: my_data.setdefault('optimizer', 'AdamW8bit' if use_8bit_adam else 'AdamW')
my_data['optimizer'] = 'AdamW8bit'
elif 'optimizer' not in my_data:
my_data['optimizer'] = 'AdamW'
# Update model_list to custom if empty or pretrained_model_name_or_path is not a preset model # Update model_list to custom if empty or pretrained_model_name_or_path is not a preset model
model_list = my_data.get('model_list', []) model_list = my_data.get('model_list', [])
pretrained_model_name_or_path = my_data.get( pretrained_model_name_or_path = my_data.get('pretrained_model_name_or_path', '')
'pretrained_model_name_or_path', '' if not model_list or pretrained_model_name_or_path not in ALL_PRESET_MODELS:
)
if (
not model_list
or pretrained_model_name_or_path not in ALL_PRESET_MODELS
):
my_data['model_list'] = 'custom' my_data['model_list'] = 'custom'
# Convert epoch and save_every_n_epochs values to int if they are strings # Convert epoch and save_every_n_epochs values to int if they are strings
for key in ['epoch', 'save_every_n_epochs']: for key in ['epoch', 'save_every_n_epochs']:
value = my_data.get(key, -1) value = my_data.get(key, -1)
if isinstance(value, str) and value: if isinstance(value, str) and value.isdigit():
my_data[key] = int(value) my_data[key] = int(value)
elif not value: elif not value:
my_data[key] = -1 my_data[key] = -1
@ -90,43 +85,23 @@ def update_my_data(my_data):
if my_data.get('LoRA_type', 'Standard') == 'LoCon': if my_data.get('LoRA_type', 'Standard') == 'LoCon':
my_data['LoRA_type'] = 'LyCORIS/LoCon' my_data['LoRA_type'] = 'LyCORIS/LoCon'
# Update model save choices due to changes for LoRA and TI training
if (
(my_data.get('LoRA_type') or my_data.get('num_vectors_per_token'))
and my_data.get('save_model_as') not in ['safetensors', 'ckpt']
):
message = (
'Updating save_model_as to safetensors because the current value in the config file is no longer applicable to {}'
)
if my_data.get('LoRA_type'):
print(message.format('LoRA'))
if my_data.get('num_vectors_per_token'):
print(message.format('TI'))
my_data['save_model_as'] = 'safetensors'
return my_data return my_data
# def update_my_data(my_data):
# if my_data.get('use_8bit_adam', False) == True:
# my_data['optimizer'] = 'AdamW8bit'
# # my_data['use_8bit_adam'] = False
# if (
# my_data.get('optimizer', 'missing') == 'missing'
# and my_data.get('use_8bit_adam', False) == False
# ):
# my_data['optimizer'] = 'AdamW'
# if my_data.get('model_list', 'custom') == []:
# print('Old config with empty model list. Setting to custom...')
# my_data['model_list'] = 'custom'
# # If Pretrained model name or path is not one of the preset models then set the preset_model to custom
# if not my_data.get('pretrained_model_name_or_path', '') in ALL_PRESET_MODELS:
# my_data['model_list'] = 'custom'
# # Fix old config files that contain epoch as str instead of int
# for key in ['epoch', 'save_every_n_epochs']:
# value = my_data.get(key, -1)
# if type(value) == str:
# if value != '':
# my_data[key] = int(value)
# else:
# my_data[key] = -1
# if my_data.get('LoRA_type', 'Standard') == 'LoCon':
# my_data['LoRA_type'] = 'LyCORIS/LoCon'
# return my_data
def get_dir_and_file(file_path): def get_dir_and_file(file_path):
dir_path, file_name = os.path.split(file_path) dir_path, file_name = os.path.split(file_path)
return (dir_path, file_name) return (dir_path, file_name)
@ -145,54 +120,58 @@ def get_dir_and_file(file_path):
def get_file_path( def get_file_path(
file_path='', default_extension='.json', extension_name='Config files' file_path='', default_extension='.json', extension_name='Config files'
): ):
current_file_path = file_path if not any(var in os.environ for var in FILE_ENV_EXCLUSION):
# print(f'current file path: {current_file_path}') current_file_path = file_path
# print(f'current file path: {current_file_path}')
initial_dir, initial_file = get_dir_and_file(file_path) initial_dir, initial_file = get_dir_and_file(file_path)
# Create a hidden Tkinter root window # Create a hidden Tkinter root window
root = Tk() root = Tk()
root.wm_attributes('-topmost', 1) root.wm_attributes('-topmost', 1)
root.withdraw() root.withdraw()
# Show the open file dialog and get the selected file path # Show the open file dialog and get the selected file path
file_path = filedialog.askopenfilename( file_path = filedialog.askopenfilename(
filetypes=( filetypes=(
(extension_name, f'*{default_extension}'), (extension_name, f'*{default_extension}'),
('All files', '*.*'), ('All files', '*.*'),
), ),
defaultextension=default_extension, defaultextension=default_extension,
initialfile=initial_file, initialfile=initial_file,
initialdir=initial_dir, initialdir=initial_dir,
) )
# Destroy the hidden root window # Destroy the hidden root window
root.destroy() root.destroy()
# If no file is selected, use the current file path # If no file is selected, use the current file path
if not file_path: if not file_path:
file_path = current_file_path file_path = current_file_path
current_file_path = file_path
# print(f'current file path: {current_file_path}')
return file_path return file_path
def get_any_file_path(file_path=''): def get_any_file_path(file_path=''):
current_file_path = file_path if not any(var in os.environ for var in FILE_ENV_EXCLUSION):
# print(f'current file path: {current_file_path}') current_file_path = file_path
# print(f'current file path: {current_file_path}')
initial_dir, initial_file = get_dir_and_file(file_path) initial_dir, initial_file = get_dir_and_file(file_path)
root = Tk() root = Tk()
root.wm_attributes('-topmost', 1) root.wm_attributes('-topmost', 1)
root.withdraw() root.withdraw()
file_path = filedialog.askopenfilename( file_path = filedialog.askopenfilename(
initialdir=initial_dir, initialdir=initial_dir,
initialfile=initial_file, initialfile=initial_file,
) )
root.destroy() root.destroy()
if file_path == '': if file_path == '':
file_path = current_file_path file_path = current_file_path
return file_path return file_path
@ -218,18 +197,19 @@ def remove_doublequote(file_path):
def get_folder_path(folder_path=''): def get_folder_path(folder_path=''):
current_folder_path = folder_path if not any(var in os.environ for var in FILE_ENV_EXCLUSION):
current_folder_path = folder_path
initial_dir, initial_file = get_dir_and_file(folder_path) initial_dir, initial_file = get_dir_and_file(folder_path)
root = Tk() root = Tk()
root.wm_attributes('-topmost', 1) root.wm_attributes('-topmost', 1)
root.withdraw() root.withdraw()
folder_path = filedialog.askdirectory(initialdir=initial_dir) folder_path = filedialog.askdirectory(initialdir=initial_dir)
root.destroy() root.destroy()
if folder_path == '': if folder_path == '':
folder_path = current_folder_path folder_path = current_folder_path
return folder_path return folder_path
@ -237,34 +217,35 @@ def get_folder_path(folder_path=''):
def get_saveasfile_path( def get_saveasfile_path(
file_path='', defaultextension='.json', extension_name='Config files' file_path='', defaultextension='.json', extension_name='Config files'
): ):
current_file_path = file_path if not any(var in os.environ for var in FILE_ENV_EXCLUSION):
# print(f'current file path: {current_file_path}') current_file_path = file_path
# print(f'current file path: {current_file_path}')
initial_dir, initial_file = get_dir_and_file(file_path) initial_dir, initial_file = get_dir_and_file(file_path)
root = Tk() root = Tk()
root.wm_attributes('-topmost', 1) root.wm_attributes('-topmost', 1)
root.withdraw() root.withdraw()
save_file_path = filedialog.asksaveasfile( save_file_path = filedialog.asksaveasfile(
filetypes=( filetypes=(
(f'{extension_name}', f'{defaultextension}'), (f'{extension_name}', f'{defaultextension}'),
('All files', '*'), ('All files', '*'),
), ),
defaultextension=defaultextension, defaultextension=defaultextension,
initialdir=initial_dir, initialdir=initial_dir,
initialfile=initial_file, initialfile=initial_file,
) )
root.destroy() root.destroy()
# print(save_file_path) # print(save_file_path)
if save_file_path == None: if save_file_path == None:
file_path = current_file_path file_path = current_file_path
else: else:
print(save_file_path.name) print(save_file_path.name)
file_path = save_file_path.name file_path = save_file_path.name
# print(file_path) # print(file_path)
return file_path return file_path
@ -272,27 +253,28 @@ def get_saveasfile_path(
def get_saveasfilename_path( def get_saveasfilename_path(
file_path='', extensions='*', extension_name='Config files' file_path='', extensions='*', extension_name='Config files'
): ):
current_file_path = file_path if not any(var in os.environ for var in FILE_ENV_EXCLUSION):
# print(f'current file path: {current_file_path}') current_file_path = file_path
# print(f'current file path: {current_file_path}')
initial_dir, initial_file = get_dir_and_file(file_path) initial_dir, initial_file = get_dir_and_file(file_path)
root = Tk() root = Tk()
root.wm_attributes('-topmost', 1) root.wm_attributes('-topmost', 1)
root.withdraw() root.withdraw()
save_file_path = filedialog.asksaveasfilename( save_file_path = filedialog.asksaveasfilename(
filetypes=((f'{extension_name}', f'{extensions}'), ('All files', '*')), filetypes=((f'{extension_name}', f'{extensions}'), ('All files', '*')),
defaultextension=extensions, defaultextension=extensions,
initialdir=initial_dir, initialdir=initial_dir,
initialfile=initial_file, initialfile=initial_file,
) )
root.destroy() root.destroy()
if save_file_path == '': if save_file_path == '':
file_path = current_file_path file_path = current_file_path
else: else:
# print(save_file_path) # print(save_file_path)
file_path = save_file_path file_path = save_file_path
return file_path return file_path
@ -343,33 +325,6 @@ def add_pre_postfix(
) )
# def add_pre_postfix(
# folder='', prefix='', postfix='', caption_file_ext='.caption'
# ):
# if not has_ext_files(folder, caption_file_ext):
# msgbox(
# f'No files with extension {caption_file_ext} were found in {folder}...'
# )
# return
# if prefix == '' and postfix == '':
# return
# files = [f for f in os.listdir(folder) if f.endswith(caption_file_ext)]
# if not prefix == '':
# prefix = f'{prefix} '
# if not postfix == '':
# postfix = f' {postfix}'
# for file in files:
# with open(os.path.join(folder, file), 'r+') as f:
# content = f.read()
# content = content.rstrip()
# f.seek(0, 0)
# f.write(f'{prefix} {content} {postfix}')
# f.close()
def has_ext_files(folder_path: str, file_extension: str) -> bool: def has_ext_files(folder_path: str, file_extension: str) -> bool:
""" """
Check if there are any files with the specified extension in the given folder. Check if there are any files with the specified extension in the given folder.
@ -429,28 +384,6 @@ def find_replace(
f.write(content) f.write(content)
# def find_replace(folder='', caption_file_ext='.caption', find='', replace=''):
# print('Running caption find/replace')
# if not has_ext_files(folder, caption_file_ext):
# msgbox(
# f'No files with extension {caption_file_ext} were found in {folder}...'
# )
# return
# if find == '':
# return
# files = [f for f in os.listdir(folder) if f.endswith(caption_file_ext)]
# for file in files:
# with open(os.path.join(folder, file), 'r', errors='ignore') as f:
# content = f.read()
# f.close
# content = content.replace(find, replace)
# with open(os.path.join(folder, file), 'w') as f:
# f.write(content)
# f.close()
def color_aug_changed(color_aug): def color_aug_changed(color_aug):
if color_aug: if color_aug:
msgbox( msgbox(
@ -604,7 +537,13 @@ def get_pretrained_model_name_or_path_file(
set_model_list(model_list, pretrained_model_name_or_path) set_model_list(model_list, pretrained_model_name_or_path)
def gradio_source_model(): def gradio_source_model(save_model_as_choices = [
'same as source model',
'ckpt',
'diffusers',
'diffusers_safetensors',
'safetensors',
]):
with gr.Tab('Source model'): with gr.Tab('Source model'):
# Define the input elements # Define the input elements
with gr.Row(): with gr.Row():
@ -646,13 +585,7 @@ def gradio_source_model():
) )
save_model_as = gr.Dropdown( save_model_as = gr.Dropdown(
label='Save trained model as', label='Save trained model as',
choices=[ choices=save_model_as_choices,
'same as source model',
'ckpt',
'diffusers',
'diffusers_safetensors',
'safetensors',
],
value='safetensors', value='safetensors',
) )
@ -928,6 +861,13 @@ def gradio_advanced_training():
caption_dropout_rate = gr.Slider( caption_dropout_rate = gr.Slider(
label='Rate of caption dropout', value=0, minimum=0, maximum=1 label='Rate of caption dropout', value=0, minimum=0, maximum=1
) )
vae_batch_size = gr.Slider(
label='VAE batch size',
minimum=0,
maximum=32,
value=0,
every=1
)
with gr.Row(): with gr.Row():
save_state = gr.Checkbox(label='Save training state', value=False) save_state = gr.Checkbox(label='Save training state', value=False)
resume = gr.Textbox( resume = gr.Textbox(
@ -947,6 +887,7 @@ def gradio_advanced_training():
max_data_loader_n_workers = gr.Textbox( max_data_loader_n_workers = gr.Textbox(
label='Max num workers for DataLoader', label='Max num workers for DataLoader',
placeholder='(Optional) Override number of epoch. Default: 8', placeholder='(Optional) Override number of epoch. Default: 8',
value="0",
) )
return ( return (
# use_8bit_adam, # use_8bit_adam,
@ -972,6 +913,7 @@ def gradio_advanced_training():
caption_dropout_rate, caption_dropout_rate,
noise_offset, noise_offset,
additional_parameters, additional_parameters,
vae_batch_size,
) )
@ -998,8 +940,11 @@ def run_cmd_advanced_training(**kwargs):
f' --caption_dropout_every_n_epochs="{int(kwargs.get("caption_dropout_every_n_epochs", 0))}"' f' --caption_dropout_every_n_epochs="{int(kwargs.get("caption_dropout_every_n_epochs", 0))}"'
if int(kwargs.get('caption_dropout_every_n_epochs', 0)) > 0 if int(kwargs.get('caption_dropout_every_n_epochs', 0)) > 0
else '', else '',
f' --caption_dropout_rate="{kwargs.get("caption_dropout_rate", "")}"' f' --caption_dropout_every_n_epochs="{int(kwargs.get("caption_dropout_every_n_epochs", 0))}"'
if float(kwargs.get('caption_dropout_rate', 0)) > 0 if int(kwargs.get('caption_dropout_every_n_epochs', 0)) > 0
else '',
f' --vae_batch_size="{kwargs.get("vae_batch_size", 0)}"'
if int(kwargs.get('vae_batch_size', 0)) > 0
else '', else '',
f' --bucket_reso_steps={int(kwargs.get("bucket_reso_steps", 1))}' f' --bucket_reso_steps={int(kwargs.get("bucket_reso_steps", 1))}'
if int(kwargs.get('bucket_reso_steps', 64)) >= 1 if int(kwargs.get('bucket_reso_steps', 64)) >= 1

View File

@ -23,6 +23,7 @@ def extract_lora(
dim, dim,
v2, v2,
conv_dim, conv_dim,
device,
): ):
# Check for caption_text_input # Check for caption_text_input
if model_tuned == '': if model_tuned == '':
@ -50,6 +51,7 @@ def extract_lora(
run_cmd += f' --model_org "{model_org}"' run_cmd += f' --model_org "{model_org}"'
run_cmd += f' --model_tuned "{model_tuned}"' run_cmd += f' --model_tuned "{model_tuned}"'
run_cmd += f' --dim {dim}' run_cmd += f' --dim {dim}'
run_cmd += f' --device {device}'
if conv_dim > 0: if conv_dim > 0:
run_cmd += f' --conv_dim {conv_dim}' run_cmd += f' --conv_dim {conv_dim}'
if v2: if v2:
@ -134,7 +136,7 @@ def gradio_extract_lora_tab():
dim = gr.Slider( dim = gr.Slider(
minimum=4, minimum=4,
maximum=1024, maximum=1024,
label='Network Dimension', label='Network Dimension (Rank)',
value=128, value=128,
step=1, step=1,
interactive=True, interactive=True,
@ -142,12 +144,21 @@ def gradio_extract_lora_tab():
conv_dim = gr.Slider( conv_dim = gr.Slider(
minimum=0, minimum=0,
maximum=1024, maximum=1024,
label='Conv Dimension', label='Conv Dimension (Rank)',
value=0, value=128,
step=1, step=1,
interactive=True, interactive=True,
) )
v2 = gr.Checkbox(label='v2', value=False, interactive=True) v2 = gr.Checkbox(label='v2', value=False, interactive=True)
device = gr.Dropdown(
label='Device',
choices=[
'cpu',
'cuda',
],
value='cuda',
interactive=True,
)
extract_button = gr.Button('Extract LoRA model') extract_button = gr.Button('Extract LoRA model')
@ -161,6 +172,7 @@ def gradio_extract_lora_tab():
dim, dim,
v2, v2,
conv_dim, conv_dim,
device
], ],
show_progress=False, show_progress=False,
) )

View File

@ -73,8 +73,7 @@ DEFAULT_LAST_OUTPUT_NAME = "last"
# region dataset # region dataset
IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp"] IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"]
# , ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"] # Linux?
class ImageInfo: class ImageInfo:
@ -675,10 +674,19 @@ class BaseDataset(torch.utils.data.Dataset):
def is_latent_cacheable(self): def is_latent_cacheable(self):
return all([not subset.color_aug and not subset.random_crop for subset in self.subsets]) return all([not subset.color_aug and not subset.random_crop for subset in self.subsets])
def cache_latents(self, vae): def cache_latents(self, vae, vae_batch_size=1):
# TODO ここを高速化したい # ちょっと速くした
print("caching latents.") print("caching latents.")
for info in tqdm(self.image_data.values()):
image_infos = list(self.image_data.values())
# sort by resolution
image_infos.sort(key=lambda info: info.bucket_reso[0] * info.bucket_reso[1])
# split by resolution
batches = []
batch = []
for info in image_infos:
subset = self.image_to_subset[info.image_key] subset = self.image_to_subset[info.image_key]
if info.latents_npz is not None: if info.latents_npz is not None:
@ -689,18 +697,42 @@ class BaseDataset(torch.utils.data.Dataset):
info.latents_flipped = torch.FloatTensor(info.latents_flipped) info.latents_flipped = torch.FloatTensor(info.latents_flipped)
continue continue
image = self.load_image(info.absolute_path) # if last member of batch has different resolution, flush the batch
image = self.trim_and_resize_if_required(subset, image, info.bucket_reso, info.resized_size) if len(batch) > 0 and batch[-1].bucket_reso != info.bucket_reso:
batches.append(batch)
batch = []
img_tensor = self.image_transforms(image) batch.append(info)
img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
info.latents = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu") # if number of data in batch is enough, flush the batch
if len(batch) >= vae_batch_size:
batches.append(batch)
batch = []
if len(batch) > 0:
batches.append(batch)
# iterate batches
for batch in tqdm(batches, smoothing=1, total=len(batches)):
images = []
for info in batch:
image = self.load_image(info.absolute_path)
image = self.trim_and_resize_if_required(subset, image, info.bucket_reso, info.resized_size)
image = self.image_transforms(image)
images.append(image)
img_tensors = torch.stack(images, dim=0)
img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype)
latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
for info, latent in zip(batch, latents):
info.latents = latent
if subset.flip_aug: if subset.flip_aug:
image = image[:, ::-1].copy() # cannot convert to Tensor without copy img_tensors = torch.flip(img_tensors, dims=[3])
img_tensor = self.image_transforms(image) latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype) for info, latent in zip(batch, latents):
info.latents_flipped = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu") info.latents_flipped = latent
def get_image_size(self, image_path): def get_image_size(self, image_path):
image = Image.open(image_path) image = Image.open(image_path)
@ -1197,6 +1229,10 @@ class FineTuningDataset(BaseDataset):
npz_file_flip = None npz_file_flip = None
return npz_file_norm, npz_file_flip return npz_file_norm, npz_file_flip
# if not full path, check image_dir. if image_dir is None, return None
if subset.image_dir is None:
return None, None
# image_key is relative path # image_key is relative path
npz_file_norm = os.path.join(subset.image_dir, image_key + ".npz") npz_file_norm = os.path.join(subset.image_dir, image_key + ".npz")
npz_file_flip = os.path.join(subset.image_dir, image_key + "_flip.npz") npz_file_flip = os.path.join(subset.image_dir, image_key + "_flip.npz")
@ -1237,10 +1273,10 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
# for dataset in self.datasets: # for dataset in self.datasets:
# dataset.make_buckets() # dataset.make_buckets()
def cache_latents(self, vae): def cache_latents(self, vae, vae_batch_size=1):
for i, dataset in enumerate(self.datasets): for i, dataset in enumerate(self.datasets):
print(f"[Dataset {i}]") print(f"[Dataset {i}]")
dataset.cache_latents(vae) dataset.cache_latents(vae, vae_batch_size)
def is_latent_cacheable(self) -> bool: def is_latent_cacheable(self) -> bool:
return all([dataset.is_latent_cacheable() for dataset in self.datasets]) return all([dataset.is_latent_cacheable() for dataset in self.datasets])
@ -1989,6 +2025,7 @@ def add_dataset_arguments(
action="store_true", action="store_true",
help="cache latents to reduce memory (augmentations must be disabled) / メモリ削減のためにlatentをcacheするaugmentationは使用不可", help="cache latents to reduce memory (augmentations must be disabled) / メモリ削減のためにlatentをcacheするaugmentationは使用不可",
) )
parser.add_argument("--vae_batch_size", type=int, default=1, help="batch size for caching latents / latentのcache時のバッチサイズ")
parser.add_argument( parser.add_argument(
"--enable_bucket", action="store_true", help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする" "--enable_bucket", action="store_true", help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする"
) )

View File

@ -5,6 +5,16 @@ from .common_gui import get_folder_path
import os import os
def replace_underscore_with_space(folder_path, file_extension):
for file_name in os.listdir(folder_path):
if file_name.endswith(file_extension):
file_path = os.path.join(folder_path, file_name)
with open(file_path, 'r') as file:
file_content = file.read()
new_file_content = file_content.replace('_', ' ')
with open(file_path, 'w') as file:
file.write(new_file_content)
def caption_images( def caption_images(
train_data_dir, caption_extension, batch_size, thresh, replace_underscores train_data_dir, caption_extension, batch_size, thresh, replace_underscores
): ):
@ -26,9 +36,7 @@ def caption_images(
run_cmd = f'accelerate launch "./finetune/tag_images_by_wd14_tagger.py"' run_cmd = f'accelerate launch "./finetune/tag_images_by_wd14_tagger.py"'
run_cmd += f' --batch_size="{int(batch_size)}"' run_cmd += f' --batch_size="{int(batch_size)}"'
run_cmd += f' --thresh="{thresh}"' run_cmd += f' --thresh="{thresh}"'
run_cmd += f' --replace_underscores' if replace_underscores else '' run_cmd += f' --caption_extension="{caption_extension}"'
if caption_extension != '':
run_cmd += f' --caption_extension="{caption_extension}"'
run_cmd += f' "{train_data_dir}"' run_cmd += f' "{train_data_dir}"'
print(run_cmd) print(run_cmd)
@ -38,6 +46,9 @@ def caption_images(
os.system(run_cmd) os.system(run_cmd)
else: else:
subprocess.run(run_cmd) subprocess.run(run_cmd)
if replace_underscores:
replace_underscore_with_space(train_data_dir, caption_extension)
print('...captioning done') print('...captioning done')

View File

@ -123,7 +123,7 @@ def save_configuration(
sample_every_n_epochs, sample_every_n_epochs,
sample_sampler, sample_sampler,
sample_prompts, sample_prompts,
additional_parameters, additional_parameters,vae_batch_size,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -240,7 +240,7 @@ def open_configuration(
sample_every_n_epochs, sample_every_n_epochs,
sample_sampler, sample_sampler,
sample_prompts, sample_prompts,
additional_parameters, additional_parameters,vae_batch_size,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -257,7 +257,8 @@ def open_configuration(
with open(file_path, 'r') as f: with open(file_path, 'r') as f:
my_data = json.load(f) my_data = json.load(f)
print('Loading config...') print('Loading config...')
# Update values to fix deprecated use_8bit_adam checkbox and set appropriate optimizer if it is set to True
# Update values to fix deprecated use_8bit_adam checkbox, set appropriate optimizer if it is set to True, etc.
my_data = update_my_data(my_data) my_data = update_my_data(my_data)
else: else:
file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action
@ -347,7 +348,7 @@ def train_model(
sample_every_n_epochs, sample_every_n_epochs,
sample_sampler, sample_sampler,
sample_prompts, sample_prompts,
additional_parameters, additional_parameters,vae_batch_size,
): ):
print_only_bool = True if print_only.get('label') == 'True' else False print_only_bool = True if print_only.get('label') == 'True' else False
@ -418,14 +419,14 @@ def train_model(
num_images = len( num_images = len(
[ [
f f
for f in os.listdir(os.path.join(train_data_dir, folder)) for f, lower_f in (
if f.endswith('.jpg') (file, file.lower()) for file in os.listdir(os.path.join(train_data_dir, folder))
or f.endswith('.jpeg') )
or f.endswith('.png') if lower_f.endswith(('.jpg', '.jpeg', '.png', '.webp'))
or f.endswith('.webp')
] ]
) )
print(f'Folder {folder}: {num_images} images found') print(f'Folder {folder}: {num_images} images found')
# Calculate the total number of steps for this folder # Calculate the total number of steps for this folder
@ -589,6 +590,7 @@ def train_model(
caption_dropout_rate=caption_dropout_rate, caption_dropout_rate=caption_dropout_rate,
noise_offset=noise_offset, noise_offset=noise_offset,
additional_parameters=additional_parameters, additional_parameters=additional_parameters,
vae_batch_size=vae_batch_size,
) )
run_cmd += run_cmd_sample( run_cmd += run_cmd_sample(
@ -647,7 +649,10 @@ def lora_tab(
v_parameterization, v_parameterization,
save_model_as, save_model_as,
model_list, model_list,
) = gradio_source_model() ) = gradio_source_model(save_model_as_choices = [
'ckpt',
'safetensors',
])
with gr.Tab('Folders'): with gr.Tab('Folders'):
with gr.Row(): with gr.Row():
@ -891,6 +896,7 @@ def lora_tab(
caption_dropout_rate, caption_dropout_rate,
noise_offset, noise_offset,
additional_parameters, additional_parameters,
vae_batch_size,
) = gradio_advanced_training() ) = gradio_advanced_training()
color_aug.change( color_aug.change(
color_aug_changed, color_aug_changed,
@ -1008,6 +1014,7 @@ def lora_tab(
sample_sampler, sample_sampler,
sample_prompts, sample_prompts,
additional_parameters, additional_parameters,
vae_batch_size,
] ]
button_open_config.click( button_open_config.click(
@ -1096,6 +1103,8 @@ def UI(**kwargs):
launch_kwargs['server_port'] = kwargs.get('server_port', 0) launch_kwargs['server_port'] = kwargs.get('server_port', 0)
if kwargs.get('inbrowser', False): if kwargs.get('inbrowser', False):
launch_kwargs['inbrowser'] = kwargs.get('inbrowser', False) launch_kwargs['inbrowser'] = kwargs.get('inbrowser', False)
if kwargs.get('listen', True):
launch_kwargs['server_name'] = "0.0.0.0"
print(launch_kwargs) print(launch_kwargs)
interface.launch(**launch_kwargs) interface.launch(**launch_kwargs)
@ -1118,6 +1127,9 @@ if __name__ == '__main__':
parser.add_argument( parser.add_argument(
'--inbrowser', action='store_true', help='Open in browser' '--inbrowser', action='store_true', help='Open in browser'
) )
parser.add_argument(
'--listen', action='store_true', help='Launch gradio with server name 0.0.0.0, allowing LAN access'
)
args = parser.parse_args() args = parser.parse_args()

38
macos_setup.sh Executable file
View File

@ -0,0 +1,38 @@
#!/bin/bash
# The initial setup script to prep the environment on macOS
# xformers has been omitted as that is for Nvidia GPUs only
if ! command -v brew >/dev/null; then
echo "Please install homebrew first. This is a requirement for the remaining setup."
echo "You can find that here: https://brew.sh"
exit 1
fi
# Install base python packages
echo "Installing Python 3.10 if not found."
brew ls --versions python@3.10 >/dev/null || brew install python@3.10
echo "Installing Python-TK 3.10 if not found."
brew ls --versions python-tk@3.10 >/dev/null || brew install python-tk@3.10
if command -v python3.10 >/dev/null; then
python3.10 -m venv venv
source venv/bin/activate
# DEBUG ONLY
#pip install pydevd-pycharm~=223.8836.43
# Tensorflow installation
if wget https://github.com/apple/tensorflow_macos/releases/download/v0.1alpha3/tensorflow_macos-0.1a3-cp38-cp38-macosx_11_0_arm64.whl /tmp; then
python -m pip install tensorflow==0.1a3 -f https://github.com/apple/tensorflow_macos/releases/download/v0.1alpha3/tensorflow_macos-0.1a3-cp38-cp38-macosx_11_0_arm64.whl
rm -f /tmp/tensorflow_macos-0.1a3-cp38-cp38-macosx_11_0_arm64.whl
fi
pip install torch==2.0.0 torchvision==0.15.1 -f https://download.pytorch.org/whl/cpu/torch_stable.html
python -m pip install --use-pep517 --upgrade -r requirements_macos.txt
accelerate config
echo -e "Setup finished! Run ./gui_macos.sh to start."
else
echo "Python not found. Please ensure you install Python."
echo "The brew command for Python 3.10 is: brew install python@3.10"
exit 1
fi

View File

@ -24,9 +24,16 @@ def main(file):
print(f"{key},{str(tuple(value.size())).replace(', ', '-')},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}") print(f"{key},{str(tuple(value.size())).replace(', ', '-')},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}")
if __name__ == '__main__': def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("file", type=str, help="model file to check / 重みを確認するモデルファイル") parser.add_argument("file", type=str, help="model file to check / 重みを確認するモデルファイル")
return parser
if __name__ == '__main__':
parser = setup_parser()
args = parser.parse_args() args = parser.parse_args()
main(args.file) main(args.file)

View File

@ -11,8 +11,8 @@ import library.model_util as model_util
import lora import lora
CLAMP_QUANTILE = 0.99 CLAMP_QUANTILE = 1
MIN_DIFF = 1e-6 MIN_DIFF = 1e-8
def save_to_file(file_name, model, state_dict, dtype): def save_to_file(file_name, model, state_dict, dtype):
@ -113,7 +113,7 @@ def svd(args):
else: else:
mat = mat.squeeze() mat = mat.squeeze()
U, S, Vh = torch.linalg.svd(mat.to("cuda")) U, S, Vh = torch.linalg.svd(mat)
U = U[:, :rank] U = U[:, :rank]
S = S[:rank] S = S[:rank]
@ -121,7 +121,7 @@ def svd(args):
Vh = Vh[:rank, :] Vh = Vh[:rank, :]
dist = torch.cat([U.flatten(), Vh.flatten()]) # dist = torch.cat([U.flatten(), Vh.flatten()])
# hi_val = torch.quantile(dist, CLAMP_QUANTILE) # hi_val = torch.quantile(dist, CLAMP_QUANTILE)
# low_val = -hi_val # low_val = -hi_val
@ -132,8 +132,8 @@ def svd(args):
U = U.reshape(out_dim, rank, 1, 1) U = U.reshape(out_dim, rank, 1, 1)
Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1]) Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1])
U = U.to("cuda").contiguous() U = U.to("cpu").contiguous()
Vh = Vh.to("cuda").contiguous() Vh = Vh.to("cpu").contiguous()
lora_weights[lora_name] = (U, Vh) lora_weights[lora_name] = (U, Vh)
@ -162,7 +162,7 @@ def svd(args):
print(f"LoRA weights are saved to: {args.save_to}") print(f"LoRA weights are saved to: {args.save_to}")
if __name__ == '__main__': def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--v2", action='store_true', parser.add_argument("--v2", action='store_true',
help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む') help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む')
@ -179,5 +179,11 @@ if __name__ == '__main__':
help="dimension (rank) of LoRA for Conv2d-3x3 (default None, disabled) / LoRAのConv2d-3x3の次元数rankデフォルトNone、適用なし") help="dimension (rank) of LoRA for Conv2d-3x3 (default None, disabled) / LoRAのConv2d-3x3の次元数rankデフォルトNone、適用なし")
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う") parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
return parser
if __name__ == '__main__':
parser = setup_parser()
args = parser.parse_args() args = parser.parse_args()
svd(args) svd(args)

View File

@ -105,7 +105,7 @@ def interrogate(args):
print(f"[{i:3d}]: {token:5d} {string:<20s}: {diff:.5f}") print(f"[{i:3d}]: {token:5d} {string:<20s}: {diff:.5f}")
if __name__ == '__main__': def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--v2", action='store_true', parser.add_argument("--v2", action='store_true',
help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む') help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む')
@ -118,5 +118,11 @@ if __name__ == '__main__':
parser.add_argument("--clip_skip", type=int, default=None, parser.add_argument("--clip_skip", type=int, default=None,
help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いるnは1以上") help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いるnは1以上")
return parser
if __name__ == '__main__':
parser = setup_parser()
args = parser.parse_args() args = parser.parse_args()
interrogate(args) interrogate(args)

View File

@ -197,7 +197,7 @@ def merge(args):
save_to_file(args.save_to, state_dict, state_dict, save_dtype) save_to_file(args.save_to, state_dict, state_dict, save_dtype)
if __name__ == '__main__': def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--v2", action='store_true', parser.add_argument("--v2", action='store_true',
help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む') help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む')
@ -214,5 +214,11 @@ if __name__ == '__main__':
parser.add_argument("--ratios", type=float, nargs='*', parser.add_argument("--ratios", type=float, nargs='*',
help="ratios for each model / それぞれのLoRAモデルの比率") help="ratios for each model / それぞれのLoRAモデルの比率")
return parser
if __name__ == '__main__':
parser = setup_parser()
args = parser.parse_args() args = parser.parse_args()
merge(args) merge(args)

View File

@ -158,7 +158,7 @@ def merge(args):
save_to_file(args.save_to, state_dict, state_dict, save_dtype) save_to_file(args.save_to, state_dict, state_dict, save_dtype)
if __name__ == '__main__': def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--v2", action='store_true', parser.add_argument("--v2", action='store_true',
help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む') help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む')
@ -175,5 +175,11 @@ if __name__ == '__main__':
parser.add_argument("--ratios", type=float, nargs='*', parser.add_argument("--ratios", type=float, nargs='*',
help="ratios for each model / それぞれのLoRAモデルの比率") help="ratios for each model / それぞれのLoRAモデルの比率")
return parser
if __name__ == '__main__':
parser = setup_parser()
args = parser.parse_args() args = parser.parse_args()
merge(args) merge(args)

View File

@ -208,18 +208,28 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn
with torch.no_grad(): with torch.no_grad():
for key, value in tqdm(lora_sd.items()): for key, value in tqdm(lora_sd.items()):
weight_name = None
if 'lora_down' in key: if 'lora_down' in key:
block_down_name = key.split(".")[0] block_down_name = key.split(".")[0]
weight_name = key.split(".")[-1]
lora_down_weight = value lora_down_weight = value
if 'lora_up' in key: else:
block_up_name = key.split(".")[0] continue
lora_up_weight = value
# find corresponding lora_up and alpha
block_up_name = block_down_name
lora_up_weight = lora_sd.get(block_up_name + '.lora_up.' + weight_name, None)
lora_alpha = lora_sd.get(block_down_name + '.alpha', None)
weights_loaded = (lora_down_weight is not None and lora_up_weight is not None) weights_loaded = (lora_down_weight is not None and lora_up_weight is not None)
if (block_down_name == block_up_name) and weights_loaded: if weights_loaded:
conv2d = (len(lora_down_weight.size()) == 4) conv2d = (len(lora_down_weight.size()) == 4)
if lora_alpha is None:
scale = 1.0
else:
scale = lora_alpha/lora_down_weight.size()[0]
if conv2d: if conv2d:
full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device) full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device)
@ -311,7 +321,7 @@ def resize(args):
save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata) save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata)
if __name__ == '__main__': def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--save_precision", type=str, default=None, parser.add_argument("--save_precision", type=str, default=None,
@ -329,7 +339,12 @@ if __name__ == '__main__':
help="Specify dynamic resizing method, --new_rank is used as a hard limit for max rank") help="Specify dynamic resizing method, --new_rank is used as a hard limit for max rank")
parser.add_argument("--dynamic_param", type=float, default=None, parser.add_argument("--dynamic_param", type=float, default=None,
help="Specify target for dynamic reduction") help="Specify target for dynamic reduction")
return parser
if __name__ == '__main__':
parser = setup_parser()
args = parser.parse_args() args = parser.parse_args()
resize(args) resize(args)

View File

@ -76,7 +76,11 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
down_weight = down_weight.to(device) down_weight = down_weight.to(device)
# W <- W + U * D # W <- W + U * D
scale = (alpha / network_dim).to(device) scale = (alpha / network_dim)
if device: # and isinstance(scale, torch.Tensor):
scale = scale.to(device)
if not conv2d: # linear if not conv2d: # linear
weight = weight + ratio * (up_weight @ down_weight) * scale weight = weight + ratio * (up_weight @ down_weight) * scale
elif kernel_size == (1, 1): elif kernel_size == (1, 1):
@ -115,12 +119,12 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
Vh = Vh[:module_new_rank, :] Vh = Vh[:module_new_rank, :]
# dist = torch.cat([U.flatten(), Vh.flatten()]) dist = torch.cat([U.flatten(), Vh.flatten()])
# hi_val = torch.quantile(dist, CLAMP_QUANTILE) hi_val = torch.quantile(dist, CLAMP_QUANTILE)
# low_val = -hi_val low_val = -hi_val
# U = U.clamp(low_val, hi_val) U = U.clamp(low_val, hi_val)
# Vh = Vh.clamp(low_val, hi_val) Vh = Vh.clamp(low_val, hi_val)
if conv2d: if conv2d:
U = U.reshape(out_dim, module_new_rank, 1, 1) U = U.reshape(out_dim, module_new_rank, 1, 1)
@ -160,7 +164,7 @@ def merge(args):
save_to_file(args.save_to, state_dict, save_dtype) save_to_file(args.save_to, state_dict, save_dtype)
if __name__ == '__main__': def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--save_precision", type=str, default=None, parser.add_argument("--save_precision", type=str, default=None,
choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ") choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ")
@ -178,5 +182,11 @@ if __name__ == '__main__':
help="Specify rank of output LoRA for Conv2d 3x3, None for same as new_rank / 出力するConv2D 3x3 LoRAのrank (dim)、Noneでnew_rankと同じ") help="Specify rank of output LoRA for Conv2d 3x3, None for same as new_rank / 出力するConv2D 3x3 LoRAのrank (dim)、Noneでnew_rankと同じ")
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う") parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
return parser
if __name__ == '__main__':
parser = setup_parser()
args = parser.parse_args() args = parser.parse_args()
merge(args) merge(args)

32
requirements_macos.txt Normal file
View File

@ -0,0 +1,32 @@
accelerate==0.15.0
albumentations==1.3.0
altair==4.2.2
bitsandbytes==0.35.0
dadaptation==1.5
diffusers[torch]==0.10.2
easygui==0.98.3
einops==0.6.0
ftfy==6.1.1
gradio==3.19.1; sys_platform != 'darwin'
gradio==3.23.0; sys_platform == 'darwin'
lion-pytorch==0.0.6
opencv-python==4.7.0.68
pytorch-lightning==1.9.0
safetensors==0.2.6
tensorboard==2.10.1
tk==0.1.0
toml==0.10.2
transformers==4.26.0
voluptuous==0.13.1
# for BLIP captioning
fairscale==0.4.13
requests==2.28.2
timm==0.6.12
# tensorflow<2.11
huggingface-hub==0.12.0; sys_platform != 'darwin'
huggingface-hub==0.13.0; sys_platform == 'darwin'
tensorflow==2.10.1; sys_platform != 'darwin'
# For locon support
lycoris_lora==0.1.2
# for kohya_ss library
.

View File

@ -1,3 +1,10 @@
from setuptools import setup, find_packages from setuptools import setup, find_packages
import subprocess
setup(name = "library", version="1.0.2", packages = find_packages()) import os
import sys
# Call the create_user_files.py script
script_path = os.path.join("tools", "create_user_files.py")
subprocess.run([sys.executable, script_path])
setup(name="library", version="1.0.3", packages=find_packages())

View File

@ -112,7 +112,7 @@ def save_configuration(
sample_every_n_epochs, sample_every_n_epochs,
sample_sampler, sample_sampler,
sample_prompts, sample_prompts,
additional_parameters, additional_parameters,vae_batch_size,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -225,7 +225,7 @@ def open_configuration(
sample_every_n_epochs, sample_every_n_epochs,
sample_sampler, sample_sampler,
sample_prompts, sample_prompts,
additional_parameters, additional_parameters,vae_batch_size,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -320,7 +320,7 @@ def train_model(
sample_every_n_epochs, sample_every_n_epochs,
sample_sampler, sample_sampler,
sample_prompts, sample_prompts,
additional_parameters, additional_parameters,vae_batch_size,
): ):
if pretrained_model_name_or_path == '': if pretrained_model_name_or_path == '':
msgbox('Source model information is missing') msgbox('Source model information is missing')
@ -375,11 +375,10 @@ def train_model(
num_images = len( num_images = len(
[ [
f f
for f in os.listdir(os.path.join(train_data_dir, folder)) for f, lower_f in (
if f.endswith('.jpg') (file, file.lower()) for file in os.listdir(os.path.join(train_data_dir, folder))
or f.endswith('.jpeg') )
or f.endswith('.png') if lower_f.endswith(('.jpg', '.jpeg', '.png', '.webp'))
or f.endswith('.webp')
] ]
) )
@ -511,6 +510,7 @@ def train_model(
caption_dropout_rate=caption_dropout_rate, caption_dropout_rate=caption_dropout_rate,
noise_offset=noise_offset, noise_offset=noise_offset,
additional_parameters=additional_parameters, additional_parameters=additional_parameters,
vae_batch_size=vae_batch_size,
) )
run_cmd += f' --token_string="{token_string}"' run_cmd += f' --token_string="{token_string}"'
run_cmd += f' --init_word="{init_word}"' run_cmd += f' --init_word="{init_word}"'
@ -569,7 +569,10 @@ def ti_tab(
v_parameterization, v_parameterization,
save_model_as, save_model_as,
model_list, model_list,
) = gradio_source_model() ) = gradio_source_model(save_model_as_choices = [
'ckpt',
'safetensors',
])
with gr.Tab('Folders'): with gr.Tab('Folders'):
with gr.Row(): with gr.Row():
@ -770,6 +773,7 @@ def ti_tab(
caption_dropout_rate, caption_dropout_rate,
noise_offset, noise_offset,
additional_parameters, additional_parameters,
vae_batch_size,
) = gradio_advanced_training() ) = gradio_advanced_training()
color_aug.change( color_aug.change(
color_aug_changed, color_aug_changed,
@ -876,6 +880,7 @@ def ti_tab(
sample_sampler, sample_sampler,
sample_prompts, sample_prompts,
additional_parameters, additional_parameters,
vae_batch_size,
] ]
button_open_config.click( button_open_config.click(

View File

@ -13,12 +13,18 @@ def canny(args):
print("done!") print("done!")
if __name__ == '__main__': def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--input", type=str, default=None, help="input path") parser.add_argument("--input", type=str, default=None, help="input path")
parser.add_argument("--output", type=str, default=None, help="output path") parser.add_argument("--output", type=str, default=None, help="output path")
parser.add_argument("--thres1", type=int, default=32, help="thres1") parser.add_argument("--thres1", type=int, default=32, help="thres1")
parser.add_argument("--thres2", type=int, default=224, help="thres2") parser.add_argument("--thres2", type=int, default=224, help="thres2")
return parser
if __name__ == '__main__':
parser = setup_parser()
args = parser.parse_args() args = parser.parse_args()
canny(args) canny(args)

View File

@ -61,7 +61,7 @@ def convert(args):
print(f"model saved.") print(f"model saved.")
if __name__ == '__main__': def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--v1", action='store_true', parser.add_argument("--v1", action='store_true',
help='load v1.x model (v1 or v2 is required to load checkpoint) / 1.xのモデルを読み込む') help='load v1.x model (v1 or v2 is required to load checkpoint) / 1.xのモデルを読み込む')
@ -84,6 +84,11 @@ if __name__ == '__main__':
help="model to load: checkpoint file or Diffusers model's directory / 読み込むモデル、checkpointかDiffusers形式モデルのディレクトリ") help="model to load: checkpoint file or Diffusers model's directory / 読み込むモデル、checkpointかDiffusers形式モデルのディレクトリ")
parser.add_argument("model_to_save", type=str, default=None, parser.add_argument("model_to_save", type=str, default=None,
help="model to save: checkpoint (with extension) or Diffusers model's directory (without extension) / 変換後のモデル、拡張子がある場合はcheckpoint、ない場合はDiffusesモデルとして保存") help="model to save: checkpoint (with extension) or Diffusers model's directory (without extension) / 変換後のモデル、拡張子がある場合はcheckpoint、ない場合はDiffusesモデルとして保存")
return parser
if __name__ == '__main__':
parser = setup_parser()
args = parser.parse_args() args = parser.parse_args()
convert(args) convert(args)

View File

@ -0,0 +1,37 @@
import os
bat_content = r'''@echo off
REM Example of how to start the GUI with custom arguments. In this case how to auto launch the browser:
REM call gui.bat --inbrowser
REM
REM You can add many arguments on the same line
REM
call gui.bat --inbrowser
'''
ps1_content = r'''# Example of how to start the GUI with custom arguments. In this case how to auto launch the browser:
# .\gui.ps1 --inbrowser
#
# You can add many arguments on the same line
#
# & .\gui.ps1 --inbrowser --server_port 2345
& .\gui.ps1 --inbrowser
'''
bat_filename = 'gui-user.bat'
ps1_filename = 'gui-user.ps1'
if not os.path.exists(bat_filename):
with open(bat_filename, 'w') as bat_file:
bat_file.write(bat_content)
print(f"File created: {bat_filename}")
else:
print(f"File already exists: {bat_filename}")
if not os.path.exists(ps1_filename):
with open(ps1_filename, 'w') as ps1_file:
ps1_file.write(ps1_content)
print(f"File created: {ps1_filename}")
else:
print(f"File already exists: {ps1_filename}")

View File

@ -214,7 +214,7 @@ def process(args):
buf.tofile(f) buf.tofile(f)
if __name__ == '__main__': def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--src_dir", type=str, help="directory to load images / 画像を読み込むディレクトリ") parser.add_argument("--src_dir", type=str, help="directory to load images / 画像を読み込むディレクトリ")
parser.add_argument("--dst_dir", type=str, help="directory to save images / 画像を保存するディレクトリ") parser.add_argument("--dst_dir", type=str, help="directory to save images / 画像を保存するディレクトリ")
@ -234,6 +234,13 @@ if __name__ == '__main__':
parser.add_argument("--multiple_faces", action="store_true", parser.add_argument("--multiple_faces", action="store_true",
help="output each faces / 複数の顔が見つかった場合、それぞれを切り出す") help="output each faces / 複数の顔が見つかった場合、それぞれを切り出す")
parser.add_argument("--debug", action="store_true", help="render rect for face / 処理後画像の顔位置に矩形を描画します") parser.add_argument("--debug", action="store_true", help="render rect for face / 処理後画像の顔位置に矩形を描画します")
return parser
if __name__ == '__main__':
parser = setup_parser()
args = parser.parse_args() args = parser.parse_args()
process(args) process(args)

View File

@ -98,7 +98,7 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi
shutil.copy(os.path.join(src_img_folder, asoc_file), os.path.join(dst_img_folder, new_asoc_file)) shutil.copy(os.path.join(src_img_folder, asoc_file), os.path.join(dst_img_folder, new_asoc_file))
def main(): def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='Resize images in a folder to a specified max resolution(s) / 指定されたフォルダ内の画像を指定した最大画像サイズ(面積)以下にアスペクト比を維持したままリサイズします') description='Resize images in a folder to a specified max resolution(s) / 指定されたフォルダ内の画像を指定した最大画像サイズ(面積)以下にアスペクト比を維持したままリサイズします')
parser.add_argument('src_img_folder', type=str, help='Source folder containing the images / 元画像のフォルダ') parser.add_argument('src_img_folder', type=str, help='Source folder containing the images / 元画像のフォルダ')
@ -113,6 +113,12 @@ def main():
parser.add_argument('--copy_associated_files', action='store_true', parser.add_argument('--copy_associated_files', action='store_true',
help='Copy files with same base name to images (captions etc) / 画像と同じファイル名(拡張子を除く)のファイルもコピーする') help='Copy files with same base name to images (captions etc) / 画像と同じファイル名(拡張子を除く)のファイルもコピーする')
return parser
def main():
parser = setup_parser()
args = parser.parse_args() args = parser.parse_args()
resize_images(args.src_img_folder, args.dst_img_folder, args.max_resolution, resize_images(args.src_img_folder, args.dst_img_folder, args.max_resolution,
args.divisible_by, args.interpolation, args.save_as_png, args.copy_associated_files) args.divisible_by, args.interpolation, args.save_as_png, args.copy_associated_files)

View File

@ -1,11 +1,17 @@
import os import os
import sys import sys
import pkg_resources import pkg_resources
import argparse
# Parse command line arguments
parser = argparse.ArgumentParser(description="Validate that requirements are satisfied.")
parser.add_argument('-r', '--requirements', type=str, default='requirements.txt', help="Path to the requirements file.")
args = parser.parse_args()
print("Validating that requirements are satisfied.") print("Validating that requirements are satisfied.")
# Load the requirements from the requirements.txt file # Load the requirements from the specified requirements file
with open('requirements.txt') as f: with open(args.requirements) as f:
requirements = f.readlines() requirements = f.readlines()
# Check each requirement against the installed packages # Check each requirement against the installed packages
@ -34,7 +40,7 @@ if missing_requirements or wrong_version_requirements:
for requirement, expected_version, actual_version in wrong_version_requirements: for requirement, expected_version, actual_version in wrong_version_requirements:
print(f" - {requirement} (expected version {expected_version}, found version {actual_version})") print(f" - {requirement} (expected version {expected_version}, found version {actual_version})")
upgrade_script = "upgrade.ps1" if os.name == "nt" else "upgrade.sh" upgrade_script = "upgrade.ps1" if os.name == "nt" else "upgrade.sh"
print(f"\nRun \033[33m{upgrade_script}\033[0m or \033[33mpip install -U -r requirements.txt\033[0m to resolve the missing requirements listed above...") print(f"\nRun \033[33m{upgrade_script}\033[0m or \033[33mpip install -U -r {args.requirements}\033[0m to resolve the missing requirements listed above...")
sys.exit(1) sys.exit(1)

View File

@ -115,7 +115,7 @@ def train(args):
vae.requires_grad_(False) vae.requires_grad_(False)
vae.eval() vae.eval()
with torch.no_grad(): with torch.no_grad():
train_dataset_group.cache_latents(vae) train_dataset_group.cache_latents(vae, args.vae_batch_size)
vae.to("cpu") vae.to("cpu")
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -160,7 +160,7 @@ def train(args):
# 学習ステップ数を計算する # 学習ステップ数を計算する
if args.max_train_epochs is not None: if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * len(train_dataloader) args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps)
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
if args.stop_text_encoder_training is None: if args.stop_text_encoder_training is None:
@ -384,7 +384,7 @@ def train(args):
print("model saved.") print("model saved.")
if __name__ == "__main__": def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
train_util.add_sd_models_arguments(parser) train_util.add_sd_models_arguments(parser)
@ -406,6 +406,12 @@ if __name__ == "__main__":
help="steps to stop text encoder training, -1 for no training / Text Encoderの学習を止めるステップ数、-1で最初から学習しない", help="steps to stop text encoder training, -1 for no training / Text Encoderの学習を止めるステップ数、-1で最初から学習しない",
) )
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args() args = parser.parse_args()
args = train_util.read_config_from_file(args, parser) args = train_util.read_config_from_file(args, parser)

View File

@ -140,7 +140,7 @@ def train(args):
vae.requires_grad_(False) vae.requires_grad_(False)
vae.eval() vae.eval()
with torch.no_grad(): with torch.no_grad():
train_dataset_group.cache_latents(vae) train_dataset_group.cache_latents(vae, args.vae_batch_size)
vae.to("cpu") vae.to("cpu")
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -197,7 +197,7 @@ def train(args):
# 学習ステップ数を計算する # 学習ステップ数を計算する
if args.max_train_epochs is not None: if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes) args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps)
if is_main_process: if is_main_process:
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
@ -647,7 +647,7 @@ def train(args):
print("model saved.") print("model saved.")
if __name__ == "__main__": def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
train_util.add_sd_models_arguments(parser) train_util.add_sd_models_arguments(parser)
@ -690,6 +690,12 @@ if __name__ == "__main__":
"--training_comment", type=str, default=None, help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列" "--training_comment", type=str, default=None, help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列"
) )
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args() args = parser.parse_args()
args = train_util.read_config_from_file(args, parser) args = train_util.read_config_from_file(args, parser)

View File

@ -228,7 +228,7 @@ def train(args):
vae.requires_grad_(False) vae.requires_grad_(False)
vae.eval() vae.eval()
with torch.no_grad(): with torch.no_grad():
train_dataset_group.cache_latents(vae) train_dataset_group.cache_latents(vae, args.vae_batch_size)
vae.to("cpu") vae.to("cpu")
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -257,7 +257,7 @@ def train(args):
# 学習ステップ数を計算する # 学習ステップ数を計算する
if args.max_train_epochs is not None: if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * len(train_dataloader) args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps)
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# lr schedulerを用意する # lr schedulerを用意する
@ -526,7 +526,7 @@ def load_weights(file):
return emb return emb
if __name__ == "__main__": def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
train_util.add_sd_models_arguments(parser) train_util.add_sd_models_arguments(parser)
@ -565,6 +565,12 @@ if __name__ == "__main__":
help="ignore caption and use default templates for stype / キャプションは使わずデフォルトのスタイル用テンプレートで学習する", help="ignore caption and use default templates for stype / キャプションは使わずデフォルトのスタイル用テンプレートで学習する",
) )
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args() args = parser.parse_args()
args = train_util.read_config_from_file(args, parser) args = train_util.read_config_from_file(args, parser)

16
upgrade_macos.sh Executable file
View File

@ -0,0 +1,16 @@
#!/bin/bash
# Check if there are any changes that need to be committed
if [[ -n $(git status --short) ]]; then
echo "There are changes that need to be committed. Please stash or undo your changes before running this script." >&2
exit 1
fi
# Pull the latest changes from the remote repository
git pull
# Activate the virtual environment
source venv/bin/activate
# Upgrade the required packages
pip install --upgrade -r requirements_macos.txt