From 4277645b2a7c859b10e3da3425ebf7985ef60617 Mon Sep 17 00:00:00 2001 From: Vertana Date: Mon, 27 Mar 2023 12:37:00 -0700 Subject: [PATCH 01/44] Consolidated Install Scripts and Improve README Install scripts have been consolidated for every non-Windows OS. Python Requirements were consolidated. README improved to work locally and provide more information. --- README.md | 122 ++++++++++++++++++++++++----------------- gui.sh | 14 ++--- gui_macos.sh | 13 ----- macos_setup.sh | 38 ------------- requirements.txt | 8 ++- requirements_macos.txt | 32 ----------- setup.sh | 96 ++++++++++++++++++++++++++++++++ ubuntu_setup.sh | 12 ---- upgrade.sh | 2 +- upgrade_macos.sh | 16 ------ 10 files changed, 178 insertions(+), 175 deletions(-) delete mode 100755 gui_macos.sh delete mode 100755 macos_setup.sh delete mode 100644 requirements_macos.txt create mode 100755 setup.sh delete mode 100755 ubuntu_setup.sh delete mode 100755 upgrade_macos.sh diff --git a/README.md b/README.md index f3af693..92cfecd 100644 --- a/README.md +++ b/README.md @@ -6,21 +6,29 @@ If you run on Linux and would like to use the GUI, there is now a port of it as ### Table of Contents -- [Tutorials](https://github.com/bmaltais/kohya_ss#tutorials) -- [Required Dependencies](https://github.com/bmaltais/kohya_ss#required-dependencies) -- [Installation](https://github.com/bmaltais/kohya_ss#installation) - - [CUDNN 8.6](https://github.com/bmaltais/kohya_ss#optional-cudnn-86) -- [Upgrading](https://github.com/bmaltais/kohya_ss#upgrading) -- [Launching the GUI](https://github.com/bmaltais/kohya_ss#launching-the-gui) -- [Dreambooth](https://github.com/bmaltais/kohya_ss#dreambooth) -- [Finetune](https://github.com/bmaltais/kohya_ss#finetune) -- [Train Network](https://github.com/bmaltais/kohya_ss#train-network) -- [LoRA](https://github.com/bmaltais/kohya_ss#lora) -- [Troubleshooting](https://github.com/bmaltais/kohya_ss#troubleshooting) - - [Page File Limit](https://github.com/bmaltais/kohya_ss#page-file-limit) - - [No module called tkinter](https://github.com/bmaltais/kohya_ss#no-module-called-tkinter) - - [FileNotFoundError](https://github.com/bmaltais/kohya_ss#filenotfounderror) -- [Change History](https://github.com/bmaltais/kohya_ss#change-history) +- [Tutorials](#tutorials) +- [Required Dependencies](#required-dependencies) + - [Linux/macOS](#linux-and-macos-dependencies) +- [Installation](#installation) + - [Linux/macOS](#linux-and-macos) + - [Windows](#windows) + - [CUDNN 8.6](#optional--cudnn-86) +- [Upgrading](#upgrading) + - [Windows](#windows-upgrade) + - [Linux/macOS](#linux-and-macos-upgrade) +- [Launching the GUI](#starting-gui-service) + - [Windows](#launching-the-gui-on-windows) + - [Linux/macOS](#launching-the-gui-on-linux-and-macos) + - [Direct Launch via Python Script](#launching-the-gui-directly-using-kohyaguipy) +- [Dreambooth](#dreambooth) +- [Finetune](#finetune) +- [Train Network](#train-network) +- [LoRA](#lora) +- [Troubleshooting](#troubleshooting) + - [Page File Limit](#page-file-limit) + - [No module called tkinter](#no-module-called-tkinter) + - [FileNotFoundError](#filenotfounderror) +- [Change History](#change-history) ## Tutorials @@ -39,32 +47,27 @@ If you run on Linux and would like to use the GUI, there is now a port of it as - Install [Git](https://git-scm.com/download/win) - Install [Visual Studio 2015, 2017, 2019, and 2022 redistributable](https://aka.ms/vs/17/release/vc_redist.x64.exe) +### Linux and macOS dependencies + +These dependencies are taken care of via `setup.sh` in the installation section. No additional steps should be needed unless the scripts inform you otherwise. + ## Installation ### Runpod Follow the instructions found in this discussion: https://github.com/bmaltais/kohya_ss/discussions/379 -### MacOS +### Linux and macOS In the terminal, run ``` git clone https://github.com/bmaltais/kohya_ss.git cd kohya_ss -bash macos_setup.sh +# May need to chmod +x ./setup.sh if you're on a machine with stricter security. +./setup.sh ``` During the accelerate config screen after running the script answer "This machine", "None", "No" for the remaining questions. - -### Ubuntu -In the terminal, run - -``` -git clone https://github.com/bmaltais/kohya_ss.git -cd kohya_ss -bash ubuntu_setup.sh -``` - -then configure accelerate with the same answers as in the Windows instructions when prompted. +These are the same answers as the Windows install. ### Windows @@ -110,21 +113,13 @@ Run the following commands to install: python .\tools\cudann_1.8_install.py ``` -## Upgrading MacOS - -When a new release comes out, you can upgrade your repo with the following commands in the root directory: - -```bash -upgrade_macos.sh -``` - Once the commands have completed successfully you should be ready to use the new version. MacOS support is not tested and has been mostly taken from https://gist.github.com/jstayco/9f5733f05b9dc29de95c4056a023d645 -## Upgrading Windows +## Upgrading -When a new release comes out, you can upgrade your repo with the following commands in the root directory: - -```powershell +The following commands will work from the root directory of the project if you'd prefer to not run scripts. +These commands will work on any OS. +```bash git pull .\venv\Scripts\activate @@ -132,20 +127,40 @@ git pull pip install --use-pep517 --upgrade -r requirements.txt ``` +### Windows Upgrade +When a new release comes out, you can upgrade your repo with the following commands in the root directory: + +```powershell +./upgrade.ps1 +``` + +### Linux and macOS Upgrade +You can cd into the root directory and simply run + +```bash +./upgrade.sh +``` + Once the commands have completed successfully you should be ready to use the new version. -## Launching the GUI using gui.bat or gui.ps1 - -The script can be run with several optional command line arguments: +# Starting GUI Service +The following command line arguments can be passed to the scripts on any OS to configure the underlying service. +``` --listen: the IP address to listen on for connections to Gradio. ---username: a username for authentication. ---password: a password for authentication. ---server_port: the port to run the server listener on. ---inbrowser: opens the Gradio UI in a web browser. +--username: a username for authentication. +--password: a password for authentication. +--server_port: the port to run the server listener on. +--inbrowser: opens the Gradio UI in a web browser. --share: shares the Gradio UI. +``` -These command line arguments can be passed to the UI function as keyword arguments. To launch the Gradio UI, run the script in a terminal with the desired command line arguments, for example: +### Launching the GUI on Windows + +The two scripts to launch the GUI on Windows are gui.ps1 and gui.bat in the root directory. +You can use whichever script you prefer. + +To launch the Gradio UI, run the script in a terminal with the desired command line arguments, for example: `gui.ps1 --listen 127.0.0.1 --server_port 7860 --inbrowser --share` @@ -153,14 +168,19 @@ or `gui.bat --listen 127.0.0.1 --server_port 7860 --inbrowser --share` -## Launching the GUI using kohya_gui.py +## Launching the GUI on Linux and macOS -To run the GUI, simply use this command: +Run the launcher script with the desired command line arguments similar to Windows. +`gui.sh --listen 127.0.0.1 --server_port 7860 --inbrowser --share` + +## Launching the GUI directly using kohya_gui.py + +To run the GUI directly bypassing the wrapper scripts, simply use this command from the root project directory: ``` .\venv\Scripts\activate -python.exe .\kohya_gui.py +python .\kohya_gui.py ``` ## Dreambooth diff --git a/gui.sh b/gui.sh index e4eca6f..bd7078f 100755 --- a/gui.sh +++ b/gui.sh @@ -1,13 +1,9 @@ -#!/bin/bash +#!/usr/bin/env bash # Activate the virtual environment -source venv/bin/activate - -# Validate the requirements and store the exit code -python tools/validate_requirements.py -exit_code=$? - -# If the exit code is 0, run the kohya_gui.py script with the command-line arguments -if [ $exit_code -eq 0 ]; then +source ./venv/bin/activate +python -V +# If the requirements are validated, run the kohya_gui.py script with the command-line arguments +if python tools/validate_requirements.py; then python kohya_gui.py "$@" fi diff --git a/gui_macos.sh b/gui_macos.sh deleted file mode 100755 index 4a0bfb8..0000000 --- a/gui_macos.sh +++ /dev/null @@ -1,13 +0,0 @@ -#!/bin/bash - -# Activate the virtual environment -source venv/bin/activate - -# Validate the requirements and store the exit code -python tools/validate_requirements.py --requirements requirements_macos.txt -exit_code=$? - -# If the exit code is 0, run the kohya_gui.py script with the command-line arguments -if [ $exit_code -eq 0 ]; then - python kohya_gui.py "$@" -fi diff --git a/macos_setup.sh b/macos_setup.sh deleted file mode 100755 index 4de8417..0000000 --- a/macos_setup.sh +++ /dev/null @@ -1,38 +0,0 @@ -#!/bin/bash -# The initial setup script to prep the environment on macOS -# xformers has been omitted as that is for Nvidia GPUs only - -if ! command -v brew >/dev/null; then - echo "Please install homebrew first. This is a requirement for the remaining setup." - echo "You can find that here: https://brew.sh" - exit 1 -fi - -# Install base python packages -echo "Installing Python 3.10 if not found." -brew ls --versions python@3.10 >/dev/null || brew install python@3.10 -echo "Installing Python-TK 3.10 if not found." -brew ls --versions python-tk@3.10 >/dev/null || brew install python-tk@3.10 - -if command -v python3.10 >/dev/null; then - python3.10 -m venv venv - source venv/bin/activate - - # DEBUG ONLY - #pip install pydevd-pycharm~=223.8836.43 - - # Tensorflow installation - if wget https://github.com/apple/tensorflow_macos/releases/download/v0.1alpha3/tensorflow_macos-0.1a3-cp38-cp38-macosx_11_0_arm64.whl /tmp; then - python -m pip install tensorflow==0.1a3 -f https://github.com/apple/tensorflow_macos/releases/download/v0.1alpha3/tensorflow_macos-0.1a3-cp38-cp38-macosx_11_0_arm64.whl - rm -f /tmp/tensorflow_macos-0.1a3-cp38-cp38-macosx_11_0_arm64.whl - fi - - pip install torch==2.0.0 torchvision==0.15.1 -f https://download.pytorch.org/whl/cpu/torch_stable.html - python -m pip install --use-pep517 --upgrade -r requirements_macos.txt - accelerate config - echo -e "Setup finished! Run ./gui_macos.sh to start." -else - echo "Python not found. Please ensure you install Python." - echo "The brew command for Python 3.10 is: brew install python@3.10" - exit 1 -fi \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 8e47439..4ee4eec 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,7 +7,8 @@ diffusers[torch]==0.10.2 easygui==0.98.3 einops==0.6.0 ftfy==6.1.1 -gradio==3.19.1 +gradio==3.19.1; sys_platform != 'darwin' +gradio==3.23.0; sys_platform == 'darwin' lion-pytorch==0.0.6 opencv-python==4.7.0.68 pytorch-lightning==1.9.0 @@ -22,8 +23,9 @@ fairscale==0.4.13 requests==2.28.2 timm==0.6.12 # tensorflow<2.11 -huggingface-hub==0.12.0 -tensorflow==2.10.1 +huggingface-hub==0.12.0; sys_platform != 'darwin' +huggingface-hub==0.13.0; sys_platform == 'darwin' +tensorflow==2.10.1; sys_platform != 'darwin' # For locon support lycoris_lora==0.1.2 # for kohya_ss library diff --git a/requirements_macos.txt b/requirements_macos.txt deleted file mode 100644 index 4ee4eec..0000000 --- a/requirements_macos.txt +++ /dev/null @@ -1,32 +0,0 @@ -accelerate==0.15.0 -albumentations==1.3.0 -altair==4.2.2 -bitsandbytes==0.35.0 -dadaptation==1.5 -diffusers[torch]==0.10.2 -easygui==0.98.3 -einops==0.6.0 -ftfy==6.1.1 -gradio==3.19.1; sys_platform != 'darwin' -gradio==3.23.0; sys_platform == 'darwin' -lion-pytorch==0.0.6 -opencv-python==4.7.0.68 -pytorch-lightning==1.9.0 -safetensors==0.2.6 -tensorboard==2.10.1 -tk==0.1.0 -toml==0.10.2 -transformers==4.26.0 -voluptuous==0.13.1 -# for BLIP captioning -fairscale==0.4.13 -requests==2.28.2 -timm==0.6.12 -# tensorflow<2.11 -huggingface-hub==0.12.0; sys_platform != 'darwin' -huggingface-hub==0.13.0; sys_platform == 'darwin' -tensorflow==2.10.1; sys_platform != 'darwin' -# For locon support -lycoris_lora==0.1.2 -# for kohya_ss library -. \ No newline at end of file diff --git a/setup.sh b/setup.sh new file mode 100755 index 0000000..6cba233 --- /dev/null +++ b/setup.sh @@ -0,0 +1,96 @@ +#!/usr/bin/env bash + +if [[ "$OSTYPE" == "linux-gnu"* ]]; then + # Check if root or sudo + root=true + if [ "$EUID" -ne 0 ]; then + root=false + fi + + distro="$(python -mplatform)" + if "$distro" | grep -qi "Ubuntu"; then + echo "Ubuntu detected." + echo "Installing Python TK if not found on the system." + if [ ! $(dpkg-query -W -f='${Status}' python3-tk 2>/dev/null | grep -c "ok installed") -eq 0 ]; then + if [ root = true ]; then + apt-get install python3-tk + else + echo "This script needs to be run as root or via sudo to install packages." + exit 1 + fi + else + echo "Python TK found! Skipping install!" + fi + elif "$distro" | grep -Eqi "Fedora|CentOS|Redhat"; then + if ! rpm -qa | grep -qi python3-tkinter; then + if [ root = true ]; then + dnf install python3-tkinter + else + echo "This script needs to be run as root or via sudo to install packages." + exit 1 + fi + fi + fi + + python3 -m venv venv + source venv/bin/activate + pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116 + pip install --use-pep517 --upgrade -r requirements.txt + pip install -U -I --no-deps https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/linux/xformers-0.0.14.dev0-cp310-cp310-linux_x86_64.whl + accelerate config + + echo -e "Setup finished! Run \e[0;92m./gui.sh\e[0m to start." +elif [[ "$OSTYPE" == "darwin"* ]]; then + # The initial setup script to prep the environment on macOS + # xformers has been omitted as that is for Nvidia GPUs only + + if ! command -v brew >/dev/null; then + echo "Please install homebrew first. This is a requirement for the remaining setup." + echo "You can find that here: https://brew.sh" + exit 1 + fi + + # Install base python packages + echo "Installing Python 3.10 if not found." + if ! brew ls --versions python@3.10 >/dev/null; then + brew install python@3.10 + else + echo "Python 3.10 found!" + fi + echo "Installing Python-TK 3.10 if not found." + if ! brew ls --versions python-tk@3.10 >/dev/null; then + brew install python-tk@3.10 + else + echo "Python Tkinter 3.10 found!" + fi + + if command -v python3.10 >/dev/null; then + python3.10 -m venv venv + source venv/bin/activate + + # DEBUG ONLY + #pip install pydevd-pycharm~=223.8836.43 + + # Tensorflow installation + if wget https://github.com/apple/tensorflow_macos/releases/download/v0.1alpha3/tensorflow_macos-0.1a3-cp38-cp38-macosx_11_0_arm64.whl /tmp; then + python -m pip install tensorflow==0.1a3 -f https://github.com/apple/tensorflow_macos/releases/download/v0.1alpha3/tensorflow_macos-0.1a3-cp38-cp38-macosx_11_0_arm64.whl + rm -f /tmp/tensorflow_macos-0.1a3-cp38-cp38-macosx_11_0_arm64.whl + fi + + pip install torch==2.0.0 torchvision==0.15.1 -f https://download.pytorch.org/whl/cpu/torch_stable.html + python -m pip install --use-pep517 --upgrade -r requirements.txt + accelerate config + echo -e "Setup finished! Run ./gui.sh to start." + else + echo "Python not found. Please ensure you install Python." + echo "The brew command for Python 3.10 is: brew install python@3.10" + exit 1 + fi +elif [[ "$OSTYPE" == "cygwin" ]]; then + # Cygwin is a standalone suite of Linux utilies on Windows + echo "This hasn't been validated on cygwin yet." +elif [[ "$OSTYPE" == "msys" ]]; then + # MinGW has the msys environment which is a standalone suite of Linux utilies on Windows + # "git bash" on Windows may also be detected as msys. + echo "This hasn't been validated in msys (mingw) on Windows yet." +fi diff --git a/ubuntu_setup.sh b/ubuntu_setup.sh deleted file mode 100755 index 1431155..0000000 --- a/ubuntu_setup.sh +++ /dev/null @@ -1,12 +0,0 @@ -#!/bin/bash -echo installing tk -sudo apt install python3-tk -python3 -m venv venv -source venv/bin/activate -pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116 -pip install --use-pep517 --upgrade -r requirements.txt -pip install -U -I --no-deps https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/linux/xformers-0.0.14.dev0-cp310-cp310-linux_x86_64.whl - -accelerate config - -echo -e "setup finished! run \e[0;92m./gui.sh\e[0m to start" diff --git a/upgrade.sh b/upgrade.sh index f01e7b7..8ed545f 100755 --- a/upgrade.sh +++ b/upgrade.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # Check if there are any changes that need to be committed if [[ -n $(git status --short) ]]; then diff --git a/upgrade_macos.sh b/upgrade_macos.sh deleted file mode 100755 index 2e26c55..0000000 --- a/upgrade_macos.sh +++ /dev/null @@ -1,16 +0,0 @@ -#!/bin/bash - -# Check if there are any changes that need to be committed -if [[ -n $(git status --short) ]]; then - echo "There are changes that need to be committed. Please stash or undo your changes before running this script." >&2 - exit 1 -fi - -# Pull the latest changes from the remote repository -git pull - -# Activate the virtual environment -source venv/bin/activate - -# Upgrade the required packages -pip install --upgrade -r requirements_macos.txt From 8168e36326b11d1ce40c0009ed05ad69c048d013 Mon Sep 17 00:00:00 2001 From: jstayco <127801635+jstayco@users.noreply.github.com> Date: Mon, 27 Mar 2023 18:58:03 -0700 Subject: [PATCH 02/44] Update setup.sh Small comment added to clarify the purpose of the script for casual viewers. --- setup.sh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/setup.sh b/setup.sh index 6cba233..87985ad 100755 --- a/setup.sh +++ b/setup.sh @@ -1,5 +1,7 @@ #!/usr/bin/env bash +# This file will be the host environment setup file for all operating systems other than base Windows. + if [[ "$OSTYPE" == "linux-gnu"* ]]; then # Check if root or sudo root=true From c79181b7ea06732575b2f529387573eacdfcb8ea Mon Sep 17 00:00:00 2001 From: jstayco <127801635+jstayco@users.noreply.github.com> Date: Mon, 27 Mar 2023 18:58:30 -0700 Subject: [PATCH 03/44] Update gui.sh Removed an unnecessary debug line. --- gui.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gui.sh b/gui.sh index bd7078f..4fa2e35 100755 --- a/gui.sh +++ b/gui.sh @@ -2,7 +2,7 @@ # Activate the virtual environment source ./venv/bin/activate -python -V + # If the requirements are validated, run the kohya_gui.py script with the command-line arguments if python tools/validate_requirements.py; then python kohya_gui.py "$@" From 8bf3ecfd1f3725cfa442bd4ce720f81dd378ef4d Mon Sep 17 00:00:00 2001 From: JSTayco Date: Tue, 28 Mar 2023 15:52:16 -0700 Subject: [PATCH 04/44] Update setup.sh Linux distribution detection is much more robust. We also now include the Linux distribution family. That should help downstream Linux distributions as well. For example, Manjaro will now be detected as arch and Mint Linux will be detected as Ubuntu. --- setup.sh | 80 ++++++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 78 insertions(+), 2 deletions(-) diff --git a/setup.sh b/setup.sh index 87985ad..60a3c08 100755 --- a/setup.sh +++ b/setup.sh @@ -9,8 +9,56 @@ if [[ "$OSTYPE" == "linux-gnu"* ]]; then root=false fi - distro="$(python -mplatform)" - if "$distro" | grep -qi "Ubuntu"; then + get_distro_name() { + local line + if [ -f /etc/os-release ]; then + # We search for the line starting with ID= + # Then we remove the ID= prefix to get the name itself + line="$(grep -Ei '^ID=' /etc/os-release)" + line=${line##*=} + echo "$line" + return 0 + elif command -v python >/dev/null; then + line="$(python -mplatform)" + echo "$line" + return 0 + elif command -v python3 >/dev/null; then + line="$(python3 -mplatform)" + echo "$line" + return 0 + else + line="None" + echo "$line" + return 1 + fi + } + + get_distro_family() { + local line + if [ -f /etc/os-release ]; then + # We search for the line starting with ID_LIKE= + # Then we remove the ID_LIKE= prefix to get the name itself + # This is the "type" of distro. For example, Ubuntu returns "debian". + if grep -Eiq '^ID_LIKE=' /etc/os-release >/dev/null; then + line="$(grep -Ei '^ID_LIKE=' /etc/os-release)" + line=${line##*=} + echo "$line" + return 0 + else + line="None" + echo "$line" + return 1 + fi + else + line="None" + echo "$line" + return 1 + fi + } + + distro=get_distro_name + family=get_distro_family + if "$distro" | grep -qi "Ubuntu" || "$family" | grep -qi "Ubuntu"; then echo "Ubuntu detected." echo "Installing Python TK if not found on the system." if [ ! $(dpkg-query -W -f='${Status}' python3-tk 2>/dev/null | grep -c "ok installed") -eq 0 ]; then @@ -32,6 +80,32 @@ if [[ "$OSTYPE" == "linux-gnu"* ]]; then exit 1 fi fi + elif "$distro" | grep -Eqi "arch" || "$family" | grep -qi "arch"; then + if ! pacman -Qi tk >/dev/null; then + if [ root = true ]; then + pacman -S tk + else + echo "This script needs to be run as root or via sudo to install packages." + exit 1 + fi + fi + elif "$distro" | grep -Eqi "opensuse" || "$family" | grep -qi "opensuse"; then + if ! rpm -qa | grep -qi python-tk; then + if [ root = true ]; then + zypper install python-tk + else + echo "This script needs to be run as root or via sudo to install packages." + exit 1 + fi + fi + elif [ "$distro" = "None" ] || [ "$family" = "None" ]; then + if [ "$distro" = "None" ]; then + echo "We could not detect your distribution of Linux. Please file a bug report on github with the contents of your /etc/os-release file." + fi + + if [ "$family" = "None" ]; then + echo "We could not detect the family of your Linux distribution. Please file a bug report on github with the contents of your /etc/os-release file." + fi fi python3 -m venv venv @@ -49,6 +123,8 @@ elif [[ "$OSTYPE" == "darwin"* ]]; then if ! command -v brew >/dev/null; then echo "Please install homebrew first. This is a requirement for the remaining setup." echo "You can find that here: https://brew.sh" + #shellcheck disable=SC2016 + echo 'The "brew" command should be in $PATH to be detected.' exit 1 fi From 51e843fe2ed68edbe38ef9cb49faaae61ab70578 Mon Sep 17 00:00:00 2001 From: JSTayco Date: Tue, 28 Mar 2023 16:02:20 -0700 Subject: [PATCH 05/44] Update setup.sh Fixed install commands to not require user input. --- setup.sh | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/setup.sh b/setup.sh index 60a3c08..9a2602e 100755 --- a/setup.sh +++ b/setup.sh @@ -63,7 +63,7 @@ if [[ "$OSTYPE" == "linux-gnu"* ]]; then echo "Installing Python TK if not found on the system." if [ ! $(dpkg-query -W -f='${Status}' python3-tk 2>/dev/null | grep -c "ok installed") -eq 0 ]; then if [ root = true ]; then - apt-get install python3-tk + apt install -y python3-tk else echo "This script needs to be run as root or via sudo to install packages." exit 1 @@ -74,7 +74,7 @@ if [[ "$OSTYPE" == "linux-gnu"* ]]; then elif "$distro" | grep -Eqi "Fedora|CentOS|Redhat"; then if ! rpm -qa | grep -qi python3-tkinter; then if [ root = true ]; then - dnf install python3-tkinter + dnf install python3-tkinter -y else echo "This script needs to be run as root or via sudo to install packages." exit 1 @@ -83,7 +83,7 @@ if [[ "$OSTYPE" == "linux-gnu"* ]]; then elif "$distro" | grep -Eqi "arch" || "$family" | grep -qi "arch"; then if ! pacman -Qi tk >/dev/null; then if [ root = true ]; then - pacman -S tk + pacman --noconfirm -S tk else echo "This script needs to be run as root or via sudo to install packages." exit 1 @@ -92,7 +92,7 @@ if [[ "$OSTYPE" == "linux-gnu"* ]]; then elif "$distro" | grep -Eqi "opensuse" || "$family" | grep -qi "opensuse"; then if ! rpm -qa | grep -qi python-tk; then if [ root = true ]; then - zypper install python-tk + zypper install -y python-tk else echo "This script needs to be run as root or via sudo to install packages." exit 1 From 307c433254af05735e70d0db174b6a31523878e2 Mon Sep 17 00:00:00 2001 From: JSTayco Date: Tue, 28 Mar 2023 18:07:18 -0700 Subject: [PATCH 06/44] The runpod update! Script now accepts long and short arguments. Script should now help with a runpod environment and ensure apt cache is updated before package install attempt. --- README.md | 2 ++ setup.sh | 95 ++++++++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 96 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 92cfecd..0a866b2 100644 --- a/README.md +++ b/README.md @@ -63,6 +63,8 @@ In the terminal, run git clone https://github.com/bmaltais/kohya_ss.git cd kohya_ss # May need to chmod +x ./setup.sh if you're on a machine with stricter security. +# There are additional options if needed for a runpod environment. +# Call 'setup.sh -h' or 'setup.sh --help' for more information. ./setup.sh ``` diff --git a/setup.sh b/setup.sh index 9a2602e..0fa1e4c 100755 --- a/setup.sh +++ b/setup.sh @@ -2,6 +2,51 @@ # This file will be the host environment setup file for all operating systems other than base Windows. +display_help() { + cat </dev/null | grep -c "ok installed") -eq 0 ]; then if [ root = true ]; then - apt install -y python3-tk + apt update -y && apt install -y python3-tk else echo "This script needs to be run as root or via sudo to install packages." exit 1 @@ -113,6 +183,29 @@ if [[ "$OSTYPE" == "linux-gnu"* ]]; then pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116 pip install --use-pep517 --upgrade -r requirements.txt pip install -U -I --no-deps https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/linux/xformers-0.0.14.dev0-cp310-cp310-linux_x86_64.whl + + # We need this extra package and setup if we are running in a runpod + if env_var_exists RUNPOD_POD_ID || env_var_exists RUNPOD_API_KEY; then + pip install tensorrt + ln -s "$VENV_DIR/lib/python3.10/site-packages/tensorrt/libnvinfer_plugin.so.8" \ + "$VENV_DIR/lib/python3.10/site-packages/tensorrt/libnvinfer_plugin.so.7" + ln -s "$VENV_DIR/lib/python3.10/site-packages/tensorrt/libnvinfer.so.8" \ + "$VENV_DIR/lib/python3.10/site-packages/tensorrt/libnvinfer.so.7" + ln -s "$VENV_DIR/lib/python3.10/site-packages/nvidia/cuda_runtime/lib/libcudart.so.12" \ + "$VENV_DIR/lib/python3.10/site-packages/nvidia/cuda_runtime/lib/libcudart.so.11.0" + + export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:$VENV_DIR/lib/python3.10/site-packages/tensorrt/" + export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:$VENV_DIR/lib/python3.10/site-packages/nvidia/cuda_runtime/lib/" + + # This is a non-interactive environment, so just directly call gui.sh after all setup steps are complete. + if command -v bash >/dev/null; then + bash "$DIR"/gui.sh + else + # This shouldn't happen, but we're going to try to help. + sh "$DIR"/gui.sh + fi + fi + accelerate config echo -e "Setup finished! Run \e[0;92m./gui.sh\e[0m to start." From 0b5a418b39586375c1380315bbda26d7834a2750 Mon Sep 17 00:00:00 2001 From: JSTayco Date: Wed, 29 Mar 2023 00:47:03 -0700 Subject: [PATCH 07/44] More options, better var handlng. Changed runpod detection to a variable to simplify maintenance and provide a mechanism for the user to force a runpod installation. Also, updated help message to acount for the change. --- setup.sh | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/setup.sh b/setup.sh index 0fa1e4c..9af9320 100755 --- a/setup.sh +++ b/setup.sh @@ -10,22 +10,25 @@ The following options are useful in a runpod environment, but will not affect a local machine install. Usage: - setup.sh -b dev -d /workspace/kohya_ss - setup.sh --branch=dev --dir=/workspace/kohya_ss + setup.sh -b dev -d /workspace/kohya_ss -g https://mycustom.repo.tld/custom_fork.git + setup.sh --branch=dev --dir=/workspace/kohya_ss --git-repo=https://mycustom.repo.tld/custom_fork.git Options: -b BRANCH, --branch=BRANCH Select which branch of kohya to checkout on new installs. -d DIR, --dir=DIR The full path you want kohya_ss installed to. - -h, --help Show this screen. + -g, --git_repo You can optionally provide a git repo to checkout for runpod installation. Useful for custom forks. + -r, --runpod Forces a runpod installation. Useful if detection fails for any reason. + -h, --help Show this screen. EOF } # Variables defined before the getopts loop, so we have sane default values. DIR="/workspace/kohya_ss" BRANCH="dev" -REPO="https://github.com/bmaltais/kohya_ss.git" +GIT_REPO="https://github.com/bmaltais/kohya_ss.git" +RUNPOD=false -while getopts "b:d:-:" opt; do +while getopts "b:d:g:r-:" opt; do # support long options: https://stackoverflow.com/a/28466267/519360 if [ "$opt" = "-" ]; then # long option: reformulate OPT and OPTARG opt="${OPTARG%%=*}" # extract long option name @@ -36,7 +39,8 @@ while getopts "b:d:-:" opt; do # note the leading colon b | branch) BRANCH="$OPTARG" ;; d | dir) DIR="$OPTARG" ;; - r | repo) REPO="$OPTARG" ;; + g | git-repo) GIT_REPO="$OPTARG" ;; + r | runpod) RUNPOD=true ;; h) display_help && exit 0 ;; *) display_help && exit 0 ;; esac @@ -54,6 +58,10 @@ if [[ "$OSTYPE" == "linux-gnu"* ]]; then root=false fi + if env_var_exists RUNPOD_POD_ID || env_var_exists RUNPOD_API_KEY; then + RUNPOD=true + fi + env_var_exists() { local env_var= env_var=$(declare -p "$1") @@ -110,17 +118,17 @@ if [[ "$OSTYPE" == "linux-gnu"* ]]; then } # This is the pre-install work for a kohya installation on a runpod - if env_var_exists RUNPOD_POD_ID || env_var_exists RUNPOD_API_KEY; then + if [ "$RUNPOD" = true ]; then if [ -d "$VENV_DIR" ]; then echo "Pre-existing installation on a runpod detected." export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:"$VENV_DIR"/lib/python3.10/site-packages/tensorrt/ export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:"$VENV_DIR"/lib/python3.10/site-packages/nvidia/cuda_runtime/lib/ cd "$DIR" || exit 1 - sed -i "s/interface.launch(\*\*launch_kwargs)/interface.launch(\*\*launch_kwargs,share=True)/g" kohya_gui.py + sed -i "s/interface.launch(\*\*launch_kwargs)/interface.launch(\*\*launch_kwargs,share=True)/g" ./kohya_gui.py else echo "Clean installation on a runpod detected." cd "$BASE_DIR" || exit 1 - git clone "$REPO" + git clone "$GIT_REPO" cd "$DIR" || exit 1 git checkout "$BRANCH" fi @@ -185,7 +193,7 @@ if [[ "$OSTYPE" == "linux-gnu"* ]]; then pip install -U -I --no-deps https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/linux/xformers-0.0.14.dev0-cp310-cp310-linux_x86_64.whl # We need this extra package and setup if we are running in a runpod - if env_var_exists RUNPOD_POD_ID || env_var_exists RUNPOD_API_KEY; then + if [ "$RUNPOD" = true ]; then pip install tensorrt ln -s "$VENV_DIR/lib/python3.10/site-packages/tensorrt/libnvinfer_plugin.so.8" \ "$VENV_DIR/lib/python3.10/site-packages/tensorrt/libnvinfer_plugin.so.7" From 981f634c31ae3ff8fc1bd23a1e87084d583bf8f9 Mon Sep 17 00:00:00 2001 From: jstayco <127801635+jstayco@users.noreply.github.com> Date: Wed, 29 Mar 2023 09:55:04 -0700 Subject: [PATCH 08/44] Removed superfluous comment --- setup.sh | 1 - 1 file changed, 1 deletion(-) diff --git a/setup.sh b/setup.sh index 9af9320..afaed56 100755 --- a/setup.sh +++ b/setup.sh @@ -36,7 +36,6 @@ while getopts "b:d:g:r-:" opt; do OPTARG="${OPTARG#=}" # if long option argument, remove assigning `=` fi case $opt in - # note the leading colon b | branch) BRANCH="$OPTARG" ;; d | dir) DIR="$OPTARG" ;; g | git-repo) GIT_REPO="$OPTARG" ;; From 20c525bd59c282855310b9f34b506e4ee0fc6cd8 Mon Sep 17 00:00:00 2001 From: JSTayco Date: Wed, 29 Mar 2023 10:53:15 -0700 Subject: [PATCH 09/44] Enable more arbitrary install locations Default to master branch (user can override with an argument), update the repo if we find a git folder but no venv folder (indicating blank env), rename BASE_DIR to PARENT_DIR to be more obvious, enable PARENT_DIR to account for an arbitrary amount of folders. --- setup.sh | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/setup.sh b/setup.sh index 9af9320..7965fbd 100755 --- a/setup.sh +++ b/setup.sh @@ -24,7 +24,7 @@ EOF # Variables defined before the getopts loop, so we have sane default values. DIR="/workspace/kohya_ss" -BRANCH="dev" +BRANCH="master" GIT_REPO="https://github.com/bmaltais/kohya_ss.git" RUNPOD=false @@ -48,7 +48,7 @@ done shift $((OPTIND - 1)) # This must be set after the getopts loop to account for $DIR changes. -BASE_DIR="$(echo "$DIR" | cut -d "/" -f2)" +PARENT_DIR="$(dirname "${DIR}")" VENV_DIR="$DIR/venv" if [[ "$OSTYPE" == "linux-gnu"* ]]; then @@ -127,10 +127,18 @@ if [[ "$OSTYPE" == "linux-gnu"* ]]; then sed -i "s/interface.launch(\*\*launch_kwargs)/interface.launch(\*\*launch_kwargs,share=True)/g" ./kohya_gui.py else echo "Clean installation on a runpod detected." - cd "$BASE_DIR" || exit 1 - git clone "$GIT_REPO" - cd "$DIR" || exit 1 - git checkout "$BRANCH" + cd "$PARENT_DIR" || exit 1 + if [ ! -d "$DIR/.git" ]; then + echo "Cloning $GIT_REPO." + git clone "$GIT_REPO" + cd "$DIR" || exit 1 + git checkout "$BRANCH" + else + cd "$DIR" || exit 1 + echo "git repo detected. Attempting tp update repo instead." + echo "Updating: $GIT_REPO" + git pull "$GIT_REPO" + fi fi fi From 9a9976bb1c71af7c4dea2f5c1fc5653d1771b8ad Mon Sep 17 00:00:00 2001 From: JSTayco Date: Wed, 29 Mar 2023 11:11:10 -0700 Subject: [PATCH 10/44] Move env_var_exist first usage after the function is defined --- setup.sh | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/setup.sh b/setup.sh index 50ab23a..be8cb56 100755 --- a/setup.sh +++ b/setup.sh @@ -57,10 +57,6 @@ if [[ "$OSTYPE" == "linux-gnu"* ]]; then root=false fi - if env_var_exists RUNPOD_POD_ID || env_var_exists RUNPOD_API_KEY; then - RUNPOD=true - fi - env_var_exists() { local env_var= env_var=$(declare -p "$1") @@ -116,6 +112,10 @@ if [[ "$OSTYPE" == "linux-gnu"* ]]; then fi } + if env_var_exists RUNPOD_POD_ID || env_var_exists RUNPOD_API_KEY; then + RUNPOD=true + fi + # This is the pre-install work for a kohya installation on a runpod if [ "$RUNPOD" = true ]; then if [ -d "$VENV_DIR" ]; then From 403c2b051ce14a38b7996b5b229283c2b546817e Mon Sep 17 00:00:00 2001 From: JSTayco Date: Wed, 29 Mar 2023 11:42:36 -0700 Subject: [PATCH 11/44] More output and fixed package detection on Ubuntu Ubuntu dpkg detection fixed and more output for detected distros. --- setup.sh | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/setup.sh b/setup.sh index be8cb56..4bf0ca4 100755 --- a/setup.sh +++ b/setup.sh @@ -143,10 +143,12 @@ if [[ "$OSTYPE" == "linux-gnu"* ]]; then distro=get_distro_name family=get_distro_family + + echo "Installing Python TK if not found on the system." + if "$distro" | grep -qi "Ubuntu" || "$family" | grep -qi "Ubuntu"; then echo "Ubuntu detected." - echo "Installing Python TK if not found on the system." - if [ ! $(dpkg-query -W -f='${Status}' python3-tk 2>/dev/null | grep -c "ok installed") -eq 0 ]; then + if [ $(dpkg-query -W -f='${Status}' python3-tk 2>/dev/null | grep -c "ok installed") = 0 ]; then if [ root = true ]; then apt update -y && apt install -y python3-tk else @@ -157,6 +159,7 @@ if [[ "$OSTYPE" == "linux-gnu"* ]]; then echo "Python TK found! Skipping install!" fi elif "$distro" | grep -Eqi "Fedora|CentOS|Redhat"; then + echo "Redhat or Redhat base detected." if ! rpm -qa | grep -qi python3-tkinter; then if [ root = true ]; then dnf install python3-tkinter -y @@ -166,6 +169,7 @@ if [[ "$OSTYPE" == "linux-gnu"* ]]; then fi fi elif "$distro" | grep -Eqi "arch" || "$family" | grep -qi "arch"; then + echo "Arch Linux or Arch base detected." if ! pacman -Qi tk >/dev/null; then if [ root = true ]; then pacman --noconfirm -S tk @@ -175,6 +179,7 @@ if [[ "$OSTYPE" == "linux-gnu"* ]]; then fi fi elif "$distro" | grep -Eqi "opensuse" || "$family" | grep -qi "opensuse"; then + echo "OpenSUSE detected." if ! rpm -qa | grep -qi python-tk; then if [ root = true ]; then zypper install -y python-tk From a99ae6cff769ba5c500206210c02b64d863cf8ba Mon Sep 17 00:00:00 2001 From: JSTayco Date: Wed, 29 Mar 2023 12:07:00 -0700 Subject: [PATCH 12/44] Fixed root detection and made more robust --- setup.sh | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/setup.sh b/setup.sh index 4bf0ca4..6bdb21b 100755 --- a/setup.sh +++ b/setup.sh @@ -52,9 +52,13 @@ VENV_DIR="$DIR/venv" if [[ "$OSTYPE" == "linux-gnu"* ]]; then # Check if root or sudo - root=true - if [ "$EUID" -ne 0 ]; then - root=false + root=false + if [ "$EUID" = 0 ]; then + root=true + elif command -v id >/dev/null && [ "$(id -u)" = 0 ]; then + root=true + elif [ "$UID" = 0 ]; then + root=true fi env_var_exists() { @@ -149,7 +153,7 @@ if [[ "$OSTYPE" == "linux-gnu"* ]]; then if "$distro" | grep -qi "Ubuntu" || "$family" | grep -qi "Ubuntu"; then echo "Ubuntu detected." if [ $(dpkg-query -W -f='${Status}' python3-tk 2>/dev/null | grep -c "ok installed") = 0 ]; then - if [ root = true ]; then + if [ "$root" = true ]; then apt update -y && apt install -y python3-tk else echo "This script needs to be run as root or via sudo to install packages." @@ -161,7 +165,7 @@ if [[ "$OSTYPE" == "linux-gnu"* ]]; then elif "$distro" | grep -Eqi "Fedora|CentOS|Redhat"; then echo "Redhat or Redhat base detected." if ! rpm -qa | grep -qi python3-tkinter; then - if [ root = true ]; then + if [ "$root" = true ]; then dnf install python3-tkinter -y else echo "This script needs to be run as root or via sudo to install packages." @@ -171,7 +175,7 @@ if [[ "$OSTYPE" == "linux-gnu"* ]]; then elif "$distro" | grep -Eqi "arch" || "$family" | grep -qi "arch"; then echo "Arch Linux or Arch base detected." if ! pacman -Qi tk >/dev/null; then - if [ root = true ]; then + if [ "$root" = true ]; then pacman --noconfirm -S tk else echo "This script needs to be run as root or via sudo to install packages." @@ -181,7 +185,7 @@ if [[ "$OSTYPE" == "linux-gnu"* ]]; then elif "$distro" | grep -Eqi "opensuse" || "$family" | grep -qi "opensuse"; then echo "OpenSUSE detected." if ! rpm -qa | grep -qi python-tk; then - if [ root = true ]; then + if [ "$root" = true ]; then zypper install -y python-tk else echo "This script needs to be run as root or via sudo to install packages." From ee1ac64034322f0732087c4e28f521a563cd7ba4 Mon Sep 17 00:00:00 2001 From: JSTayco Date: Wed, 29 Mar 2023 12:23:43 -0700 Subject: [PATCH 13/44] Added setup.sh -h to the README Some environments can't run setup.sh interactively. This will be convenient for those users. --- README.md | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/README.md b/README.md index 0a866b2..1ffc7e2 100644 --- a/README.md +++ b/README.md @@ -68,6 +68,26 @@ cd kohya_ss ./setup.sh ``` +Setup.sh help included here: + +```bash +Kohya_SS Installation Script for POSIX operating systems. + +The following options are useful in a runpod environment, +but will not affect a local machine install. + +Usage: + setup.sh -b dev -d /workspace/kohya_ss -g https://mycustom.repo.tld/custom_fork.git + setup.sh --branch=dev --dir=/workspace/kohya_ss --git-repo=https://mycustom.repo.tld/custom_fork.git + +Options: + -b BRANCH, --branch=BRANCH Select which branch of kohya to checkout on new installs. + -d DIR, --dir=DIR The full path you want kohya_ss installed to. + -g, --git_repo You can optionally provide a git repo to checkout for runpod installation. Useful for custom forks. + -r, --runpod Forces a runpod installation. Useful if detection fails for any reason. + -h, --help Show this screen. +``` + During the accelerate config screen after running the script answer "This machine", "None", "No" for the remaining questions. These are the same answers as the Windows install. From 71c9459db23a02e210607d0070f50082dfdee909 Mon Sep 17 00:00:00 2001 From: JSTayco Date: Wed, 29 Mar 2023 13:26:16 -0700 Subject: [PATCH 14/44] Non-interactive mode, new warning, new default config for accel We now warn the user with a nicely formatted message if they have less than 10Gb space free and offer a 10s window to cancel operations. We now try to configure accelerate with no user input by default, but allow an override. --- README.md | 3 +- config_files/accelerate/default_config.yaml | 22 +++++++++ setup.sh | 53 +++++++++++++++++++-- 3 files changed, 74 insertions(+), 4 deletions(-) create mode 100644 config_files/accelerate/default_config.yaml diff --git a/README.md b/README.md index 1ffc7e2..9a04735 100644 --- a/README.md +++ b/README.md @@ -85,10 +85,11 @@ Options: -d DIR, --dir=DIR The full path you want kohya_ss installed to. -g, --git_repo You can optionally provide a git repo to checkout for runpod installation. Useful for custom forks. -r, --runpod Forces a runpod installation. Useful if detection fails for any reason. + -i, --interactive Interactively configure accelerate instead of using default config file. -h, --help Show this screen. ``` -During the accelerate config screen after running the script answer "This machine", "None", "No" for the remaining questions. +If you are using the interactive mode, our default values for the accelerate config screen after running the script answer "This machine", "None", "No" for the remaining questions. These are the same answers as the Windows install. ### Windows diff --git a/config_files/accelerate/default_config.yaml b/config_files/accelerate/default_config.yaml new file mode 100644 index 0000000..a31ddd0 --- /dev/null +++ b/config_files/accelerate/default_config.yaml @@ -0,0 +1,22 @@ +command_file: null +commands: null +compute_environment: LOCAL_MACHINE +deepspeed_config: {} +distributed_type: 'NO' +downcast_bf16: 'no' +dynamo_backend: 'NO' +fsdp_config: {} +gpu_ids: all +machine_rank: 0 +main_process_ip: null +main_process_port: null +main_training_function: main +megatron_lm_config: {} +mixed_precision: 'no' +num_machines: 1 +num_processes: 1 +rdzv_backend: static +same_network: true +tpu_name: null +tpu_zone: null +use_cpu: false diff --git a/setup.sh b/setup.sh index 6bdb21b..83a92d8 100755 --- a/setup.sh +++ b/setup.sh @@ -18,6 +18,7 @@ Options: -d DIR, --dir=DIR The full path you want kohya_ss installed to. -g, --git_repo You can optionally provide a git repo to checkout for runpod installation. Useful for custom forks. -r, --runpod Forces a runpod installation. Useful if detection fails for any reason. + -i, --interactive Interactively configure accelerate instead of using default config file. -h, --help Show this screen. EOF } @@ -27,8 +28,9 @@ DIR="/workspace/kohya_ss" BRANCH="master" GIT_REPO="https://github.com/bmaltais/kohya_ss.git" RUNPOD=false +INTERACTIVE=false -while getopts "b:d:g:r-:" opt; do +while getopts "b:d:g:ir-:" opt; do # support long options: https://stackoverflow.com/a/28466267/519360 if [ "$opt" = "-" ]; then # long option: reformulate OPT and OPTARG opt="${OPTARG%%=*}" # extract long option name @@ -39,6 +41,7 @@ while getopts "b:d:g:r-:" opt; do b | branch) BRANCH="$OPTARG" ;; d | dir) DIR="$OPTARG" ;; g | git-repo) GIT_REPO="$OPTARG" ;; + i | interactive) INTERACTIVE=true ;; r | runpod) RUNPOD=true ;; h) display_help && exit 0 ;; *) display_help && exit 0 ;; @@ -116,10 +119,29 @@ if [[ "$OSTYPE" == "linux-gnu"* ]]; then fi } + # This checks for free space on the installation drive and returns that in Gb. + size_available() { + local FREESPACEINKB="$(df -Pk "$DIR" | sed 1d | grep -v used | awk '{ print $4 "\t" }')" + local FREESPACEINGB=$((FREESPACEINKB / 1024 / 1024)) + echo "$FREESPACEINGB" + } + if env_var_exists RUNPOD_POD_ID || env_var_exists RUNPOD_API_KEY; then RUNPOD=true fi + # Offer a warning and opportunity to cancel the installation if < 10Gb of Free Space detected + if [ "$(size_available)" -lt 10 ]; then + echo "You have less than 10Gb of free space. This installation may fail." + MSGTIMEOUT=10 # In seconds + MESSAGE="Continuing in..." + echo "Press control-c to cancel the installation." + for ((i = $MSGTIMEOUT; i >= 0; i--)); do + printf "\r${MESSAGE} %ss. " "${i}" + sleep 1 + done + fi + # This is the pre-install work for a kohya installation on a runpod if [ "$RUNPOD" = true ]; then if [ -d "$VENV_DIR" ]; then @@ -221,6 +243,33 @@ if [[ "$OSTYPE" == "linux-gnu"* ]]; then export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:$VENV_DIR/lib/python3.10/site-packages/tensorrt/" export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:$VENV_DIR/lib/python3.10/site-packages/nvidia/cuda_runtime/lib/" + # Attempt to non-interactively install a default accelerate config file unless specified otherwise. + # Documentation for order of precedence locations for configuration file for automated installation: + # https://huggingface.co/docs/accelerate/basic_tutorials/launch#custom-configurations + if [ "$INTERACTIVE" = true ]; then + accelerate config + else + if env_var_exists HF_HOME; then + if [ ! -f "$HF_HOME/accelerate/default_config.yaml" ]; then + mkdir -p "$HF_HOME/accelerate/" && + cp ./config_files/accelerate/default_config.yaml "$HF_HOME/accelerate/default_config.yaml" + fi + elif env_var_exists XDG_CACHE_HOME; then + if [ ! -f "$XDG_CACHE_HOME/huggingface/accelerate" ]; then + mkdir -p "$XDG_CACHE_HOME/huggingface/accelerate" && + cp ./config_files/accelerate/default_config.yaml "$XDG_CACHE_HOME/huggingface/accelerate/default_config.yaml" + fi + elif env_var_exists HOME; then + if [ ! -f "$HOME/.cache/huggingface/accelerate" ]; then + mkdir -p "$HOME/.cache/huggingface/accelerate" && + cp ./config_files/accelerate/default_config.yaml "$HOME/.cache/huggingface/accelerate/default_config.yaml" + fi + else + echo "Could not place the accelerate configuration file. Please configure manually." + accelerate config + fi + fi + # This is a non-interactive environment, so just directly call gui.sh after all setup steps are complete. if command -v bash >/dev/null; then bash "$DIR"/gui.sh @@ -230,8 +279,6 @@ if [[ "$OSTYPE" == "linux-gnu"* ]]; then fi fi - accelerate config - echo -e "Setup finished! Run \e[0;92m./gui.sh\e[0m to start." elif [[ "$OSTYPE" == "darwin"* ]]; then # The initial setup script to prep the environment on macOS From 094528c7cd449ef450053ccdce885a5770462bfe Mon Sep 17 00:00:00 2001 From: JSTayco Date: Wed, 29 Mar 2023 13:35:10 -0700 Subject: [PATCH 15/44] Inform user of config copy Better inform the user what is happening. Upon successful configuration file copy operation, we notify the user where that config is located. Also added small sleep step before calling accelerate config to give the user a chance to read the message. --- setup.sh | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/setup.sh b/setup.sh index 83a92d8..564000b 100755 --- a/setup.sh +++ b/setup.sh @@ -252,20 +252,24 @@ if [[ "$OSTYPE" == "linux-gnu"* ]]; then if env_var_exists HF_HOME; then if [ ! -f "$HF_HOME/accelerate/default_config.yaml" ]; then mkdir -p "$HF_HOME/accelerate/" && - cp ./config_files/accelerate/default_config.yaml "$HF_HOME/accelerate/default_config.yaml" + cp ./config_files/accelerate/default_config.yaml "$HF_HOME/accelerate/default_config.yaml" && + echo "Copied accelerate config file to: $HF_HOME/accelerate/default_config.yaml" fi elif env_var_exists XDG_CACHE_HOME; then if [ ! -f "$XDG_CACHE_HOME/huggingface/accelerate" ]; then mkdir -p "$XDG_CACHE_HOME/huggingface/accelerate" && - cp ./config_files/accelerate/default_config.yaml "$XDG_CACHE_HOME/huggingface/accelerate/default_config.yaml" + cp ./config_files/accelerate/default_config.yaml "$XDG_CACHE_HOME/huggingface/accelerate/default_config.yaml" && + echo "Copied accelerate config file to: $XDG_CACHE_HOME/huggingface/accelerate/default_config.yaml" fi elif env_var_exists HOME; then if [ ! -f "$HOME/.cache/huggingface/accelerate" ]; then mkdir -p "$HOME/.cache/huggingface/accelerate" && - cp ./config_files/accelerate/default_config.yaml "$HOME/.cache/huggingface/accelerate/default_config.yaml" + cp ./config_files/accelerate/default_config.yaml "$HOME/.cache/huggingface/accelerate/default_config.yaml" && + echo "Copying accelerate config file to: $HOME/.cache/huggingface/accelerate/default_config.yaml" fi else echo "Could not place the accelerate configuration file. Please configure manually." + sleep 2 accelerate config fi fi From a58b3b616aa382f8c0963d36321b86c418d36502 Mon Sep 17 00:00:00 2001 From: JSTayco Date: Wed, 29 Mar 2023 14:41:17 -0700 Subject: [PATCH 16/44] Add -p switch to expose public URL This adds the ability for an extra switch to expose a public URL. We default to private. --- README.md | 1 + setup.sh | 16 +++++++++++++--- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 9a04735..a81de9b 100644 --- a/README.md +++ b/README.md @@ -84,6 +84,7 @@ Options: -b BRANCH, --branch=BRANCH Select which branch of kohya to checkout on new installs. -d DIR, --dir=DIR The full path you want kohya_ss installed to. -g, --git_repo You can optionally provide a git repo to checkout for runpod installation. Useful for custom forks. + -p, --public Expose public URL in runpod mode. Won't have an effect in other modes. -r, --runpod Forces a runpod installation. Useful if detection fails for any reason. -i, --interactive Interactively configure accelerate instead of using default config file. -h, --help Show this screen. diff --git a/setup.sh b/setup.sh index 564000b..ffa9236 100755 --- a/setup.sh +++ b/setup.sh @@ -17,6 +17,7 @@ Options: -b BRANCH, --branch=BRANCH Select which branch of kohya to checkout on new installs. -d DIR, --dir=DIR The full path you want kohya_ss installed to. -g, --git_repo You can optionally provide a git repo to checkout for runpod installation. Useful for custom forks. + -p, --public Expose public URL in runpod mode. Won't have an effect in other modes. -r, --runpod Forces a runpod installation. Useful if detection fails for any reason. -i, --interactive Interactively configure accelerate instead of using default config file. -h, --help Show this screen. @@ -29,6 +30,7 @@ BRANCH="master" GIT_REPO="https://github.com/bmaltais/kohya_ss.git" RUNPOD=false INTERACTIVE=false +PUBLIC=false while getopts "b:d:g:ir-:" opt; do # support long options: https://stackoverflow.com/a/28466267/519360 @@ -42,6 +44,7 @@ while getopts "b:d:g:ir-:" opt; do d | dir) DIR="$OPTARG" ;; g | git-repo) GIT_REPO="$OPTARG" ;; i | interactive) INTERACTIVE=true ;; + p | public) PUBLIC=true ;; r | runpod) RUNPOD=true ;; h) display_help && exit 0 ;; *) display_help && exit 0 ;; @@ -149,7 +152,6 @@ if [[ "$OSTYPE" == "linux-gnu"* ]]; then export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:"$VENV_DIR"/lib/python3.10/site-packages/tensorrt/ export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:"$VENV_DIR"/lib/python3.10/site-packages/nvidia/cuda_runtime/lib/ cd "$DIR" || exit 1 - sed -i "s/interface.launch(\*\*launch_kwargs)/interface.launch(\*\*launch_kwargs,share=True)/g" ./kohya_gui.py else echo "Clean installation on a runpod detected." cd "$PARENT_DIR" || exit 1 @@ -276,10 +278,18 @@ if [[ "$OSTYPE" == "linux-gnu"* ]]; then # This is a non-interactive environment, so just directly call gui.sh after all setup steps are complete. if command -v bash >/dev/null; then - bash "$DIR"/gui.sh + if [ "$PUBLIC" = false ]; then + bash "$DIR"/gui.sh + else + bash "$DIR"/gui.sh --share + fi else # This shouldn't happen, but we're going to try to help. - sh "$DIR"/gui.sh + if [ "$PUBLIC" = false ]; then + sh "$DIR"/gui.sh + else + sh "$DIR"/gui.sh --share + fi fi fi From 1f67aaa43b2864d9352a6d1283765e8ee46d2101 Mon Sep 17 00:00:00 2001 From: JSTayco Date: Wed, 29 Mar 2023 14:47:10 -0700 Subject: [PATCH 17/44] More robust df check --- setup.sh | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/setup.sh b/setup.sh index ffa9236..8f5344d 100755 --- a/setup.sh +++ b/setup.sh @@ -124,7 +124,19 @@ if [[ "$OSTYPE" == "linux-gnu"* ]]; then # This checks for free space on the installation drive and returns that in Gb. size_available() { - local FREESPACEINKB="$(df -Pk "$DIR" | sed 1d | grep -v used | awk '{ print $4 "\t" }')" + local folder + if [ -d "$DIR" ]; then + folder="$DIR" + elif [ -d "$PARENT_DIR" ]; then + folder="$PARENT_DIR" + elif [ -d "$(echo "$DIR" | cut -d "/" -f2)" ]; then + folder="$(echo "$DIR" | cut -d "/" -f2)" + else + echo "We are assuming a root drive install for space-checking purposes." + folder='/' + fi + + local FREESPACEINKB="$(df -Pk "$folder" | sed 1d | grep -v used | awk '{ print $4 "\t" }')" local FREESPACEINGB=$((FREESPACEINKB / 1024 / 1024)) echo "$FREESPACEINGB" } From c490764c7ee56be2c5e502f21db5cb1f4acfcc11 Mon Sep 17 00:00:00 2001 From: JSTayco Date: Wed, 29 Mar 2023 15:07:04 -0700 Subject: [PATCH 18/44] Simplify environment variable check --- setup.sh | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/setup.sh b/setup.sh index 8f5344d..83bb992 100755 --- a/setup.sh +++ b/setup.sh @@ -67,11 +67,12 @@ if [[ "$OSTYPE" == "linux-gnu"* ]]; then root=true fi + # Checks to see if variable is set and non-empty. env_var_exists() { - local env_var= - env_var=$(declare -p "$1") - if ! [[ -v $1 && $env_var =~ ^declare\ -x ]]; then + if [[ ! -v "$1" ]] || [[ -z "$1" ]]; then return 1 + else + return 0 fi } From dfe96a581d78196e06fbbbcad3850a083ad4b4f1 Mon Sep 17 00:00:00 2001 From: JSTayco Date: Wed, 29 Mar 2023 15:09:42 -0700 Subject: [PATCH 19/44] Remove hard-coded config file path --- setup.sh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/setup.sh b/setup.sh index 83bb992..152eb67 100755 --- a/setup.sh +++ b/setup.sh @@ -267,19 +267,19 @@ if [[ "$OSTYPE" == "linux-gnu"* ]]; then if env_var_exists HF_HOME; then if [ ! -f "$HF_HOME/accelerate/default_config.yaml" ]; then mkdir -p "$HF_HOME/accelerate/" && - cp ./config_files/accelerate/default_config.yaml "$HF_HOME/accelerate/default_config.yaml" && + cp "$DIR/config_files/accelerate/default_config.yaml" "$HF_HOME/accelerate/default_config.yaml" && echo "Copied accelerate config file to: $HF_HOME/accelerate/default_config.yaml" fi elif env_var_exists XDG_CACHE_HOME; then if [ ! -f "$XDG_CACHE_HOME/huggingface/accelerate" ]; then mkdir -p "$XDG_CACHE_HOME/huggingface/accelerate" && - cp ./config_files/accelerate/default_config.yaml "$XDG_CACHE_HOME/huggingface/accelerate/default_config.yaml" && + cp "$DIR/config_files/accelerate/default_config.yaml" "$XDG_CACHE_HOME/huggingface/accelerate/default_config.yaml" && echo "Copied accelerate config file to: $XDG_CACHE_HOME/huggingface/accelerate/default_config.yaml" fi elif env_var_exists HOME; then if [ ! -f "$HOME/.cache/huggingface/accelerate" ]; then mkdir -p "$HOME/.cache/huggingface/accelerate" && - cp ./config_files/accelerate/default_config.yaml "$HOME/.cache/huggingface/accelerate/default_config.yaml" && + cp "$DIR/config_files/accelerate/default_config.yaml" "$HOME/.cache/huggingface/accelerate/default_config.yaml" && echo "Copying accelerate config file to: $HOME/.cache/huggingface/accelerate/default_config.yaml" fi else From 38aab1c30a7038cb38d16f1e1643777fe733f463 Mon Sep 17 00:00:00 2001 From: jstayco <127801635+jstayco@users.noreply.github.com> Date: Wed, 29 Mar 2023 15:30:41 -0700 Subject: [PATCH 20/44] Added forgotten getopt letter --- setup.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.sh b/setup.sh index 152eb67..6cf0a5b 100755 --- a/setup.sh +++ b/setup.sh @@ -32,7 +32,7 @@ RUNPOD=false INTERACTIVE=false PUBLIC=false -while getopts "b:d:g:ir-:" opt; do +while getopts "b:d:g:ipr-:" opt; do # support long options: https://stackoverflow.com/a/28466267/519360 if [ "$opt" = "-" ]; then # long option: reformulate OPT and OPTARG opt="${OPTARG%%=*}" # extract long option name From bff107878b517a220c8eb5b126fa27fdeb2d71fc Mon Sep 17 00:00:00 2001 From: JSTayco Date: Wed, 29 Mar 2023 16:05:57 -0700 Subject: [PATCH 21/44] Default install location is now environment-based Better default install locations and updated the README to reflect. --- README.md | 8 ++++++++ setup.sh | 54 ++++++++++++++++++++++++++++++++++++++---------------- 2 files changed, 46 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index a81de9b..dca0b1d 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,7 @@ If you run on Linux and would like to use the GUI, there is now a port of it as - [Linux/macOS](#linux-and-macos-dependencies) - [Installation](#installation) - [Linux/macOS](#linux-and-macos) + - [Default Install Locations](#install-location) - [Windows](#windows) - [CUDNN 8.6](#optional--cudnn-86) - [Upgrading](#upgrading) @@ -90,6 +91,13 @@ Options: -h, --help Show this screen. ``` +#### Install location + +The default install location for Linux is `/opt/kohya_ss`. If /opt is not writeable, the fallback is `$HOME/kohya_ss`. Lastly, if all else fails it will simply install to the current folder you are in. + +On macOS and other non-Linux machines, it will default install to `$HOME/kohya_ss` followed by where you're currently at if there's no access to $HOME. +You can override this behavior by specifying an install directory with the -d option. + If you are using the interactive mode, our default values for the accelerate config screen after running the script answer "This machine", "None", "No" for the remaining questions. These are the same answers as the Windows install. diff --git a/setup.sh b/setup.sh index 152eb67..112d942 100755 --- a/setup.sh +++ b/setup.sh @@ -24,15 +24,50 @@ Options: EOF } +# Checks to see if variable is set and non-empty. +# This is defined first, so we can use the function for some default variable values +env_var_exists() { + if [[ ! -v "$1" ]] || [[ -z "$1" ]]; then + return 1 + else + return 0 + fi +} + +# Need RUNPOD to have a default value before first access +RUNPOD=false +if env_var_exists RUNPOD_POD_ID || env_var_exists RUNPOD_API_KEY; then + RUNPOD=true +fi + # Variables defined before the getopts loop, so we have sane default values. -DIR="/workspace/kohya_ss" +# Default installation locations based on OS and environment +if [[ "$OSTYPE" == "linux-gnu"* ]]; then + if [ "$RUNPOD" = true ]; then + DIR="/workspace/kohya_ss" + elif [ -w "/opt" ]; then + DIR="/opt/kohya_ss" + elif env_var_exists HOME; then + DIR="$HOME/kohya_ss" + else + # The last fallback is simply PWD + DIR="$(PWD)" + fi +else + if env_var_exists HOME; then + DIR="$HOME/kohya_ss" + else + # The last fallback is simply PWD + DIR="$(PWD)" + fi +fi + BRANCH="master" GIT_REPO="https://github.com/bmaltais/kohya_ss.git" -RUNPOD=false INTERACTIVE=false PUBLIC=false -while getopts "b:d:g:ir-:" opt; do +while getopts "b:d:g:ipr-:" opt; do # support long options: https://stackoverflow.com/a/28466267/519360 if [ "$opt" = "-" ]; then # long option: reformulate OPT and OPTARG opt="${OPTARG%%=*}" # extract long option name @@ -67,15 +102,6 @@ if [[ "$OSTYPE" == "linux-gnu"* ]]; then root=true fi - # Checks to see if variable is set and non-empty. - env_var_exists() { - if [[ ! -v "$1" ]] || [[ -z "$1" ]]; then - return 1 - else - return 0 - fi - } - get_distro_name() { local line if [ -f /etc/os-release ]; then @@ -142,10 +168,6 @@ if [[ "$OSTYPE" == "linux-gnu"* ]]; then echo "$FREESPACEINGB" } - if env_var_exists RUNPOD_POD_ID || env_var_exists RUNPOD_API_KEY; then - RUNPOD=true - fi - # Offer a warning and opportunity to cancel the installation if < 10Gb of Free Space detected if [ "$(size_available)" -lt 10 ]; then echo "You have less than 10Gb of free space. This installation may fail." From 9f6e0c1c8f6af6854987a5caa94bb4eb00b3817f Mon Sep 17 00:00:00 2001 From: bmaltais Date: Thu, 30 Mar 2023 07:23:37 -0400 Subject: [PATCH 22/44] Fix issue with LyCORIS version --- README.md | 2 ++ requirements.txt | 3 ++- tools/validate_requirements.py | 14 +++++++++++++- 3 files changed, 17 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 15030b7..c6e89c7 100644 --- a/README.md +++ b/README.md @@ -213,6 +213,8 @@ This will store your a backup file with your current locally installed pip packa ## Change History +* 2023/03/30 (v21.3.8) + - Fix issue with LyCORIS version not being found: https://github.com/bmaltais/kohya_ss/issues/481 * 2023/03/29 (v21.3.7) - Allow for 0.1 increment in Network and Conv alpha values: https://github.com/bmaltais/kohya_ss/pull/471 Thanks to @srndpty - Updated Lycoris module version diff --git a/requirements.txt b/requirements.txt index f8a1e88..5881d6e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,6 +25,7 @@ timm==0.6.12 huggingface-hub==0.13.0 tensorflow==2.10.1 # For locon support -lycoris_lora==0.1.4 +lycoris-lora @ git+https://github.com/KohakuBlueleaf/LyCORIS.git@c3d925421209a22a60d863ffa3de0b3e7e89f047 +# lycoris_lora==0.1.4 # for kohya_ss library . \ No newline at end of file diff --git a/tools/validate_requirements.py b/tools/validate_requirements.py index 948c21f..86f09d5 100644 --- a/tools/validate_requirements.py +++ b/tools/validate_requirements.py @@ -25,7 +25,19 @@ for requirement in requirements: try: pkg_resources.require(requirement) except pkg_resources.DistributionNotFound: - missing_requirements.append(requirement) + # Check if the requirement contains a VCS URL + if "@" in requirement: + # If it does, split the requirement into two parts: the package name and the VCS URL + package_name, vcs_url = requirement.split("@", 1) + # Use pip to install the package from the VCS URL + os.system(f"pip install -e {vcs_url}") + # Try to require the package again + try: + pkg_resources.require(package_name) + except pkg_resources.DistributionNotFound: + missing_requirements.append(requirement) + else: + missing_requirements.append(requirement) except pkg_resources.VersionConflict as e: wrong_version_requirements.append((requirement, str(e.req), e.dist.version)) From 90e34cc6f72ffc39ebea23b5624dea4ccecffe77 Mon Sep 17 00:00:00 2001 From: jstayco <127801635+jstayco@users.noreply.github.com> Date: Thu, 30 Mar 2023 16:20:13 -0700 Subject: [PATCH 23/44] Small improvement to an stdout line. Just a typo fix and added a word in an echo statement. --- setup.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.sh b/setup.sh index 112d942..c24d48b 100755 --- a/setup.sh +++ b/setup.sh @@ -197,7 +197,7 @@ if [[ "$OSTYPE" == "linux-gnu"* ]]; then git checkout "$BRANCH" else cd "$DIR" || exit 1 - echo "git repo detected. Attempting tp update repo instead." + echo "git repo detected. Attempting to update repository instead." echo "Updating: $GIT_REPO" git pull "$GIT_REPO" fi From 6e7e25cba927d97ef599a48a7de567f3d5687e6c Mon Sep 17 00:00:00 2001 From: JSTayco Date: Thu, 30 Mar 2023 18:20:42 -0700 Subject: [PATCH 24/44] Big update: Verbosity, space check skip, Huge reorg to account for macOS and other non-Linux OSs. Verbosity levels 1-3 added (used by FDs 3-5). Simplified code by creating more shared functions. --- setup.sh | 291 ++++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 201 insertions(+), 90 deletions(-) diff --git a/setup.sh b/setup.sh index 112d942..6a5c392 100755 --- a/setup.sh +++ b/setup.sh @@ -6,21 +6,26 @@ display_help() { cat <&2" #Don't change anything higher than the maximum verbosity allowed. +done + +for v in $( #From the verbosity level one higher than requested, through the maximum; + seq $((VERBOSITY + 1)) $MAXVERBOSITY +); do + (("$v" > "2")) && eval exec "$v>/dev/null" #Redirect these to bitbucket, provided that they don't match stdout and stderr. +done + +# Example of how to use the verbosity levels. +# printf "%s\n" "This message is seen at verbosity level 1 and above." >&3 +# printf "%s\n" "This message is seen at verbosity level 2 and above." >&4 +# printf "%s\n" "This message is seen at verbosity level 3 and above." >&5 + +# Debug variable dump at max verbosity +echo "BRANCH: $BRANCH +DIR: $DIR +GIT_REPO: $GIT_REPO +INTERACTIVE: $INTERACTIVE +PUBLIC: $PUBLIC +RUNPOD: $RUNPOD +SKIP_SPACE_CHECK: $SKIP_SPACE_CHECK +VERBOSITY: $VERBOSITY" >&5 + # This must be set after the getopts loop to account for $DIR changes. PARENT_DIR="$(dirname "${DIR}")" VENV_DIR="$DIR/venv" +# Shared functions +# This checks for free space on the installation drive and returns that in Gb. +size_available() { + local folder + if [ -d "$DIR" ]; then + folder="$DIR" + elif [ -d "$PARENT_DIR" ]; then + folder="$PARENT_DIR" + elif [ -d "$(echo "$DIR" | cut -d "/" -f2)" ]; then + folder="$(echo "$DIR" | cut -d "/" -f2)" + else + echo "We are assuming a root drive install for space-checking purposes." + folder='/' + fi + + local FREESPACEINKB="$(df -Pk "$folder" | sed 1d | grep -v used | awk '{ print $4 "\t" }')" + echo "Detected available space in Kb: $FREESPACEINKB" >&5 + local FREESPACEINGB=$((FREESPACEINKB / 1024 / 1024)) + echo "$FREESPACEINGB" +} + +# The expected usage is create_symlinks $symlink $target_file +create_symlinks() { + echo "Checking symlinks now." + # Next line checks for valid symlink + if [ -L "$1" ]; then + # Check if the linked file exists and points to the expected file + if [ -e "$1" ] && [ "$(readlink "$1")" == "$2" ]; then + echo "$(basename "$1") symlink looks fine. Skipping." + else + if [ -f "$2" ]; then + echo "Broken symlink detected. Recreating $(basename "$1")." + rm "$1" && + ln -s "$2" "$1" + else + echo "$2 does not exist. Nothing to link." + fi + fi + else + echo "Linking $(basename "$1")." + ln -s "$2" "$1" + fi +} + +# Attempt to non-interactively install a default accelerate config file unless specified otherwise. +# Documentation for order of precedence locations for configuration file for automated installation: +# https://huggingface.co/docs/accelerate/basic_tutorials/launch#custom-configurations +configure_accelerate() { + echo "Source accelerate config location: $DIR/config_files/accelerate/default_config.yaml" >&3 + if [ "$INTERACTIVE" = true ]; then + accelerate config + else + if env_var_exists HF_HOME; then + if [ ! -f "$HF_HOME/accelerate/default_config.yaml" ]; then + mkdir -p "$HF_HOME/accelerate/" && + echo "Target accelerate config location: $HF_HOME/accelerate/default_config.yaml" >&3 + cp "$DIR/config_files/accelerate/default_config.yaml" "$HF_HOME/accelerate/default_config.yaml" && + echo "Copied accelerate config file to: $HF_HOME/accelerate/default_config.yaml" + fi + elif env_var_exists XDG_CACHE_HOME; then + if [ ! -f "$XDG_CACHE_HOME/huggingface/accelerate" ]; then + mkdir -p "$XDG_CACHE_HOME/huggingface/accelerate" && + echo "Target accelerate config location: $XDG_CACHE_HOME/accelerate/default_config.yaml" >&3 + cp "$DIR/config_files/accelerate/default_config.yaml" "$XDG_CACHE_HOME/huggingface/accelerate/default_config.yaml" && + echo "Copied accelerate config file to: $XDG_CACHE_HOME/huggingface/accelerate/default_config.yaml" + fi + elif env_var_exists HOME; then + if [ ! -f "$HOME/.cache/huggingface/accelerate" ]; then + mkdir -p "$HOME/.cache/huggingface/accelerate" && + echo "Target accelerate config location: $HOME/accelerate/default_config.yaml" >&3 + cp "$DIR/config_files/accelerate/default_config.yaml" "$HOME/.cache/huggingface/accelerate/default_config.yaml" && + echo "Copying accelerate config file to: $HOME/.cache/huggingface/accelerate/default_config.yaml" + fi + else + echo "Could not place the accelerate configuration file. Please configure manually." + sleep 2 + accelerate config + fi + fi +} + +# Offer a warning and opportunity to cancel the installation if < 10Gb of Free Space detected +check_storage_space() { + if [ "$SKIP_SPACE_CHECK" = false ]; then + if [ "$(size_available)" -lt 10 ]; then + echo "You have less than 10Gb of free space. This installation may fail." + MSGTIMEOUT=10 # In seconds + MESSAGE="Continuing in..." + echo "Press control-c to cancel the installation." + for ((i = $MSGTIMEOUT; i >= 0; i--)); do + printf "\r${MESSAGE} %ss. " "${i}" + sleep 1 + done + fi + fi +} + +# Start OS-specific detection and work if [[ "$OSTYPE" == "linux-gnu"* ]]; then # Check if root or sudo root=false @@ -108,6 +244,7 @@ if [[ "$OSTYPE" == "linux-gnu"* ]]; then # We search for the line starting with ID= # Then we remove the ID= prefix to get the name itself line="$(grep -Ei '^ID=' /etc/os-release)" + echo "Raw detected os-release distro line: $line" >&5 line=${line##*=} echo "$line" return 0 @@ -134,6 +271,7 @@ if [[ "$OSTYPE" == "linux-gnu"* ]]; then # This is the "type" of distro. For example, Ubuntu returns "debian". if grep -Eiq '^ID_LIKE=' /etc/os-release >/dev/null; then line="$(grep -Ei '^ID_LIKE=' /etc/os-release)" + echo "Raw detected os-release distro family line: $line" >&5 line=${line##*=} echo "$line" return 0 @@ -149,36 +287,7 @@ if [[ "$OSTYPE" == "linux-gnu"* ]]; then fi } - # This checks for free space on the installation drive and returns that in Gb. - size_available() { - local folder - if [ -d "$DIR" ]; then - folder="$DIR" - elif [ -d "$PARENT_DIR" ]; then - folder="$PARENT_DIR" - elif [ -d "$(echo "$DIR" | cut -d "/" -f2)" ]; then - folder="$(echo "$DIR" | cut -d "/" -f2)" - else - echo "We are assuming a root drive install for space-checking purposes." - folder='/' - fi - - local FREESPACEINKB="$(df -Pk "$folder" | sed 1d | grep -v used | awk '{ print $4 "\t" }')" - local FREESPACEINGB=$((FREESPACEINKB / 1024 / 1024)) - echo "$FREESPACEINGB" - } - - # Offer a warning and opportunity to cancel the installation if < 10Gb of Free Space detected - if [ "$(size_available)" -lt 10 ]; then - echo "You have less than 10Gb of free space. This installation may fail." - MSGTIMEOUT=10 # In seconds - MESSAGE="Continuing in..." - echo "Press control-c to cancel the installation." - for ((i = $MSGTIMEOUT; i >= 0; i--)); do - printf "\r${MESSAGE} %ss. " "${i}" - sleep 1 - done - fi + check_storage_space # This is the pre-install work for a kohya installation on a runpod if [ "$RUNPOD" = true ]; then @@ -192,28 +301,29 @@ if [[ "$OSTYPE" == "linux-gnu"* ]]; then cd "$PARENT_DIR" || exit 1 if [ ! -d "$DIR/.git" ]; then echo "Cloning $GIT_REPO." - git clone "$GIT_REPO" + git clone "$GIT_REPO" >&3 cd "$DIR" || exit 1 - git checkout "$BRANCH" + git checkout "$BRANCH" >&3 else cd "$DIR" || exit 1 echo "git repo detected. Attempting tp update repo instead." echo "Updating: $GIT_REPO" - git pull "$GIT_REPO" + git pull "$GIT_REPO" >&3 fi fi fi distro=get_distro_name family=get_distro_family + echo "Raw detected distro string: $distro" >&4 + echo "Raw detected distro family string: $family" >&4 echo "Installing Python TK if not found on the system." - if "$distro" | grep -qi "Ubuntu" || "$family" | grep -qi "Ubuntu"; then echo "Ubuntu detected." if [ $(dpkg-query -W -f='${Status}' python3-tk 2>/dev/null | grep -c "ok installed") = 0 ]; then if [ "$root" = true ]; then - apt update -y && apt install -y python3-tk + apt update -y >&3 && apt install -y python3-tk >&3 else echo "This script needs to be run as root or via sudo to install packages." exit 1 @@ -225,7 +335,7 @@ if [[ "$OSTYPE" == "linux-gnu"* ]]; then echo "Redhat or Redhat base detected." if ! rpm -qa | grep -qi python3-tkinter; then if [ "$root" = true ]; then - dnf install python3-tkinter -y + dnf install python3-tkinter -y >&3 else echo "This script needs to be run as root or via sudo to install packages." exit 1 @@ -235,7 +345,7 @@ if [[ "$OSTYPE" == "linux-gnu"* ]]; then echo "Arch Linux or Arch base detected." if ! pacman -Qi tk >/dev/null; then if [ "$root" = true ]; then - pacman --noconfirm -S tk + pacman --noconfirm -S tk >&3 else echo "This script needs to be run as root or via sudo to install packages." exit 1 @@ -245,7 +355,7 @@ if [[ "$OSTYPE" == "linux-gnu"* ]]; then echo "OpenSUSE detected." if ! rpm -qa | grep -qi python-tk; then if [ "$root" = true ]; then - zypper install -y python-tk + zypper install -y python-tk >&3 else echo "This script needs to be run as root or via sudo to install packages." exit 1 @@ -263,53 +373,40 @@ if [[ "$OSTYPE" == "linux-gnu"* ]]; then python3 -m venv venv source venv/bin/activate - pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116 - pip install --use-pep517 --upgrade -r requirements.txt - pip install -U -I --no-deps https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/linux/xformers-0.0.14.dev0-cp310-cp310-linux_x86_64.whl + + # Updating pip if there is one + echo "Checking for pip updates before Python operations." + python3 -m pip install --upgrade pip >&3 + + echo "Installing python dependencies. This could take a few minutes as it downloads files." + echo "If this operation ever runs too long, you can rerun this script in verbose mode to check." + pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116 >&3 + pip install --use-pep517 --upgrade -r requirements.txt >&3 + pip install -U -I --no-deps \ + https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/linux/xformers-0.0.14.dev0-cp310-cp310-linux_x86_64.whl >&3 # We need this extra package and setup if we are running in a runpod if [ "$RUNPOD" = true ]; then pip install tensorrt - ln -s "$VENV_DIR/lib/python3.10/site-packages/tensorrt/libnvinfer_plugin.so.8" \ - "$VENV_DIR/lib/python3.10/site-packages/tensorrt/libnvinfer_plugin.so.7" - ln -s "$VENV_DIR/lib/python3.10/site-packages/tensorrt/libnvinfer.so.8" \ - "$VENV_DIR/lib/python3.10/site-packages/tensorrt/libnvinfer.so.7" - ln -s "$VENV_DIR/lib/python3.10/site-packages/nvidia/cuda_runtime/lib/libcudart.so.12" \ - "$VENV_DIR/lib/python3.10/site-packages/nvidia/cuda_runtime/lib/libcudart.so.11.0" + # Symlink paths + libnvinfer_plugin_symlink="$VENV_DIR/lib/python3.10/site-packages/tensorrt/libnvinfer_plugin.so.7" + libnvinfer_symlink="$VENV_DIR/lib/python3.10/site-packages/tensorrt/libnvinfer.so.7" + libcudart_symlink="$VENV_DIR/lib/python3.10/site-packages/nvidia/cuda_runtime/lib/libcudart.so.11.0" + + #Target file paths + libnvinfer_plugin_target="$VENV_DIR/lib/python3.10/site-packages/tensorrt/libnvinfer_plugin.so.8" + libnvinfer_target="$VENV_DIR/lib/python3.10/site-packages/tensorrt/libnvinfer.so.8" + libcudart_target="$VENV_DIR/lib/python3.10/site-packages/nvidia/cuda_runtime/lib/libcudart.so.12" + + echo "Checking symlinks now." + create_symlinks "$libnvinfer_plugin_symlink" "$libnvinfer_plugin_target" + create_symlinks "$libnvinfer_symlink" "$libnvinfer_target" + create_symlinks "$libcudart_symlink" "$libcudart_target" export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:$VENV_DIR/lib/python3.10/site-packages/tensorrt/" export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:$VENV_DIR/lib/python3.10/site-packages/nvidia/cuda_runtime/lib/" - # Attempt to non-interactively install a default accelerate config file unless specified otherwise. - # Documentation for order of precedence locations for configuration file for automated installation: - # https://huggingface.co/docs/accelerate/basic_tutorials/launch#custom-configurations - if [ "$INTERACTIVE" = true ]; then - accelerate config - else - if env_var_exists HF_HOME; then - if [ ! -f "$HF_HOME/accelerate/default_config.yaml" ]; then - mkdir -p "$HF_HOME/accelerate/" && - cp "$DIR/config_files/accelerate/default_config.yaml" "$HF_HOME/accelerate/default_config.yaml" && - echo "Copied accelerate config file to: $HF_HOME/accelerate/default_config.yaml" - fi - elif env_var_exists XDG_CACHE_HOME; then - if [ ! -f "$XDG_CACHE_HOME/huggingface/accelerate" ]; then - mkdir -p "$XDG_CACHE_HOME/huggingface/accelerate" && - cp "$DIR/config_files/accelerate/default_config.yaml" "$XDG_CACHE_HOME/huggingface/accelerate/default_config.yaml" && - echo "Copied accelerate config file to: $XDG_CACHE_HOME/huggingface/accelerate/default_config.yaml" - fi - elif env_var_exists HOME; then - if [ ! -f "$HOME/.cache/huggingface/accelerate" ]; then - mkdir -p "$HOME/.cache/huggingface/accelerate" && - cp "$DIR/config_files/accelerate/default_config.yaml" "$HOME/.cache/huggingface/accelerate/default_config.yaml" && - echo "Copying accelerate config file to: $HOME/.cache/huggingface/accelerate/default_config.yaml" - fi - else - echo "Could not place the accelerate configuration file. Please configure manually." - sleep 2 - accelerate config - fi - fi + configure_accelerate # This is a non-interactive environment, so just directly call gui.sh after all setup steps are complete. if command -v bash >/dev/null; then @@ -329,6 +426,7 @@ if [[ "$OSTYPE" == "linux-gnu"* ]]; then fi echo -e "Setup finished! Run \e[0;92m./gui.sh\e[0m to start." + echo "Please note if you'd like to expose your public server you need to run ./gui.sh --share" elif [[ "$OSTYPE" == "darwin"* ]]; then # The initial setup script to prep the environment on macOS # xformers has been omitted as that is for Nvidia GPUs only @@ -341,16 +439,20 @@ elif [[ "$OSTYPE" == "darwin"* ]]; then exit 1 fi + check_storage_space + # Install base python packages echo "Installing Python 3.10 if not found." if ! brew ls --versions python@3.10 >/dev/null; then - brew install python@3.10 + echo "Installing Python 3.10." + brew install python@3.10 >&3 else echo "Python 3.10 found!" fi echo "Installing Python-TK 3.10 if not found." if ! brew ls --versions python-tk@3.10 >/dev/null; then - brew install python-tk@3.10 + echo "Installing Python TK 3.10." + brew install python-tk@3.10 >&3 else echo "Python Tkinter 3.10 found!" fi @@ -362,15 +464,24 @@ elif [[ "$OSTYPE" == "darwin"* ]]; then # DEBUG ONLY #pip install pydevd-pycharm~=223.8836.43 + # Updating pip if there is one + echo "Checking for pip updates before Python operations." + python3 -m pip install --upgrade pip >&3 + # Tensorflow installation - if wget https://github.com/apple/tensorflow_macos/releases/download/v0.1alpha3/tensorflow_macos-0.1a3-cp38-cp38-macosx_11_0_arm64.whl /tmp; then - python -m pip install tensorflow==0.1a3 -f https://github.com/apple/tensorflow_macos/releases/download/v0.1alpha3/tensorflow_macos-0.1a3-cp38-cp38-macosx_11_0_arm64.whl + echo "Downloading and installing macOS Tensorflow." + if wget https://github.com/apple/tensorflow_macos/releases/download/v0.1alpha3/tensorflow_macos-0.1a3-cp38-cp38-macosx_11_0_arm64.whl /tmp &>3; then + python -m pip install tensorflow==0.1a3 \ + -f https://github.com/apple/tensorflow_macos/releases/download/v0.1alpha3/tensorflow_macos-0.1a3-cp38-cp38-macosx_11_0_arm64.whl >&3 rm -f /tmp/tensorflow_macos-0.1a3-cp38-cp38-macosx_11_0_arm64.whl fi - pip install torch==2.0.0 torchvision==0.15.1 -f https://download.pytorch.org/whl/cpu/torch_stable.html - python -m pip install --use-pep517 --upgrade -r requirements.txt - accelerate config + echo "Installing python dependencies. This could take a few minutes as it downloads files." + echo "If this operation ever runs too long, you can rerun this script in verbose mode to check." + pip install torch==2.0.0 torchvision==0.15.1 \ + -f https://download.pytorch.org/whl/cpu/torch_stable.html >&3 + python -m pip install --use-pep517 --upgrade -r requirements.txt >&3 + configure_accelerate echo -e "Setup finished! Run ./gui.sh to start." else echo "Python not found. Please ensure you install Python." From e6da2d135b13b178122b808b91e25e3a1f3320b1 Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Fri, 31 Mar 2023 16:45:04 -0600 Subject: [PATCH 25/44] Add setup/upgrade batch files --- README.md | 35 +++++++---------------------------- setup.bat | 13 +++++++++++++ upgrade.bat | 16 ++++++++++++++++ 3 files changed, 36 insertions(+), 28 deletions(-) create mode 100644 setup.bat create mode 100644 upgrade.bat diff --git a/README.md b/README.md index c6e89c7..3614d94 100644 --- a/README.md +++ b/README.md @@ -64,36 +64,19 @@ cd kohya_ss bash ubuntu_setup.sh ``` -then configure accelerate with the same answers as in the Windows instructions when prompted. +then configure accelerate with the same answers as in the MacOS instructions when prompted. ### Windows +In the terminal, run -Give unrestricted script access to powershell so venv can work: - -- Run PowerShell as an administrator -- Run `Set-ExecutionPolicy Unrestricted` and answer 'A' -- Close PowerShell - -Open a regular user Powershell terminal and run the following commands: - -```powershell +``` git clone https://github.com/bmaltais/kohya_ss.git cd kohya_ss - -python -m venv venv -.\venv\Scripts\activate - -pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116 -pip install --use-pep517 --upgrade -r requirements.txt -pip install -U -I --no-deps https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl - -cp .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\ -cp .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py -cp .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py - -accelerate config +setup.bat ``` +then configure accelerate with the same answers as in the MacOS instructions when prompted. + ### Optional: CUDNN 8.6 This step is optional but can improve the learning speed for NVIDIA 30X0/40X0 owners. It allows for larger training batch size and faster training speed. @@ -125,11 +108,7 @@ Once the commands have completed successfully you should be ready to use the new When a new release comes out, you can upgrade your repo with the following commands in the root directory: ```powershell -git pull - -.\venv\Scripts\activate - -pip install --use-pep517 --upgrade -r requirements.txt +upgrade.bat ``` Once the commands have completed successfully you should be ready to use the new version. diff --git a/setup.bat b/setup.bat new file mode 100644 index 0000000..2c84356 --- /dev/null +++ b/setup.bat @@ -0,0 +1,13 @@ +@echo off +python -m venv venv +call .\venv\Scripts\activate.bat + +pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116 +pip install --use-pep517 --upgrade -r requirements.txt +pip install -U -I --no-deps https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl + +copy /y .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\ +copy /y .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py +copy /y .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py + +accelerate config \ No newline at end of file diff --git a/upgrade.bat b/upgrade.bat new file mode 100644 index 0000000..787df73 --- /dev/null +++ b/upgrade.bat @@ -0,0 +1,16 @@ +@echo off +:: Check if there are any changes that need to be committed +git status --short +if %errorlevel%==1 ( + echo There are changes that need to be committed. Please stash or undo your changes before running this script. + exit +) + +:: Pull the latest changes from the remote repository +git pull + +:: Activate the virtual environment +call .\venv\Scripts\activate.baT + +:: Upgrade the required packages +pip install --upgrade -r requirements.txt \ No newline at end of file From e5b2257d7d7e8f85f0e019d83e40d9c1a23cd4d8 Mon Sep 17 00:00:00 2001 From: JSTayco Date: Fri, 31 Mar 2023 16:39:28 -0700 Subject: [PATCH 26/44] Integrated upgrade.sh Integrated upgrade.sh functionality, consolidated code to install dependencies, added ability to skip git operations, ensured the script can run from anywhere aimed at installation anywhere, ensured all git commnds worked from anywhere aimed at target folder, normalized specified install directory names (always get the absolute path). --- .gitignore | 3 +- README.md | 17 ++++-- requirements.txt | 1 + setup.sh | 153 +++++++++++++++++++++++++++++------------------ upgrade.sh | 16 ----- 5 files changed, 111 insertions(+), 79 deletions(-) delete mode 100755 upgrade.sh diff --git a/.gitignore b/.gitignore index 71fe116..5ed6370 100644 --- a/.gitignore +++ b/.gitignore @@ -8,4 +8,5 @@ wd14_tagger_model .DS_Store locon gui-user.bat -gui-user.ps1 \ No newline at end of file +gui-user.ps1 +.idea \ No newline at end of file diff --git a/README.md b/README.md index 399b304..cfe0f85 100644 --- a/README.md +++ b/README.md @@ -82,13 +82,16 @@ Usage: setup.sh --branch=dev --dir=/workspace/kohya_ss --git-repo=https://mycustom.repo.tld/custom_fork.git Options: - -b BRANCH, --branch=BRANCH Select which branch of kohya to checkout on new installs. + -b BRANCH, --branch=BRANCH Select which branch of kohya to check out on new installs. -d DIR, --dir=DIR The full path you want kohya_ss installed to. - -g, --git_repo You can optionally provide a git repo to checkout for runpod installation. Useful for custom forks. + -g, --git_repo You can optionally provide a git repo to check out for runpod installation. Useful for custom forks. + -h, --help Show this screen. + -i, --interactive Interactively configure accelerate instead of using default config file. + -n, --no-update Do not update kohya_ss repo. No git pull or clone operations. -p, --public Expose public URL in runpod mode. Won't have an effect in other modes. -r, --runpod Forces a runpod installation. Useful if detection fails for any reason. - -i, --interactive Interactively configure accelerate instead of using default config file. - -h, --help Show this screen. + -s, --skip-space-check Skip the 10Gb minimum storage space check. + -v, --verbose Increase verbosity levels up to 3. ``` #### Install location @@ -170,7 +173,11 @@ When a new release comes out, you can upgrade your repo with the following comma You can cd into the root directory and simply run ```bash -./upgrade.sh +# Refresh and update everything +./setup.sh + +# This will refresh everything, but NOT close or pull the git repo. +./setup.sh --no-git-update ``` Once the commands have completed successfully you should be ready to use the new version. diff --git a/requirements.txt b/requirements.txt index 303b2b4..882617b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,6 +25,7 @@ timm==0.6.12 # tensorflow<2.11 huggingface-hub==0.13.0 tensorflow==2.10.1; sys_platform != 'darwin' +tensorflow-macos==2.12.0; sys_platform == 'darwin' # For locon support lycoris_lora==0.1.4 # for kohya_ss library diff --git a/setup.sh b/setup.sh index 686d141..650bf4c 100755 --- a/setup.sh +++ b/setup.sh @@ -13,15 +13,16 @@ Usage: # Same as example 1, but uses long options setup.sh --branch=dev --dir=/workspace/kohya_ss --git-repo=https://mycustom.repo.tld/custom_fork.git - # Maximum verbosity, fully automated install in a runpod environment skipping the runpod env checks + # Maximum verbosity, fully automated installation in a runpod environment skipping the runpod env checks setup.sh -vvv --skip-space-check --runpod Options: - -b BRANCH, --branch=BRANCH Select which branch of kohya to checkout on new installs. + -b BRANCH, --branch=BRANCH Select which branch of kohya to check out on new installs. -d DIR, --dir=DIR The full path you want kohya_ss installed to. - -g, --git_repo You can optionally provide a git repo to checkout for runpod installation. Useful for custom forks. + -g, --git_repo You can optionally provide a git repo to check out for runpod installation. Useful for custom forks. -h, --help Show this screen. -i, --interactive Interactively configure accelerate instead of using default config file. + -n, --no-git-update Do not update kohya_ss repo. No git pull or clone operations. -p, --public Expose public URL in runpod mode. Won't have an effect in other modes. -r, --runpod Forces a runpod installation. Useful if detection fails for any reason. -s, --skip-space-check Skip the 10Gb minimum storage space check. @@ -45,11 +46,15 @@ if env_var_exists RUNPOD_POD_ID || env_var_exists RUNPOD_API_KEY; then RUNPOD=true fi +SCRIPT_DIR="$(cd -- $(dirname -- "$0") && pwd)" + # Variables defined before the getopts loop, so we have sane default values. # Default installation locations based on OS and environment if [[ "$OSTYPE" == "linux-gnu"* ]]; then if [ "$RUNPOD" = true ]; then DIR="/workspace/kohya_ss" + elif [ -d "$SCRIPT_DIR/.git" ]; then + DIR="$SCRIPT_DIR" elif [ -w "/opt" ]; then DIR="/opt/kohya_ss" elif env_var_exists HOME; then @@ -59,7 +64,9 @@ if [[ "$OSTYPE" == "linux-gnu"* ]]; then DIR="$(PWD)" fi else - if env_var_exists HOME; then + if [ -d "$SCRIPT_DIR/.git" ]; then + DIR="$SCRIPT_DIR" + elif env_var_exists HOME; then DIR="$HOME/kohya_ss" else # The last fallback is simply PWD @@ -75,8 +82,9 @@ GIT_REPO="https://github.com/bmaltais/kohya_ss.git" INTERACTIVE=false PUBLIC=false SKIP_SPACE_CHECK=false +SKIP_GIT_UPDATE=false -while getopts ":vb:d:g:iprs-:" opt; do +while getopts ":vb:d:g:inprs-:" opt; do # support long options: https://stackoverflow.com/a/28466267/519360 if [ "$opt" = "-" ]; then # long option: reformulate OPT and OPTARG opt="${OPTARG%%=*}" # extract long option name @@ -88,6 +96,7 @@ while getopts ":vb:d:g:iprs-:" opt; do d | dir) DIR="$OPTARG" ;; g | git-repo) GIT_REPO="$OPTARG" ;; i | interactive) INTERACTIVE=true ;; + n | no-git-update) SKIP_GIT_UPDATE=true ;; p | public) PUBLIC=true ;; r | runpod) RUNPOD=true ;; s | skip-space-check) SKIP_SPACE_CHECK=true ;; @@ -98,6 +107,15 @@ while getopts ":vb:d:g:iprs-:" opt; do done shift $((OPTIND - 1)) +# Just in case someone puts in a relative path into $DIR, +# we're going to get the absolute path of that. +if [[ "$DIR" != /* ]] && [[ "$DIR" != ~* ]]; then + DIR="$( + cd "$(dirname "$DIR")" || exit 1 + pwd + )/$(basename "$DIR")" +fi + for v in $( #Start counting from 3 since 1 and 2 are standards (stdout/stderr). seq 3 $VERBOSITY ); do @@ -123,7 +141,8 @@ INTERACTIVE: $INTERACTIVE PUBLIC: $PUBLIC RUNPOD: $RUNPOD SKIP_SPACE_CHECK: $SKIP_SPACE_CHECK -VERBOSITY: $VERBOSITY" >&5 +VERBOSITY: $VERBOSITY +Script directory is ${SCRIPT_DIR}." >&5 # This must be set after the getopts loop to account for $DIR changes. PARENT_DIR="$(dirname "${DIR}")" @@ -144,13 +163,15 @@ size_available() { folder='/' fi - local FREESPACEINKB="$(df -Pk "$folder" | sed 1d | grep -v used | awk '{ print $4 "\t" }')" + local FREESPACEINKB + FREESPACEINKB="$(df -Pk "$folder" | sed 1d | grep -v used | awk '{ print $4 "\t" }')" echo "Detected available space in Kb: $FREESPACEINKB" >&5 - local FREESPACEINGB=$((FREESPACEINKB / 1024 / 1024)) + local FREESPACEINGB + FREESPACEINGB=$((FREESPACEINKB / 1024 / 1024)) echo "$FREESPACEINGB" } -# The expected usage is create_symlinks $symlink $target_file +# The expected usage is create_symlinks symlink target_file create_symlinks() { echo "Checking symlinks now." # Next line checks for valid symlink @@ -173,6 +194,33 @@ create_symlinks() { fi } +install_pip_dependencies() { + # Updating pip if there is one + echo "Checking for pip updates before Python operations." + python3 -m pip install --upgrade pip >&3 + + echo "Installing python dependencies. This could take a few minutes as it downloads files." + echo "If this operation ever runs too long, you can rerun this script in verbose mode to check." + case "$OSTYPE" in + "linux-gnu"*) pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 \ + --extra-index-url https://download.pytorch.org/whl/cu116 >&3 && + pip install -U -I --no-deps \ + https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/downloadlinux/xformers-0.0.14.dev0-cp310-cp310-linux_x86_64.whl >&3 ;; + "darwin"*) pip install torch==2.0.0 torchvision==0.15.1 \ + -f https://download.pytorch.org/whl/cpu/torch_stable.html >&3 ;; + "cygwin") + : + ;; + "msys") + : + ;; + esac + + # DEBUG ONLY (Update this version number to whatever PyCharm recommends) + # pip install pydevd-pycharm~=223.8836.43 + python -m pip install --use-pep517 --upgrade -r requirements.txt >&3 +} + # Attempt to non-interactively install a default accelerate config file unless specified otherwise. # Documentation for order of precedence locations for configuration file for automated installation: # https://huggingface.co/docs/accelerate/basic_tutorials/launch#custom-configurations @@ -218,7 +266,7 @@ check_storage_space() { MSGTIMEOUT=10 # In seconds MESSAGE="Continuing in..." echo "Press control-c to cancel the installation." - for ((i = $MSGTIMEOUT; i >= 0; i--)); do + for ((i = MSGTIMEOUT; i >= 0; i--)); do printf "\r${MESSAGE} %ss. " "${i}" sleep 1 done @@ -226,6 +274,36 @@ check_storage_space() { fi } +update_kohya_ss() { + if [ "$SKIP_GIT_UPDATE" = false ]; then + if command -v git >/dev/null; then + # First, we make sure there are no changes that need to be made in git, so no work is lost. + if [ -z "$(git -c "$DIR" status --porcelain=v1 2>/dev/null)" ]; then + echo "There are changes that need to be committed." + echo "Commit those changes or run this script with -n to skip git operations entirely." + exit 1 + fi + + cd "$PARENT_DIR" || exit 1 + echo "Attempting to clone $GIT_REPO." + if [ ! -d "$DIR/.git" ]; then + git -c "$DIR" clone "$GIT_REPO" "$(basename "$DIR")" >&3 + cd "$DIR" || exit 1 + git -c "$DIR" checkout -b "$BRANCH" >&3 + else + cd "$DIR" || exit 1 + echo "git repo detected. Attempting to update repository instead." + echo "Updating: $GIT_REPO" + git pull "$GIT_REPO" >&3 + git checkout -b "$BRANCH" + fi + else + echo "You need to install git." + echo "Rerun this after installing git or run this script with -n to skip the git operations." + fi + fi +} + # Start OS-specific detection and work if [[ "$OSTYPE" == "linux-gnu"* ]]; then # Check if root or sudo @@ -296,23 +374,11 @@ if [[ "$OSTYPE" == "linux-gnu"* ]]; then export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:"$VENV_DIR"/lib/python3.10/site-packages/tensorrt/ export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:"$VENV_DIR"/lib/python3.10/site-packages/nvidia/cuda_runtime/lib/ cd "$DIR" || exit 1 - else - echo "Clean installation on a runpod detected." - cd "$PARENT_DIR" || exit 1 - if [ ! -d "$DIR/.git" ]; then - echo "Cloning $GIT_REPO." - git clone "$GIT_REPO" >&3 - cd "$DIR" || exit 1 - git checkout "$BRANCH" >&3 - else - cd "$DIR" || exit 1 - echo "git repo detected. Attempting to update repository instead." - echo "Updating: $GIT_REPO" - git pull "$GIT_REPO" >&3 - fi fi fi + update_kohya_ss + distro=get_distro_name family=get_distro_family echo "Raw detected distro string: $distro" >&4 @@ -373,21 +439,12 @@ if [[ "$OSTYPE" == "linux-gnu"* ]]; then python3 -m venv venv source venv/bin/activate - - # Updating pip if there is one - echo "Checking for pip updates before Python operations." - python3 -m pip install --upgrade pip >&3 - - echo "Installing python dependencies. This could take a few minutes as it downloads files." - echo "If this operation ever runs too long, you can rerun this script in verbose mode to check." - pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116 >&3 - pip install --use-pep517 --upgrade -r requirements.txt >&3 - pip install -U -I --no-deps \ - https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/linux/xformers-0.0.14.dev0-cp310-cp310-linux_x86_64.whl >&3 + install_pip_dependencies # We need this extra package and setup if we are running in a runpod if [ "$RUNPOD" = true ]; then - pip install tensorrt + echo "Installing tenssort." + pip install tensorrt >&3 # Symlink paths libnvinfer_plugin_symlink="$VENV_DIR/lib/python3.10/site-packages/tensorrt/libnvinfer_plugin.so.7" libnvinfer_symlink="$VENV_DIR/lib/python3.10/site-packages/tensorrt/libnvinfer.so.7" @@ -457,30 +514,12 @@ elif [[ "$OSTYPE" == "darwin"* ]]; then echo "Python Tkinter 3.10 found!" fi + update_kohya_ss + if command -v python3.10 >/dev/null; then python3.10 -m venv venv source venv/bin/activate - - # DEBUG ONLY - #pip install pydevd-pycharm~=223.8836.43 - - # Updating pip if there is one - echo "Checking for pip updates before Python operations." - python3 -m pip install --upgrade pip >&3 - - # Tensorflow installation - echo "Downloading and installing macOS Tensorflow." - if wget https://github.com/apple/tensorflow_macos/releases/download/v0.1alpha3/tensorflow_macos-0.1a3-cp38-cp38-macosx_11_0_arm64.whl /tmp &>3; then - python -m pip install tensorflow==0.1a3 \ - -f https://github.com/apple/tensorflow_macos/releases/download/v0.1alpha3/tensorflow_macos-0.1a3-cp38-cp38-macosx_11_0_arm64.whl >&3 - rm -f /tmp/tensorflow_macos-0.1a3-cp38-cp38-macosx_11_0_arm64.whl - fi - - echo "Installing python dependencies. This could take a few minutes as it downloads files." - echo "If this operation ever runs too long, you can rerun this script in verbose mode to check." - pip install torch==2.0.0 torchvision==0.15.1 \ - -f https://download.pytorch.org/whl/cpu/torch_stable.html >&3 - python -m pip install --use-pep517 --upgrade -r requirements.txt >&3 + install_pip_dependencies configure_accelerate echo -e "Setup finished! Run ./gui.sh to start." else diff --git a/upgrade.sh b/upgrade.sh deleted file mode 100755 index 8ed545f..0000000 --- a/upgrade.sh +++ /dev/null @@ -1,16 +0,0 @@ -#!/usr/bin/env bash - -# Check if there are any changes that need to be committed -if [[ -n $(git status --short) ]]; then - echo "There are changes that need to be committed. Please stash or undo your changes before running this script." >&2 - exit 1 -fi - -# Pull the latest changes from the remote repository -git pull - -# Activate the virtual environment -source venv/bin/activate - -# Upgrade the required packages -pip install --upgrade -r requirements.txt From fcffc131e119d69402387fe7c3dc8baff873c6c6 Mon Sep 17 00:00:00 2001 From: jstayco <127801635+jstayco@users.noreply.github.com> Date: Fri, 31 Mar 2023 16:59:04 -0700 Subject: [PATCH 27/44] Typo fix in README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index cfe0f85..94b9466 100644 --- a/README.md +++ b/README.md @@ -176,7 +176,7 @@ You can cd into the root directory and simply run # Refresh and update everything ./setup.sh -# This will refresh everything, but NOT close or pull the git repo. +# This will refresh everything, but NOT clone or pull the git repo. ./setup.sh --no-git-update ``` From bd2e829ae32014f613500dad3188804bbb4de6a6 Mon Sep 17 00:00:00 2001 From: JSTayco Date: Fri, 31 Mar 2023 17:24:09 -0700 Subject: [PATCH 28/44] More safeties around Python ops More safeties and more code consolidation. Now we try to exit the python venv after all the python operations. All the python operations were consolidated to facilitate this. --- requirements.txt | 3 ++- setup.sh | 58 +++++++++++++++++++++++++++++++++--------------- 2 files changed, 42 insertions(+), 19 deletions(-) diff --git a/requirements.txt b/requirements.txt index 882617b..83508f5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,7 +13,8 @@ lion-pytorch==0.0.6 opencv-python==4.7.0.68 pytorch-lightning==1.9.0 safetensors==0.2.6 -tensorboard==2.10.1 +tensorboard==2.10.1 ; sys_platform != 'darwin' +tensorboard==2.12.1 ; sys_platform == 'darwin' tk==0.1.0 toml==0.10.2 transformers==4.26.0 diff --git a/setup.sh b/setup.sh index 650bf4c..14de32d 100755 --- a/setup.sh +++ b/setup.sh @@ -194,7 +194,22 @@ create_symlinks() { fi } -install_pip_dependencies() { +install_python_dependencies() { + # Switch to local virtual env + echo "Switching to virtual Python environment." + if command -v python3 >/dev/null; then + python3 -m venv venv + elif command -v python3.10 >/dev/null; then + python3.10 -m venv venv + else + echo "Valid python3 or python3.10 binary not found." + echo "Cannot proceed with the python steps." + return 1 + fi + + # Activate the virtual environment + source venv/bin/activate + # Updating pip if there is one echo "Checking for pip updates before Python operations." python3 -m pip install --upgrade pip >&3 @@ -216,9 +231,23 @@ install_pip_dependencies() { ;; esac + if [ "$RUNPOD" = true ]; then + echo "Installing tenssort." + pip install tensorrt >&3 + fi + # DEBUG ONLY (Update this version number to whatever PyCharm recommends) # pip install pydevd-pycharm~=223.8836.43 - python -m pip install --use-pep517 --upgrade -r requirements.txt >&3 + python -m pip install --use-pep517 --upgrade -r "$DIR/requirements.txt" >&3 + + if [ -n "$VIRTUAL_ENV" ]; then + if command -v deactivate >/dev/null; then + echo "Exiting Python virtual environment." + deactivate + else + echo "deactivate command not found. Could still be in the Python virtual environment." + fi + fi } # Attempt to non-interactively install a default accelerate config file unless specified otherwise. @@ -301,6 +330,8 @@ update_kohya_ss() { echo "You need to install git." echo "Rerun this after installing git or run this script with -n to skip the git operations." fi + else + echo "Skipping git operations." fi } @@ -437,14 +468,10 @@ if [[ "$OSTYPE" == "linux-gnu"* ]]; then fi fi - python3 -m venv venv - source venv/bin/activate - install_pip_dependencies + install_python_dependencies - # We need this extra package and setup if we are running in a runpod + # We need just a little bit more setup for non-interactive environments if [ "$RUNPOD" = true ]; then - echo "Installing tenssort." - pip install tensorrt >&3 # Symlink paths libnvinfer_plugin_symlink="$VENV_DIR/lib/python3.10/site-packages/tensorrt/libnvinfer_plugin.so.7" libnvinfer_symlink="$VENV_DIR/lib/python3.10/site-packages/tensorrt/libnvinfer.so.7" @@ -516,17 +543,12 @@ elif [[ "$OSTYPE" == "darwin"* ]]; then update_kohya_ss - if command -v python3.10 >/dev/null; then - python3.10 -m venv venv - source venv/bin/activate - install_pip_dependencies - configure_accelerate - echo -e "Setup finished! Run ./gui.sh to start." - else - echo "Python not found. Please ensure you install Python." - echo "The brew command for Python 3.10 is: brew install python@3.10" - exit 1 + if ! install_python_dependencies; then + echo "You may need to install Python. The command for this is brew install python@3.10." fi + + configure_accelerate + echo -e "Setup finished! Run ./gui.sh to start." elif [[ "$OSTYPE" == "cygwin" ]]; then # Cygwin is a standalone suite of Linux utilies on Windows echo "This hasn't been validated on cygwin yet." From b02fb86765569e685f42313b3b6a5a2bf0644aea Mon Sep 17 00:00:00 2001 From: JSTayco Date: Fri, 31 Mar 2023 19:30:43 -0700 Subject: [PATCH 29/44] Small README update Updates README to cover new location change and adds one small comment to clarify a variable --- README.md | 6 ++++-- setup.sh | 1 + 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 94b9466..66554b3 100644 --- a/README.md +++ b/README.md @@ -96,9 +96,11 @@ Options: #### Install location -The default install location for Linux is `/opt/kohya_ss`. If /opt is not writeable, the fallback is `$HOME/kohya_ss`. Lastly, if all else fails it will simply install to the current folder you are in. +The default install location for Linux is where the script is located if a previous installation is detected that location. +Otherwise, it will fall to `/opt/kohya_ss`. If /opt is not writeable, the fallback is `$HOME/kohya_ss`. Lastly, if all else fails it will simply install to the current folder you are in (PWD). -On macOS and other non-Linux machines, it will default install to `$HOME/kohya_ss` followed by where you're currently at if there's no access to $HOME. +On macOS and other non-Linux machines, it will first try to detect an install where the script is run from and then run setup there if that's detected. +If a previous install isn't found at that location, then it will default install to `$HOME/kohya_ss` followed by where you're currently at if there's no access to $HOME. You can override this behavior by specifying an install directory with the -d option. If you are using the interactive mode, our default values for the accelerate config screen after running the script answer "This machine", "None", "No" for the remaining questions. diff --git a/setup.sh b/setup.sh index 14de32d..c0f0f79 100755 --- a/setup.sh +++ b/setup.sh @@ -46,6 +46,7 @@ if env_var_exists RUNPOD_POD_ID || env_var_exists RUNPOD_API_KEY; then RUNPOD=true fi +# This gets the directory the script is run from so pathing can work relative to the script where needed. SCRIPT_DIR="$(cd -- $(dirname -- "$0") && pwd)" # Variables defined before the getopts loop, so we have sane default values. From febd553864425a717c8a5806cd656d0b28935944 Mon Sep 17 00:00:00 2001 From: JSTayco Date: Fri, 31 Mar 2023 20:00:56 -0700 Subject: [PATCH 30/44] Python is now more dynamic Made python and the requirements.txt location independent. --- setup.sh | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/setup.sh b/setup.sh index c0f0f79..c9146c3 100755 --- a/setup.sh +++ b/setup.sh @@ -199,9 +199,9 @@ install_python_dependencies() { # Switch to local virtual env echo "Switching to virtual Python environment." if command -v python3 >/dev/null; then - python3 -m venv venv + python3 -m venv "$DIR/venv" elif command -v python3.10 >/dev/null; then - python3.10 -m venv venv + python3.10 -m venv "$DIR/venv" else echo "Valid python3 or python3.10 binary not found." echo "Cannot proceed with the python steps." @@ -209,7 +209,7 @@ install_python_dependencies() { fi # Activate the virtual environment - source venv/bin/activate + source "$DIR/venv/bin/activate" # Updating pip if there is one echo "Checking for pip updates before Python operations." @@ -239,7 +239,17 @@ install_python_dependencies() { # DEBUG ONLY (Update this version number to whatever PyCharm recommends) # pip install pydevd-pycharm~=223.8836.43 - python -m pip install --use-pep517 --upgrade -r "$DIR/requirements.txt" >&3 + + #This will copy our requirements.txt file out, make the khoya_ss lib a dynamic location then cleanup. + echo "Copying $DIR/requirements.txt to /tmp/requirements_tmp.txt" >&3 + echo "Replacing the . for lib to our DIR variable in tmp/requirements_tmp.txt." >&3 + awk -v dir="$DIR" '/#.*kohya_ss.*library/{print; getline; sub(/^\.$/, dir)}1' "$DIR/requirements.txt" >/tmp/requirements_tmp.txt + python -m pip install --use-pep517 --upgrade -r /tmp/requirements_tmp.txt >&3 + + echo "Removing the temp requirements file." + if [ -f /tmp/requirements_tmp.txt ]; then + rm /tmp/requirements_tmp.txt + fi if [ -n "$VIRTUAL_ENV" ]; then if command -v deactivate >/dev/null; then @@ -304,6 +314,7 @@ check_storage_space() { fi } +# These are the git operations that will run to update or clone the repo update_kohya_ss() { if [ "$SKIP_GIT_UPDATE" = false ]; then if command -v git >/dev/null; then @@ -373,12 +384,12 @@ if [[ "$OSTYPE" == "linux-gnu"* ]]; then fi } + # We search for the line starting with ID_LIKE= + # Then we remove the ID_LIKE= prefix to get the name itself + # This is the "type" of distro. For example, Ubuntu returns "debian". get_distro_family() { local line if [ -f /etc/os-release ]; then - # We search for the line starting with ID_LIKE= - # Then we remove the ID_LIKE= prefix to get the name itself - # This is the "type" of distro. For example, Ubuntu returns "debian". if grep -Eiq '^ID_LIKE=' /etc/os-release >/dev/null; then line="$(grep -Ei '^ID_LIKE=' /etc/os-release)" echo "Raw detected os-release distro family line: $line" >&5 From 035dad220a18fa53765529332c9316907c813922 Mon Sep 17 00:00:00 2001 From: JSTayco Date: Fri, 31 Mar 2023 20:51:49 -0700 Subject: [PATCH 31/44] git is now location independent Removed the cd commands in the process --- setup.sh | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/setup.sh b/setup.sh index c9146c3..2bc0c60 100755 --- a/setup.sh +++ b/setup.sh @@ -319,24 +319,23 @@ update_kohya_ss() { if [ "$SKIP_GIT_UPDATE" = false ]; then if command -v git >/dev/null; then # First, we make sure there are no changes that need to be made in git, so no work is lost. - if [ -z "$(git -c "$DIR" status --porcelain=v1 2>/dev/null)" ]; then - echo "There are changes that need to be committed." + if [ -z "$(git -C "$DIR" status --porcelain=v1 >/dev/null)" ]; then + echo "There are changes that need to be committed or discarded in the repo in $DIR." echo "Commit those changes or run this script with -n to skip git operations entirely." exit 1 fi - cd "$PARENT_DIR" || exit 1 echo "Attempting to clone $GIT_REPO." if [ ! -d "$DIR/.git" ]; then - git -c "$DIR" clone "$GIT_REPO" "$(basename "$DIR")" >&3 - cd "$DIR" || exit 1 - git -c "$DIR" checkout -b "$BRANCH" >&3 + git -C "$DIR" clone -b "$BRANCH" "$GIT_REPO" "$(basename "$DIR")" >&3 + git -C "$DIR" switch "$BRANCH" >&3 else - cd "$DIR" || exit 1 echo "git repo detected. Attempting to update repository instead." echo "Updating: $GIT_REPO" - git pull "$GIT_REPO" >&3 - git checkout -b "$BRANCH" + git -C "$DIR" pull "$GIT_REPO" "$BRANCH" >&3 + if ! git -C "$DIR" switch "$BRANCH" >/dev/null; then + git -C "$DIR" switch -c "$BRANCH" >/dev/null + fi fi else echo "You need to install git." From 4559528d336fd8614308d56c86207ace221fdaa7 Mon Sep 17 00:00:00 2001 From: JSTayco Date: Fri, 31 Mar 2023 20:54:21 -0700 Subject: [PATCH 32/44] git is now location indpendent removed all cd commands in process --- setup.sh | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/setup.sh b/setup.sh index 2bc0c60..5f58944 100755 --- a/setup.sh +++ b/setup.sh @@ -319,7 +319,7 @@ update_kohya_ss() { if [ "$SKIP_GIT_UPDATE" = false ]; then if command -v git >/dev/null; then # First, we make sure there are no changes that need to be made in git, so no work is lost. - if [ -z "$(git -C "$DIR" status --porcelain=v1 >/dev/null)" ]; then + if [ -z "$(git -C "$DIR" status --porcelain=v1 >&4)" ]; then echo "There are changes that need to be committed or discarded in the repo in $DIR." echo "Commit those changes or run this script with -n to skip git operations entirely." exit 1 @@ -328,13 +328,13 @@ update_kohya_ss() { echo "Attempting to clone $GIT_REPO." if [ ! -d "$DIR/.git" ]; then git -C "$DIR" clone -b "$BRANCH" "$GIT_REPO" "$(basename "$DIR")" >&3 - git -C "$DIR" switch "$BRANCH" >&3 + git -C "$DIR" switch "$BRANCH" >&4 else echo "git repo detected. Attempting to update repository instead." echo "Updating: $GIT_REPO" git -C "$DIR" pull "$GIT_REPO" "$BRANCH" >&3 - if ! git -C "$DIR" switch "$BRANCH" >/dev/null; then - git -C "$DIR" switch -c "$BRANCH" >/dev/null + if ! git -C "$DIR" switch "$BRANCH" >&4; then + git -C "$DIR" switch -c "$BRANCH" >&4 fi fi else From fbf6709946f45635df25644c5c962dda7781ecda Mon Sep 17 00:00:00 2001 From: JSTayco Date: Fri, 31 Mar 2023 21:03:25 -0700 Subject: [PATCH 33/44] More safeties around git And more error messages in verbose mode --- setup.sh | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/setup.sh b/setup.sh index 5f58944..25b4d00 100755 --- a/setup.sh +++ b/setup.sh @@ -319,7 +319,9 @@ update_kohya_ss() { if [ "$SKIP_GIT_UPDATE" = false ]; then if command -v git >/dev/null; then # First, we make sure there are no changes that need to be made in git, so no work is lost. - if [ -z "$(git -C "$DIR" status --porcelain=v1 >&4)" ]; then + if [ -z "$(git -C "$DIR" status --porcelain=v1 >/dev/null)" ] && + echo "These files need to be committed or discarded: " >&4 && + git -C "$DIR" status >&4; then echo "There are changes that need to be committed or discarded in the repo in $DIR." echo "Commit those changes or run this script with -n to skip git operations entirely." exit 1 @@ -327,6 +329,7 @@ update_kohya_ss() { echo "Attempting to clone $GIT_REPO." if [ ! -d "$DIR/.git" ]; then + echo "Cloning and switching to $GIT_REPO:$BRANCH" >*4 git -C "$DIR" clone -b "$BRANCH" "$GIT_REPO" "$(basename "$DIR")" >&3 git -C "$DIR" switch "$BRANCH" >&4 else @@ -334,6 +337,7 @@ update_kohya_ss() { echo "Updating: $GIT_REPO" git -C "$DIR" pull "$GIT_REPO" "$BRANCH" >&3 if ! git -C "$DIR" switch "$BRANCH" >&4; then + echo "Branch $BRANCH did not exist. Creating it." >&4 git -C "$DIR" switch -c "$BRANCH" >&4 fi fi From a740fdb0064fe8672cfe2b22c0a57a887a27deab Mon Sep 17 00:00:00 2001 From: JSTayco Date: Fri, 31 Mar 2023 21:09:43 -0700 Subject: [PATCH 34/44] Update setup.sh --- setup.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.sh b/setup.sh index 25b4d00..a6b5277 100755 --- a/setup.sh +++ b/setup.sh @@ -319,7 +319,7 @@ update_kohya_ss() { if [ "$SKIP_GIT_UPDATE" = false ]; then if command -v git >/dev/null; then # First, we make sure there are no changes that need to be made in git, so no work is lost. - if [ -z "$(git -C "$DIR" status --porcelain=v1 >/dev/null)" ] && + if [ "$(git -C "$DIR" status --porcelain=v1 >/dev/null)" == "" ] && echo "These files need to be committed or discarded: " >&4 && git -C "$DIR" status >&4; then echo "There are changes that need to be committed or discarded in the repo in $DIR." From 2a24a8b6fc8e8ac39cfb24ac037cdaef151ba41f Mon Sep 17 00:00:00 2001 From: JSTayco Date: Fri, 31 Mar 2023 21:11:24 -0700 Subject: [PATCH 35/44] Update setup.sh --- setup.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.sh b/setup.sh index a6b5277..00fe910 100755 --- a/setup.sh +++ b/setup.sh @@ -319,7 +319,7 @@ update_kohya_ss() { if [ "$SKIP_GIT_UPDATE" = false ]; then if command -v git >/dev/null; then # First, we make sure there are no changes that need to be made in git, so no work is lost. - if [ "$(git -C "$DIR" status --porcelain=v1 >/dev/null)" == "" ] && + if [ "$(git -C "$DIR" status --porcelain=v1 2>/dev/null | wc -l)" -gt 0 ] && echo "These files need to be committed or discarded: " >&4 && git -C "$DIR" status >&4; then echo "There are changes that need to be committed or discarded in the repo in $DIR." From c92153e5467ad02b7296792f3bcf7ed1b5132c34 Mon Sep 17 00:00:00 2001 From: JSTayco Date: Fri, 31 Mar 2023 21:21:50 -0700 Subject: [PATCH 36/44] Hide pip output better with no verbosity --- setup.sh | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/setup.sh b/setup.sh index 00fe910..0bf901e 100755 --- a/setup.sh +++ b/setup.sh @@ -244,7 +244,11 @@ install_python_dependencies() { echo "Copying $DIR/requirements.txt to /tmp/requirements_tmp.txt" >&3 echo "Replacing the . for lib to our DIR variable in tmp/requirements_tmp.txt." >&3 awk -v dir="$DIR" '/#.*kohya_ss.*library/{print; getline; sub(/^\.$/, dir)}1' "$DIR/requirements.txt" >/tmp/requirements_tmp.txt - python -m pip install --use-pep517 --upgrade -r /tmp/requirements_tmp.txt >&3 + if [ $VERBOSITY == 2 ]; then + python -m pip install --quiet --use-pep517 --upgrade -r /tmp/requirements_tmp.txt >&3 + else + python -m pip install --use-pep517 --upgrade -r /tmp/requirements_tmp.txt >&3 + fi echo "Removing the temp requirements file." if [ -f /tmp/requirements_tmp.txt ]; then From b2e7d5f419f17fba277740db81252415c1d25b56 Mon Sep 17 00:00:00 2001 From: JSTayco Date: Fri, 31 Mar 2023 21:38:27 -0700 Subject: [PATCH 37/44] Minor help text formatting. --- README.md | 2 +- setup.sh | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 09c35e4..04018cc 100644 --- a/README.md +++ b/README.md @@ -84,7 +84,7 @@ Usage: Options: -b BRANCH, --branch=BRANCH Select which branch of kohya to check out on new installs. -d DIR, --dir=DIR The full path you want kohya_ss installed to. - -g, --git_repo You can optionally provide a git repo to check out for runpod installation. Useful for custom forks. + -g REPO, --git_repo=REPO You can optionally provide a git repo to check out for runpod installation. Useful for custom forks. -h, --help Show this screen. -i, --interactive Interactively configure accelerate instead of using default config file. -n, --no-update Do not update kohya_ss repo. No git pull or clone operations. diff --git a/setup.sh b/setup.sh index 0bf901e..bc98d6e 100755 --- a/setup.sh +++ b/setup.sh @@ -19,10 +19,10 @@ Usage: Options: -b BRANCH, --branch=BRANCH Select which branch of kohya to check out on new installs. -d DIR, --dir=DIR The full path you want kohya_ss installed to. - -g, --git_repo You can optionally provide a git repo to check out for runpod installation. Useful for custom forks. + -g REPO, --git_repo=REPO You can optionally provide a git repo to check out for runpod installation. Useful for custom forks. -h, --help Show this screen. -i, --interactive Interactively configure accelerate instead of using default config file. - -n, --no-git-update Do not update kohya_ss repo. No git pull or clone operations. + -n, --no-git-update Do not update kohya_ss repo. No git pull or clone operations. -p, --public Expose public URL in runpod mode. Won't have an effect in other modes. -r, --runpod Forces a runpod installation. Useful if detection fails for any reason. -s, --skip-space-check Skip the 10Gb minimum storage space check. From 97b004e756e3cec6195c87d313b98276f1a98a83 Mon Sep 17 00:00:00 2001 From: bmaltais Date: Sat, 1 Apr 2023 06:29:45 -0400 Subject: [PATCH 38/44] Revert "Merge pull request #466 from jstayco/consolidated_install_scripts" This reverts commit b7a719b51a56bb4094512c76f48ef118649874c6, reversing changes made to 538752ccab9032d9fd7035bfb5ba2583b30747ca. --- .gitignore | 3 +- README.md | 155 ++---- config_files/accelerate/default_config.yaml | 22 - gui.sh | 12 +- gui_macos.sh | 13 + macos_setup.sh | 38 ++ requirements.txt | 9 +- requirements_macos.txt | 32 ++ setup.sh | 578 -------------------- ubuntu_setup.sh | 12 + upgrade.sh | 16 + upgrade_macos.sh | 16 + 12 files changed, 184 insertions(+), 722 deletions(-) delete mode 100644 config_files/accelerate/default_config.yaml create mode 100755 gui_macos.sh create mode 100755 macos_setup.sh create mode 100644 requirements_macos.txt delete mode 100755 setup.sh create mode 100755 ubuntu_setup.sh create mode 100755 upgrade.sh create mode 100755 upgrade_macos.sh diff --git a/.gitignore b/.gitignore index 5ed6370..71fe116 100644 --- a/.gitignore +++ b/.gitignore @@ -8,5 +8,4 @@ wd14_tagger_model .DS_Store locon gui-user.bat -gui-user.ps1 -.idea \ No newline at end of file +gui-user.ps1 \ No newline at end of file diff --git a/README.md b/README.md index 04018cc..3614d94 100644 --- a/README.md +++ b/README.md @@ -6,30 +6,21 @@ If you run on Linux and would like to use the GUI, there is now a port of it as ### Table of Contents -- [Tutorials](#tutorials) -- [Required Dependencies](#required-dependencies) - - [Linux/macOS](#linux-and-macos-dependencies) -- [Installation](#installation) - - [Linux/macOS](#linux-and-macos) - - [Default Install Locations](#install-location) - - [Windows](#windows) - - [CUDNN 8.6](#optional--cudnn-86) -- [Upgrading](#upgrading) - - [Windows](#windows-upgrade) - - [Linux/macOS](#linux-and-macos-upgrade) -- [Launching the GUI](#starting-gui-service) - - [Windows](#launching-the-gui-on-windows) - - [Linux/macOS](#launching-the-gui-on-linux-and-macos) - - [Direct Launch via Python Script](#launching-the-gui-directly-using-kohyaguipy) -- [Dreambooth](#dreambooth) -- [Finetune](#finetune) -- [Train Network](#train-network) -- [LoRA](#lora) -- [Troubleshooting](#troubleshooting) - - [Page File Limit](#page-file-limit) - - [No module called tkinter](#no-module-called-tkinter) - - [FileNotFoundError](#filenotfounderror) -- [Change History](#change-history) +- [Tutorials](https://github.com/bmaltais/kohya_ss#tutorials) +- [Required Dependencies](https://github.com/bmaltais/kohya_ss#required-dependencies) +- [Installation](https://github.com/bmaltais/kohya_ss#installation) + - [CUDNN 8.6](https://github.com/bmaltais/kohya_ss#optional-cudnn-86) +- [Upgrading](https://github.com/bmaltais/kohya_ss#upgrading) +- [Launching the GUI](https://github.com/bmaltais/kohya_ss#launching-the-gui) +- [Dreambooth](https://github.com/bmaltais/kohya_ss#dreambooth) +- [Finetune](https://github.com/bmaltais/kohya_ss#finetune) +- [Train Network](https://github.com/bmaltais/kohya_ss#train-network) +- [LoRA](https://github.com/bmaltais/kohya_ss#lora) +- [Troubleshooting](https://github.com/bmaltais/kohya_ss#troubleshooting) + - [Page File Limit](https://github.com/bmaltais/kohya_ss#page-file-limit) + - [No module called tkinter](https://github.com/bmaltais/kohya_ss#no-module-called-tkinter) + - [FileNotFoundError](https://github.com/bmaltais/kohya_ss#filenotfounderror) +- [Change History](https://github.com/bmaltais/kohya_ss#change-history) ## Tutorials @@ -48,66 +39,35 @@ If you run on Linux and would like to use the GUI, there is now a port of it as - Install [Git](https://git-scm.com/download/win) - Install [Visual Studio 2015, 2017, 2019, and 2022 redistributable](https://aka.ms/vs/17/release/vc_redist.x64.exe) -### Linux and macOS dependencies - -These dependencies are taken care of via `setup.sh` in the installation section. No additional steps should be needed unless the scripts inform you otherwise. - ## Installation ### Runpod Follow the instructions found in this discussion: https://github.com/bmaltais/kohya_ss/discussions/379 -### Linux and macOS +### MacOS In the terminal, run ``` git clone https://github.com/bmaltais/kohya_ss.git cd kohya_ss -# May need to chmod +x ./setup.sh if you're on a machine with stricter security. -# There are additional options if needed for a runpod environment. -# Call 'setup.sh -h' or 'setup.sh --help' for more information. -./setup.sh +bash macos_setup.sh ``` -Setup.sh help included here: +During the accelerate config screen after running the script answer "This machine", "None", "No" for the remaining questions. -```bash -Kohya_SS Installation Script for POSIX operating systems. +### Ubuntu +In the terminal, run -The following options are useful in a runpod environment, -but will not affect a local machine install. - -Usage: - setup.sh -b dev -d /workspace/kohya_ss -g https://mycustom.repo.tld/custom_fork.git - setup.sh --branch=dev --dir=/workspace/kohya_ss --git-repo=https://mycustom.repo.tld/custom_fork.git - -Options: - -b BRANCH, --branch=BRANCH Select which branch of kohya to check out on new installs. - -d DIR, --dir=DIR The full path you want kohya_ss installed to. - -g REPO, --git_repo=REPO You can optionally provide a git repo to check out for runpod installation. Useful for custom forks. - -h, --help Show this screen. - -i, --interactive Interactively configure accelerate instead of using default config file. - -n, --no-update Do not update kohya_ss repo. No git pull or clone operations. - -p, --public Expose public URL in runpod mode. Won't have an effect in other modes. - -r, --runpod Forces a runpod installation. Useful if detection fails for any reason. - -s, --skip-space-check Skip the 10Gb minimum storage space check. - -v, --verbose Increase verbosity levels up to 3. +``` +git clone https://github.com/bmaltais/kohya_ss.git +cd kohya_ss +bash ubuntu_setup.sh ``` -#### Install location - -The default install location for Linux is where the script is located if a previous installation is detected that location. -Otherwise, it will fall to `/opt/kohya_ss`. If /opt is not writeable, the fallback is `$HOME/kohya_ss`. Lastly, if all else fails it will simply install to the current folder you are in (PWD). - -On macOS and other non-Linux machines, it will first try to detect an install where the script is run from and then run setup there if that's detected. -If a previous install isn't found at that location, then it will default install to `$HOME/kohya_ss` followed by where you're currently at if there's no access to $HOME. -You can override this behavior by specifying an install directory with the -d option. - -If you are using the interactive mode, our default values for the accelerate config screen after running the script answer "This machine", "None", "No" for the remaining questions. -These are the same answers as the Windows install. +then configure accelerate with the same answers as in the MacOS instructions when prompted. ### Windows -In the terminal, run: +In the terminal, run ``` git clone https://github.com/bmaltais/kohya_ss.git @@ -115,7 +75,7 @@ cd kohya_ss setup.bat ``` -Then configure accelerate with the same answers as in the MacOS instructions when prompted. +then configure accelerate with the same answers as in the MacOS instructions when prompted. ### Optional: CUDNN 8.6 @@ -133,58 +93,38 @@ Run the following commands to install: python .\tools\cudann_1.8_install.py ``` -Once the commands have completed successfully you should be ready to use the new version. MacOS support is not tested and has been mostly taken from https://gist.github.com/jstayco/9f5733f05b9dc29de95c4056a023d645 +## Upgrading MacOS -## Upgrading +When a new release comes out, you can upgrade your repo with the following commands in the root directory: -The following commands will work from the root directory of the project if you'd prefer to not run scripts. -These commands will work on any OS. ```bash -git pull - -.\venv\Scripts\activate - -pip install --use-pep517 --upgrade -r requirements.txt +upgrade_macos.sh ``` -### Windows Upgrade +Once the commands have completed successfully you should be ready to use the new version. MacOS support is not tested and has been mostly taken from https://gist.github.com/jstayco/9f5733f05b9dc29de95c4056a023d645 + +## Upgrading Windows + When a new release comes out, you can upgrade your repo with the following commands in the root directory: ```powershell upgrade.bat ``` -### Linux and macOS Upgrade -You can cd into the root directory and simply run - -```bash -# Refresh and update everything -./setup.sh - -# This will refresh everything, but NOT clone or pull the git repo. -./setup.sh --no-git-update -``` - Once the commands have completed successfully you should be ready to use the new version. -# Starting GUI Service +## Launching the GUI using gui.bat or gui.ps1 + +The script can be run with several optional command line arguments: -The following command line arguments can be passed to the scripts on any OS to configure the underlying service. -``` --listen: the IP address to listen on for connections to Gradio. ---username: a username for authentication. ---password: a password for authentication. ---server_port: the port to run the server listener on. ---inbrowser: opens the Gradio UI in a web browser. +--username: a username for authentication. +--password: a password for authentication. +--server_port: the port to run the server listener on. +--inbrowser: opens the Gradio UI in a web browser. --share: shares the Gradio UI. -``` -### Launching the GUI on Windows - -The two scripts to launch the GUI on Windows are gui.ps1 and gui.bat in the root directory. -You can use whichever script you prefer. - -To launch the Gradio UI, run the script in a terminal with the desired command line arguments, for example: +These command line arguments can be passed to the UI function as keyword arguments. To launch the Gradio UI, run the script in a terminal with the desired command line arguments, for example: `gui.ps1 --listen 127.0.0.1 --server_port 7860 --inbrowser --share` @@ -192,19 +132,14 @@ or `gui.bat --listen 127.0.0.1 --server_port 7860 --inbrowser --share` -## Launching the GUI on Linux and macOS +## Launching the GUI using kohya_gui.py -Run the launcher script with the desired command line arguments similar to Windows. -`gui.sh --listen 127.0.0.1 --server_port 7860 --inbrowser --share` - -## Launching the GUI directly using kohya_gui.py - -To run the GUI directly bypassing the wrapper scripts, simply use this command from the root project directory: +To run the GUI, simply use this command: ``` .\venv\Scripts\activate -python .\kohya_gui.py +python.exe .\kohya_gui.py ``` ## Dreambooth diff --git a/config_files/accelerate/default_config.yaml b/config_files/accelerate/default_config.yaml deleted file mode 100644 index a31ddd0..0000000 --- a/config_files/accelerate/default_config.yaml +++ /dev/null @@ -1,22 +0,0 @@ -command_file: null -commands: null -compute_environment: LOCAL_MACHINE -deepspeed_config: {} -distributed_type: 'NO' -downcast_bf16: 'no' -dynamo_backend: 'NO' -fsdp_config: {} -gpu_ids: all -machine_rank: 0 -main_process_ip: null -main_process_port: null -main_training_function: main -megatron_lm_config: {} -mixed_precision: 'no' -num_machines: 1 -num_processes: 1 -rdzv_backend: static -same_network: true -tpu_name: null -tpu_zone: null -use_cpu: false diff --git a/gui.sh b/gui.sh index 4fa2e35..e4eca6f 100755 --- a/gui.sh +++ b/gui.sh @@ -1,9 +1,13 @@ -#!/usr/bin/env bash +#!/bin/bash # Activate the virtual environment -source ./venv/bin/activate +source venv/bin/activate -# If the requirements are validated, run the kohya_gui.py script with the command-line arguments -if python tools/validate_requirements.py; then +# Validate the requirements and store the exit code +python tools/validate_requirements.py +exit_code=$? + +# If the exit code is 0, run the kohya_gui.py script with the command-line arguments +if [ $exit_code -eq 0 ]; then python kohya_gui.py "$@" fi diff --git a/gui_macos.sh b/gui_macos.sh new file mode 100755 index 0000000..4a0bfb8 --- /dev/null +++ b/gui_macos.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +# Activate the virtual environment +source venv/bin/activate + +# Validate the requirements and store the exit code +python tools/validate_requirements.py --requirements requirements_macos.txt +exit_code=$? + +# If the exit code is 0, run the kohya_gui.py script with the command-line arguments +if [ $exit_code -eq 0 ]; then + python kohya_gui.py "$@" +fi diff --git a/macos_setup.sh b/macos_setup.sh new file mode 100755 index 0000000..4de8417 --- /dev/null +++ b/macos_setup.sh @@ -0,0 +1,38 @@ +#!/bin/bash +# The initial setup script to prep the environment on macOS +# xformers has been omitted as that is for Nvidia GPUs only + +if ! command -v brew >/dev/null; then + echo "Please install homebrew first. This is a requirement for the remaining setup." + echo "You can find that here: https://brew.sh" + exit 1 +fi + +# Install base python packages +echo "Installing Python 3.10 if not found." +brew ls --versions python@3.10 >/dev/null || brew install python@3.10 +echo "Installing Python-TK 3.10 if not found." +brew ls --versions python-tk@3.10 >/dev/null || brew install python-tk@3.10 + +if command -v python3.10 >/dev/null; then + python3.10 -m venv venv + source venv/bin/activate + + # DEBUG ONLY + #pip install pydevd-pycharm~=223.8836.43 + + # Tensorflow installation + if wget https://github.com/apple/tensorflow_macos/releases/download/v0.1alpha3/tensorflow_macos-0.1a3-cp38-cp38-macosx_11_0_arm64.whl /tmp; then + python -m pip install tensorflow==0.1a3 -f https://github.com/apple/tensorflow_macos/releases/download/v0.1alpha3/tensorflow_macos-0.1a3-cp38-cp38-macosx_11_0_arm64.whl + rm -f /tmp/tensorflow_macos-0.1a3-cp38-cp38-macosx_11_0_arm64.whl + fi + + pip install torch==2.0.0 torchvision==0.15.1 -f https://download.pytorch.org/whl/cpu/torch_stable.html + python -m pip install --use-pep517 --upgrade -r requirements_macos.txt + accelerate config + echo -e "Setup finished! Run ./gui_macos.sh to start." +else + echo "Python not found. Please ensure you install Python." + echo "The brew command for Python 3.10 is: brew install python@3.10" + exit 1 +fi \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 01f6503..5881d6e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,14 +7,12 @@ diffusers[torch]==0.10.2 easygui==0.98.3 einops==0.6.0 ftfy==6.1.1 -gradio==3.19.1; sys_platform != 'darwin' -gradio==3.23.0; sys_platform == 'darwin' +gradio==3.19.1 lion-pytorch==0.0.6 opencv-python==4.7.0.68 pytorch-lightning==1.9.0 safetensors==0.2.6 -tensorboard==2.10.1 ; sys_platform != 'darwin' -tensorboard==2.12.1 ; sys_platform == 'darwin' +tensorboard==2.10.1 tk==0.1.0 toml==0.10.2 transformers==4.26.0 @@ -25,8 +23,7 @@ requests==2.28.2 timm==0.6.12 # tensorflow<2.11 huggingface-hub==0.13.0 -tensorflow==2.10.1; sys_platform != 'darwin' -tensorflow-macos==2.12.0; sys_platform == 'darwin' +tensorflow==2.10.1 # For locon support lycoris-lora @ git+https://github.com/KohakuBlueleaf/LyCORIS.git@c3d925421209a22a60d863ffa3de0b3e7e89f047 # lycoris_lora==0.1.4 diff --git a/requirements_macos.txt b/requirements_macos.txt new file mode 100644 index 0000000..4ee4eec --- /dev/null +++ b/requirements_macos.txt @@ -0,0 +1,32 @@ +accelerate==0.15.0 +albumentations==1.3.0 +altair==4.2.2 +bitsandbytes==0.35.0 +dadaptation==1.5 +diffusers[torch]==0.10.2 +easygui==0.98.3 +einops==0.6.0 +ftfy==6.1.1 +gradio==3.19.1; sys_platform != 'darwin' +gradio==3.23.0; sys_platform == 'darwin' +lion-pytorch==0.0.6 +opencv-python==4.7.0.68 +pytorch-lightning==1.9.0 +safetensors==0.2.6 +tensorboard==2.10.1 +tk==0.1.0 +toml==0.10.2 +transformers==4.26.0 +voluptuous==0.13.1 +# for BLIP captioning +fairscale==0.4.13 +requests==2.28.2 +timm==0.6.12 +# tensorflow<2.11 +huggingface-hub==0.12.0; sys_platform != 'darwin' +huggingface-hub==0.13.0; sys_platform == 'darwin' +tensorflow==2.10.1; sys_platform != 'darwin' +# For locon support +lycoris_lora==0.1.2 +# for kohya_ss library +. \ No newline at end of file diff --git a/setup.sh b/setup.sh deleted file mode 100755 index bc98d6e..0000000 --- a/setup.sh +++ /dev/null @@ -1,578 +0,0 @@ -#!/usr/bin/env bash - -# This file will be the host environment setup file for all operating systems other than base Windows. - -display_help() { - cat <&2" #Don't change anything higher than the maximum verbosity allowed. -done - -for v in $( #From the verbosity level one higher than requested, through the maximum; - seq $((VERBOSITY + 1)) $MAXVERBOSITY -); do - (("$v" > "2")) && eval exec "$v>/dev/null" #Redirect these to bitbucket, provided that they don't match stdout and stderr. -done - -# Example of how to use the verbosity levels. -# printf "%s\n" "This message is seen at verbosity level 1 and above." >&3 -# printf "%s\n" "This message is seen at verbosity level 2 and above." >&4 -# printf "%s\n" "This message is seen at verbosity level 3 and above." >&5 - -# Debug variable dump at max verbosity -echo "BRANCH: $BRANCH -DIR: $DIR -GIT_REPO: $GIT_REPO -INTERACTIVE: $INTERACTIVE -PUBLIC: $PUBLIC -RUNPOD: $RUNPOD -SKIP_SPACE_CHECK: $SKIP_SPACE_CHECK -VERBOSITY: $VERBOSITY -Script directory is ${SCRIPT_DIR}." >&5 - -# This must be set after the getopts loop to account for $DIR changes. -PARENT_DIR="$(dirname "${DIR}")" -VENV_DIR="$DIR/venv" - -# Shared functions -# This checks for free space on the installation drive and returns that in Gb. -size_available() { - local folder - if [ -d "$DIR" ]; then - folder="$DIR" - elif [ -d "$PARENT_DIR" ]; then - folder="$PARENT_DIR" - elif [ -d "$(echo "$DIR" | cut -d "/" -f2)" ]; then - folder="$(echo "$DIR" | cut -d "/" -f2)" - else - echo "We are assuming a root drive install for space-checking purposes." - folder='/' - fi - - local FREESPACEINKB - FREESPACEINKB="$(df -Pk "$folder" | sed 1d | grep -v used | awk '{ print $4 "\t" }')" - echo "Detected available space in Kb: $FREESPACEINKB" >&5 - local FREESPACEINGB - FREESPACEINGB=$((FREESPACEINKB / 1024 / 1024)) - echo "$FREESPACEINGB" -} - -# The expected usage is create_symlinks symlink target_file -create_symlinks() { - echo "Checking symlinks now." - # Next line checks for valid symlink - if [ -L "$1" ]; then - # Check if the linked file exists and points to the expected file - if [ -e "$1" ] && [ "$(readlink "$1")" == "$2" ]; then - echo "$(basename "$1") symlink looks fine. Skipping." - else - if [ -f "$2" ]; then - echo "Broken symlink detected. Recreating $(basename "$1")." - rm "$1" && - ln -s "$2" "$1" - else - echo "$2 does not exist. Nothing to link." - fi - fi - else - echo "Linking $(basename "$1")." - ln -s "$2" "$1" - fi -} - -install_python_dependencies() { - # Switch to local virtual env - echo "Switching to virtual Python environment." - if command -v python3 >/dev/null; then - python3 -m venv "$DIR/venv" - elif command -v python3.10 >/dev/null; then - python3.10 -m venv "$DIR/venv" - else - echo "Valid python3 or python3.10 binary not found." - echo "Cannot proceed with the python steps." - return 1 - fi - - # Activate the virtual environment - source "$DIR/venv/bin/activate" - - # Updating pip if there is one - echo "Checking for pip updates before Python operations." - python3 -m pip install --upgrade pip >&3 - - echo "Installing python dependencies. This could take a few minutes as it downloads files." - echo "If this operation ever runs too long, you can rerun this script in verbose mode to check." - case "$OSTYPE" in - "linux-gnu"*) pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 \ - --extra-index-url https://download.pytorch.org/whl/cu116 >&3 && - pip install -U -I --no-deps \ - https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/downloadlinux/xformers-0.0.14.dev0-cp310-cp310-linux_x86_64.whl >&3 ;; - "darwin"*) pip install torch==2.0.0 torchvision==0.15.1 \ - -f https://download.pytorch.org/whl/cpu/torch_stable.html >&3 ;; - "cygwin") - : - ;; - "msys") - : - ;; - esac - - if [ "$RUNPOD" = true ]; then - echo "Installing tenssort." - pip install tensorrt >&3 - fi - - # DEBUG ONLY (Update this version number to whatever PyCharm recommends) - # pip install pydevd-pycharm~=223.8836.43 - - #This will copy our requirements.txt file out, make the khoya_ss lib a dynamic location then cleanup. - echo "Copying $DIR/requirements.txt to /tmp/requirements_tmp.txt" >&3 - echo "Replacing the . for lib to our DIR variable in tmp/requirements_tmp.txt." >&3 - awk -v dir="$DIR" '/#.*kohya_ss.*library/{print; getline; sub(/^\.$/, dir)}1' "$DIR/requirements.txt" >/tmp/requirements_tmp.txt - if [ $VERBOSITY == 2 ]; then - python -m pip install --quiet --use-pep517 --upgrade -r /tmp/requirements_tmp.txt >&3 - else - python -m pip install --use-pep517 --upgrade -r /tmp/requirements_tmp.txt >&3 - fi - - echo "Removing the temp requirements file." - if [ -f /tmp/requirements_tmp.txt ]; then - rm /tmp/requirements_tmp.txt - fi - - if [ -n "$VIRTUAL_ENV" ]; then - if command -v deactivate >/dev/null; then - echo "Exiting Python virtual environment." - deactivate - else - echo "deactivate command not found. Could still be in the Python virtual environment." - fi - fi -} - -# Attempt to non-interactively install a default accelerate config file unless specified otherwise. -# Documentation for order of precedence locations for configuration file for automated installation: -# https://huggingface.co/docs/accelerate/basic_tutorials/launch#custom-configurations -configure_accelerate() { - echo "Source accelerate config location: $DIR/config_files/accelerate/default_config.yaml" >&3 - if [ "$INTERACTIVE" = true ]; then - accelerate config - else - if env_var_exists HF_HOME; then - if [ ! -f "$HF_HOME/accelerate/default_config.yaml" ]; then - mkdir -p "$HF_HOME/accelerate/" && - echo "Target accelerate config location: $HF_HOME/accelerate/default_config.yaml" >&3 - cp "$DIR/config_files/accelerate/default_config.yaml" "$HF_HOME/accelerate/default_config.yaml" && - echo "Copied accelerate config file to: $HF_HOME/accelerate/default_config.yaml" - fi - elif env_var_exists XDG_CACHE_HOME; then - if [ ! -f "$XDG_CACHE_HOME/huggingface/accelerate" ]; then - mkdir -p "$XDG_CACHE_HOME/huggingface/accelerate" && - echo "Target accelerate config location: $XDG_CACHE_HOME/accelerate/default_config.yaml" >&3 - cp "$DIR/config_files/accelerate/default_config.yaml" "$XDG_CACHE_HOME/huggingface/accelerate/default_config.yaml" && - echo "Copied accelerate config file to: $XDG_CACHE_HOME/huggingface/accelerate/default_config.yaml" - fi - elif env_var_exists HOME; then - if [ ! -f "$HOME/.cache/huggingface/accelerate" ]; then - mkdir -p "$HOME/.cache/huggingface/accelerate" && - echo "Target accelerate config location: $HOME/accelerate/default_config.yaml" >&3 - cp "$DIR/config_files/accelerate/default_config.yaml" "$HOME/.cache/huggingface/accelerate/default_config.yaml" && - echo "Copying accelerate config file to: $HOME/.cache/huggingface/accelerate/default_config.yaml" - fi - else - echo "Could not place the accelerate configuration file. Please configure manually." - sleep 2 - accelerate config - fi - fi -} - -# Offer a warning and opportunity to cancel the installation if < 10Gb of Free Space detected -check_storage_space() { - if [ "$SKIP_SPACE_CHECK" = false ]; then - if [ "$(size_available)" -lt 10 ]; then - echo "You have less than 10Gb of free space. This installation may fail." - MSGTIMEOUT=10 # In seconds - MESSAGE="Continuing in..." - echo "Press control-c to cancel the installation." - for ((i = MSGTIMEOUT; i >= 0; i--)); do - printf "\r${MESSAGE} %ss. " "${i}" - sleep 1 - done - fi - fi -} - -# These are the git operations that will run to update or clone the repo -update_kohya_ss() { - if [ "$SKIP_GIT_UPDATE" = false ]; then - if command -v git >/dev/null; then - # First, we make sure there are no changes that need to be made in git, so no work is lost. - if [ "$(git -C "$DIR" status --porcelain=v1 2>/dev/null | wc -l)" -gt 0 ] && - echo "These files need to be committed or discarded: " >&4 && - git -C "$DIR" status >&4; then - echo "There are changes that need to be committed or discarded in the repo in $DIR." - echo "Commit those changes or run this script with -n to skip git operations entirely." - exit 1 - fi - - echo "Attempting to clone $GIT_REPO." - if [ ! -d "$DIR/.git" ]; then - echo "Cloning and switching to $GIT_REPO:$BRANCH" >*4 - git -C "$DIR" clone -b "$BRANCH" "$GIT_REPO" "$(basename "$DIR")" >&3 - git -C "$DIR" switch "$BRANCH" >&4 - else - echo "git repo detected. Attempting to update repository instead." - echo "Updating: $GIT_REPO" - git -C "$DIR" pull "$GIT_REPO" "$BRANCH" >&3 - if ! git -C "$DIR" switch "$BRANCH" >&4; then - echo "Branch $BRANCH did not exist. Creating it." >&4 - git -C "$DIR" switch -c "$BRANCH" >&4 - fi - fi - else - echo "You need to install git." - echo "Rerun this after installing git or run this script with -n to skip the git operations." - fi - else - echo "Skipping git operations." - fi -} - -# Start OS-specific detection and work -if [[ "$OSTYPE" == "linux-gnu"* ]]; then - # Check if root or sudo - root=false - if [ "$EUID" = 0 ]; then - root=true - elif command -v id >/dev/null && [ "$(id -u)" = 0 ]; then - root=true - elif [ "$UID" = 0 ]; then - root=true - fi - - get_distro_name() { - local line - if [ -f /etc/os-release ]; then - # We search for the line starting with ID= - # Then we remove the ID= prefix to get the name itself - line="$(grep -Ei '^ID=' /etc/os-release)" - echo "Raw detected os-release distro line: $line" >&5 - line=${line##*=} - echo "$line" - return 0 - elif command -v python >/dev/null; then - line="$(python -mplatform)" - echo "$line" - return 0 - elif command -v python3 >/dev/null; then - line="$(python3 -mplatform)" - echo "$line" - return 0 - else - line="None" - echo "$line" - return 1 - fi - } - - # We search for the line starting with ID_LIKE= - # Then we remove the ID_LIKE= prefix to get the name itself - # This is the "type" of distro. For example, Ubuntu returns "debian". - get_distro_family() { - local line - if [ -f /etc/os-release ]; then - if grep -Eiq '^ID_LIKE=' /etc/os-release >/dev/null; then - line="$(grep -Ei '^ID_LIKE=' /etc/os-release)" - echo "Raw detected os-release distro family line: $line" >&5 - line=${line##*=} - echo "$line" - return 0 - else - line="None" - echo "$line" - return 1 - fi - else - line="None" - echo "$line" - return 1 - fi - } - - check_storage_space - - # This is the pre-install work for a kohya installation on a runpod - if [ "$RUNPOD" = true ]; then - if [ -d "$VENV_DIR" ]; then - echo "Pre-existing installation on a runpod detected." - export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:"$VENV_DIR"/lib/python3.10/site-packages/tensorrt/ - export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:"$VENV_DIR"/lib/python3.10/site-packages/nvidia/cuda_runtime/lib/ - cd "$DIR" || exit 1 - fi - fi - - update_kohya_ss - - distro=get_distro_name - family=get_distro_family - echo "Raw detected distro string: $distro" >&4 - echo "Raw detected distro family string: $family" >&4 - - echo "Installing Python TK if not found on the system." - if "$distro" | grep -qi "Ubuntu" || "$family" | grep -qi "Ubuntu"; then - echo "Ubuntu detected." - if [ $(dpkg-query -W -f='${Status}' python3-tk 2>/dev/null | grep -c "ok installed") = 0 ]; then - if [ "$root" = true ]; then - apt update -y >&3 && apt install -y python3-tk >&3 - else - echo "This script needs to be run as root or via sudo to install packages." - exit 1 - fi - else - echo "Python TK found! Skipping install!" - fi - elif "$distro" | grep -Eqi "Fedora|CentOS|Redhat"; then - echo "Redhat or Redhat base detected." - if ! rpm -qa | grep -qi python3-tkinter; then - if [ "$root" = true ]; then - dnf install python3-tkinter -y >&3 - else - echo "This script needs to be run as root or via sudo to install packages." - exit 1 - fi - fi - elif "$distro" | grep -Eqi "arch" || "$family" | grep -qi "arch"; then - echo "Arch Linux or Arch base detected." - if ! pacman -Qi tk >/dev/null; then - if [ "$root" = true ]; then - pacman --noconfirm -S tk >&3 - else - echo "This script needs to be run as root or via sudo to install packages." - exit 1 - fi - fi - elif "$distro" | grep -Eqi "opensuse" || "$family" | grep -qi "opensuse"; then - echo "OpenSUSE detected." - if ! rpm -qa | grep -qi python-tk; then - if [ "$root" = true ]; then - zypper install -y python-tk >&3 - else - echo "This script needs to be run as root or via sudo to install packages." - exit 1 - fi - fi - elif [ "$distro" = "None" ] || [ "$family" = "None" ]; then - if [ "$distro" = "None" ]; then - echo "We could not detect your distribution of Linux. Please file a bug report on github with the contents of your /etc/os-release file." - fi - - if [ "$family" = "None" ]; then - echo "We could not detect the family of your Linux distribution. Please file a bug report on github with the contents of your /etc/os-release file." - fi - fi - - install_python_dependencies - - # We need just a little bit more setup for non-interactive environments - if [ "$RUNPOD" = true ]; then - # Symlink paths - libnvinfer_plugin_symlink="$VENV_DIR/lib/python3.10/site-packages/tensorrt/libnvinfer_plugin.so.7" - libnvinfer_symlink="$VENV_DIR/lib/python3.10/site-packages/tensorrt/libnvinfer.so.7" - libcudart_symlink="$VENV_DIR/lib/python3.10/site-packages/nvidia/cuda_runtime/lib/libcudart.so.11.0" - - #Target file paths - libnvinfer_plugin_target="$VENV_DIR/lib/python3.10/site-packages/tensorrt/libnvinfer_plugin.so.8" - libnvinfer_target="$VENV_DIR/lib/python3.10/site-packages/tensorrt/libnvinfer.so.8" - libcudart_target="$VENV_DIR/lib/python3.10/site-packages/nvidia/cuda_runtime/lib/libcudart.so.12" - - echo "Checking symlinks now." - create_symlinks "$libnvinfer_plugin_symlink" "$libnvinfer_plugin_target" - create_symlinks "$libnvinfer_symlink" "$libnvinfer_target" - create_symlinks "$libcudart_symlink" "$libcudart_target" - - export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:$VENV_DIR/lib/python3.10/site-packages/tensorrt/" - export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:$VENV_DIR/lib/python3.10/site-packages/nvidia/cuda_runtime/lib/" - - configure_accelerate - - # This is a non-interactive environment, so just directly call gui.sh after all setup steps are complete. - if command -v bash >/dev/null; then - if [ "$PUBLIC" = false ]; then - bash "$DIR"/gui.sh - else - bash "$DIR"/gui.sh --share - fi - else - # This shouldn't happen, but we're going to try to help. - if [ "$PUBLIC" = false ]; then - sh "$DIR"/gui.sh - else - sh "$DIR"/gui.sh --share - fi - fi - fi - - echo -e "Setup finished! Run \e[0;92m./gui.sh\e[0m to start." - echo "Please note if you'd like to expose your public server you need to run ./gui.sh --share" -elif [[ "$OSTYPE" == "darwin"* ]]; then - # The initial setup script to prep the environment on macOS - # xformers has been omitted as that is for Nvidia GPUs only - - if ! command -v brew >/dev/null; then - echo "Please install homebrew first. This is a requirement for the remaining setup." - echo "You can find that here: https://brew.sh" - #shellcheck disable=SC2016 - echo 'The "brew" command should be in $PATH to be detected.' - exit 1 - fi - - check_storage_space - - # Install base python packages - echo "Installing Python 3.10 if not found." - if ! brew ls --versions python@3.10 >/dev/null; then - echo "Installing Python 3.10." - brew install python@3.10 >&3 - else - echo "Python 3.10 found!" - fi - echo "Installing Python-TK 3.10 if not found." - if ! brew ls --versions python-tk@3.10 >/dev/null; then - echo "Installing Python TK 3.10." - brew install python-tk@3.10 >&3 - else - echo "Python Tkinter 3.10 found!" - fi - - update_kohya_ss - - if ! install_python_dependencies; then - echo "You may need to install Python. The command for this is brew install python@3.10." - fi - - configure_accelerate - echo -e "Setup finished! Run ./gui.sh to start." -elif [[ "$OSTYPE" == "cygwin" ]]; then - # Cygwin is a standalone suite of Linux utilies on Windows - echo "This hasn't been validated on cygwin yet." -elif [[ "$OSTYPE" == "msys" ]]; then - # MinGW has the msys environment which is a standalone suite of Linux utilies on Windows - # "git bash" on Windows may also be detected as msys. - echo "This hasn't been validated in msys (mingw) on Windows yet." -fi diff --git a/ubuntu_setup.sh b/ubuntu_setup.sh new file mode 100755 index 0000000..1431155 --- /dev/null +++ b/ubuntu_setup.sh @@ -0,0 +1,12 @@ +#!/bin/bash +echo installing tk +sudo apt install python3-tk +python3 -m venv venv +source venv/bin/activate +pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116 +pip install --use-pep517 --upgrade -r requirements.txt +pip install -U -I --no-deps https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/linux/xformers-0.0.14.dev0-cp310-cp310-linux_x86_64.whl + +accelerate config + +echo -e "setup finished! run \e[0;92m./gui.sh\e[0m to start" diff --git a/upgrade.sh b/upgrade.sh new file mode 100755 index 0000000..f01e7b7 --- /dev/null +++ b/upgrade.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +# Check if there are any changes that need to be committed +if [[ -n $(git status --short) ]]; then + echo "There are changes that need to be committed. Please stash or undo your changes before running this script." >&2 + exit 1 +fi + +# Pull the latest changes from the remote repository +git pull + +# Activate the virtual environment +source venv/bin/activate + +# Upgrade the required packages +pip install --upgrade -r requirements.txt diff --git a/upgrade_macos.sh b/upgrade_macos.sh new file mode 100755 index 0000000..2e26c55 --- /dev/null +++ b/upgrade_macos.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +# Check if there are any changes that need to be committed +if [[ -n $(git status --short) ]]; then + echo "There are changes that need to be committed. Please stash or undo your changes before running this script." >&2 + exit 1 +fi + +# Pull the latest changes from the remote repository +git pull + +# Activate the virtual environment +source venv/bin/activate + +# Upgrade the required packages +pip install --upgrade -r requirements_macos.txt From 9069ee26be390d2f7fe0266d55f37df61f5113c1 Mon Sep 17 00:00:00 2001 From: bmaltais Date: Sat, 1 Apr 2023 06:34:03 -0400 Subject: [PATCH 39/44] Revert "Consolidated Install/Launch Scripts and Improve README" From d37aa6efada32dff9bfd83f315eb4df769f85da3 Mon Sep 17 00:00:00 2001 From: bmaltais Date: Sat, 1 Apr 2023 06:59:59 -0400 Subject: [PATCH 40/44] v21.3.9 --- .gitignore | 3 +- README.md | 2 + .../tag_images_by_wd14_tagger_bmaltais.py | 217 ++++++++++++++++++ library/__init__.py | 0 library/wd14_caption_gui.py | 2 +- setup.bat | 8 +- 6 files changed, 228 insertions(+), 4 deletions(-) create mode 100644 finetune/tag_images_by_wd14_tagger_bmaltais.py delete mode 100644 library/__init__.py diff --git a/.gitignore b/.gitignore index 71fe116..bc80c48 100644 --- a/.gitignore +++ b/.gitignore @@ -8,4 +8,5 @@ wd14_tagger_model .DS_Store locon gui-user.bat -gui-user.ps1 \ No newline at end of file +gui-user.ps1 +library/__init__.py diff --git a/README.md b/README.md index 3614d94..7a35f2b 100644 --- a/README.md +++ b/README.md @@ -192,6 +192,8 @@ This will store your a backup file with your current locally installed pip packa ## Change History +* 2023/04/01 (v21.3.9) + - Update how setup is done on Windows by introducing a setup.bat script. This will make it easier to install/re-install on Windows if needed. Many thanks to @missionfloyd for his PR: https://github.com/bmaltais/kohya_ss/pull/496 * 2023/03/30 (v21.3.8) - Fix issue with LyCORIS version not being found: https://github.com/bmaltais/kohya_ss/issues/481 * 2023/03/29 (v21.3.7) diff --git a/finetune/tag_images_by_wd14_tagger_bmaltais.py b/finetune/tag_images_by_wd14_tagger_bmaltais.py new file mode 100644 index 0000000..503d7df --- /dev/null +++ b/finetune/tag_images_by_wd14_tagger_bmaltais.py @@ -0,0 +1,217 @@ +import argparse +import csv +import glob +import os + +from PIL import Image +import cv2 +from tqdm import tqdm +import numpy as np +from tensorflow.keras.models import load_model +from huggingface_hub import hf_hub_download +import torch + +# import library.train_util as train_util + +# from wd14 tagger +IMAGE_SIZE = 448 +IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"] + +# wd-v1-4-swinv2-tagger-v2 / wd-v1-4-vit-tagger / wd-v1-4-vit-tagger-v2/ wd-v1-4-convnext-tagger / wd-v1-4-convnext-tagger-v2 +DEFAULT_WD14_TAGGER_REPO = 'SmilingWolf/wd-v1-4-convnext-tagger-v2' +FILES = ["keras_metadata.pb", "saved_model.pb", "selected_tags.csv"] +SUB_DIR = "variables" +SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"] +CSV_FILE = FILES[-1] + +def glob_images(directory, base="*"): + img_paths = [] + for ext in IMAGE_EXTENSIONS: + if base == "*": + img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext))) + else: + img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext)))) + img_paths = list(set(img_paths)) # 重複を排除 + img_paths.sort() + return img_paths + +def preprocess_image(image): + image = np.array(image) + image = image[:, :, ::-1] # RGB->BGR + + # pad to square + size = max(image.shape[0:2]) + pad_x = size - image.shape[1] + pad_y = size - image.shape[0] + pad_l = pad_x // 2 + pad_t = pad_y // 2 + image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode='constant', constant_values=255) + + interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4 + image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp) + + image = image.astype(np.float32) + return image + + +class ImageLoadingPrepDataset(torch.utils.data.Dataset): + def __init__(self, image_paths): + self.images = image_paths + + def __len__(self): + return len(self.images) + + def __getitem__(self, idx): + img_path = self.images[idx] + + try: + image = Image.open(img_path).convert("RGB") + image = preprocess_image(image) + tensor = torch.tensor(image) + except Exception as e: + print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") + return None + + return (tensor, img_path) + + +def collate_fn_remove_corrupted(batch): + """Collate function that allows to remove corrupted examples in the + dataloader. It expects that the dataloader returns 'None' when that occurs. + The 'None's in the batch are removed. + """ + # Filter out all the Nones (corrupted examples) + batch = list(filter(lambda x: x is not None, batch)) + return batch + + +def main(args): + # hf_hub_downloadをそのまま使うとsymlink関係で問題があるらしいので、キャッシュディレクトリとforce_filenameを指定してなんとかする + # depreacatedの警告が出るけどなくなったらその時 + # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22 + if not os.path.exists(args.model_dir) or args.force_download: + print(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}") + for file in FILES: + hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file) + for file in SUB_DIR_FILES: + hf_hub_download(args.repo_id, file, subfolder=SUB_DIR, cache_dir=os.path.join( + args.model_dir, SUB_DIR), force_download=True, force_filename=file) + else: + print("using existing wd14 tagger model") + + # 画像を読み込む + image_paths = glob_images(args.train_data_dir) + print(f"found {len(image_paths)} images.") + + print("loading model and labels") + model = load_model(args.model_dir) + + # label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv") + # 依存ライブラリを増やしたくないので自力で読むよ + with open(os.path.join(args.model_dir, CSV_FILE), "r", encoding="utf-8") as f: + reader = csv.reader(f) + l = [row for row in reader] + header = l[0] # tag_id,name,category,count + rows = l[1:] + assert header[0] == 'tag_id' and header[1] == 'name' and header[2] == 'category', f"unexpected csv format: {header}" + + tags = [row[1] for row in rows[1:] if row[2] == '0'] # categoryが0、つまり通常のタグのみ + + # 推論する + def run_batch(path_imgs): + imgs = np.array([im for _, im in path_imgs]) + + probs = model(imgs, training=False) + probs = probs.numpy() + + for (image_path, _), prob in zip(path_imgs, probs): + # 最初の4つはratingなので無視する + # # First 4 labels are actually ratings: pick one with argmax + # ratings_names = label_names[:4] + # rating_index = ratings_names["probs"].argmax() + # found_rating = ratings_names[rating_index: rating_index + 1][["name", "probs"]] + + # それ以降はタグなのでconfidenceがthresholdより高いものを追加する + # Everything else is tags: pick any where prediction confidence > threshold + tag_text = "" + for i, p in enumerate(prob[4:]): # numpyとか使うのが良いけど、まあそれほど数も多くないのでループで + if p >= args.thresh and i < len(tags): + tag_text += ", " + tags[i] + + if len(tag_text) > 0: + tag_text = tag_text[2:] # 最初の ", " を消す + + with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f: + f.write(tag_text + '\n') + if args.debug: + print(image_path, tag_text) + + # 読み込みの高速化のためにDataLoaderを使うオプション + if args.max_data_loader_n_workers is not None: + dataset = ImageLoadingPrepDataset(image_paths) + data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, + num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False) + else: + data = [[(None, ip)] for ip in image_paths] + + b_imgs = [] + for data_entry in tqdm(data, smoothing=0.0): + for data in data_entry: + if data is None: + continue + + image, image_path = data + if image is not None: + image = image.detach().numpy() + else: + try: + image = Image.open(image_path) + if image.mode != 'RGB': + image = image.convert("RGB") + image = preprocess_image(image) + except Exception as e: + print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") + continue + b_imgs.append((image_path, image)) + + if len(b_imgs) >= args.batch_size: + run_batch(b_imgs) + b_imgs.clear() + + if len(b_imgs) > 0: + run_batch(b_imgs) + + print("done!") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") + parser.add_argument("--repo_id", type=str, default=DEFAULT_WD14_TAGGER_REPO, + help="repo id for wd14 tagger on Hugging Face / Hugging Faceのwd14 taggerのリポジトリID") + parser.add_argument("--model_dir", type=str, default="wd14_tagger_model", + help="directory to store wd14 tagger model / wd14 taggerのモデルを格納するディレクトリ") + parser.add_argument("--force_download", action='store_true', + help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします") + parser.add_argument("--thresh", type=float, default=0.35, help="threshold of confidence to add a tag / タグを追加するか判定する閾値") + parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") + parser.add_argument("--max_data_loader_n_workers", type=int, default=None, + help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)") + parser.add_argument("--caption_extention", type=str, default=None, + help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)") + parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子") + parser.add_argument("--debug", action="store_true", help="debug mode") + + return parser + + +if __name__ == '__main__': + parser = setup_parser() + + args = parser.parse_args() + + # スペルミスしていたオプションを復元する + if args.caption_extention is not None: + args.caption_extension = args.caption_extention + + main(args) diff --git a/library/__init__.py b/library/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/library/wd14_caption_gui.py b/library/wd14_caption_gui.py index 1970849..2299feb 100644 --- a/library/wd14_caption_gui.py +++ b/library/wd14_caption_gui.py @@ -33,7 +33,7 @@ def caption_images( return print(f'Captioning files in {train_data_dir}...') - run_cmd = f'accelerate launch "./finetune/tag_images_by_wd14_tagger.py"' + run_cmd = f'accelerate launch "./finetune/tag_images_by_wd14_tagger_bmaltais.py"' run_cmd += f' --batch_size="{int(batch_size)}"' run_cmd += f' --thresh="{thresh}"' run_cmd += f' --caption_extension="{caption_extension}"' diff --git a/setup.bat b/setup.bat index 2c84356..528f1d4 100644 --- a/setup.bat +++ b/setup.bat @@ -1,5 +1,9 @@ @echo off -python -m venv venv +IF NOT EXIST venv ( + python -m venv venv +) ELSE ( + echo venv folder already exists, skipping creation... +) call .\venv\Scripts\activate.bat pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116 @@ -10,4 +14,4 @@ copy /y .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\ copy /y .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py copy /y .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py -accelerate config \ No newline at end of file +accelerate config From 2eddd64b90c8b30636f5b1b7f6c934653df277d9 Mon Sep 17 00:00:00 2001 From: bmaltais Date: Sat, 1 Apr 2023 07:14:25 -0400 Subject: [PATCH 41/44] Merge latest sd-script updates --- README.md | 12 + XTI_hijack.py | 209 ++++ gen_img_diffusers.py | 184 +++- library/model_util.py | 1781 ++++++++++++++++---------------- library/train_util.py | 32 +- networks/lora.py | 679 ++++++------ networks/merge_lora.py | 363 ++++--- tools/merge_lycoris.py | 80 ++ train_network.py | 31 +- train_textual_inversion_XTI.py | 644 ++++++++++++ 10 files changed, 2607 insertions(+), 1408 deletions(-) create mode 100644 XTI_hijack.py create mode 100644 tools/merge_lycoris.py create mode 100644 train_textual_inversion_XTI.py diff --git a/README.md b/README.md index 7a35f2b..ae913ad 100644 --- a/README.md +++ b/README.md @@ -192,8 +192,20 @@ This will store your a backup file with your current locally installed pip packa ## Change History +* 2023/04/01 (v21.4.0) + - Fix an issue that `merge_lora.py` does not work with the latest version. + - Fix an issue that `merge_lora.py` does not merge Conv2d3x3 weights. + - Fix an issue that the VRAM usage temporarily increases when loading a model in `train_network.py`. + - Fix an issue that an error occurs when loading a `.safetensors` model in `train_network.py`. [#354](https://github.com/kohya-ss/sd-scripts/issues/354) + - Support [P+](https://prompt-plus.github.io/) training. Thank you jakaline-dev! + - See [#327](https://github.com/kohya-ss/sd-scripts/pull/327) for details. + - Use `train_textual_inversion_XTI.py` for training. The usage is almost the same as `train_textual_inversion.py`. However, sample image generation during training is not supported. + - Use `gen_img_diffusers.py` for image generation (I think Web UI is not supported). Specify the embedding with `--XTI_embeddings` option. + - Reduce RAM usage at startup in `train_network.py`. [#332](https://github.com/kohya-ss/sd-scripts/pull/332) Thank you guaneec! + - Support pre-merge for LoRA in `gen_img_diffusers.py`. Specify `--network_merge` option. Note that the `--am` option of the prompt option is no longer available with this option. * 2023/04/01 (v21.3.9) - Update how setup is done on Windows by introducing a setup.bat script. This will make it easier to install/re-install on Windows if needed. Many thanks to @missionfloyd for his PR: https://github.com/bmaltais/kohya_ss/pull/496 + - Fix issue with WD14 caption script by applying a custom fix to kohya_ss code. * 2023/03/30 (v21.3.8) - Fix issue with LyCORIS version not being found: https://github.com/bmaltais/kohya_ss/issues/481 * 2023/03/29 (v21.3.7) diff --git a/XTI_hijack.py b/XTI_hijack.py new file mode 100644 index 0000000..f39cc8e --- /dev/null +++ b/XTI_hijack.py @@ -0,0 +1,209 @@ +import torch +from typing import Union, List, Optional, Dict, Any, Tuple +from diffusers.models.unet_2d_condition import UNet2DConditionOutput + +def unet_forward_XTI(self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[UNet2DConditionOutput, Tuple]: + r""" + Args: + sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps + encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + + Returns: + [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # 0. center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + emb = self.time_embedding(t_emb) + + if self.config.num_class_embeds is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) + emb = emb + class_emb + + # 2. pre-process + sample = self.conv_in(sample) + + # 3. down + down_block_res_samples = (sample,) + down_i = 0 + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states[down_i:down_i+2], + ) + down_i += 2 + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + # 4. mid + sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states[6]) + + # 5. up + up_i = 7 + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states[up_i:up_i+3], + upsample_size=upsample_size, + ) + up_i += 3 + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + ) + # 6. post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if not return_dict: + return (sample,) + + return UNet2DConditionOutput(sample=sample) + +def downblock_forward_XTI( + self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None +): + output_states = () + i = 0 + + for resnet, attn in zip(self.resnets, self.attentions): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states[i] + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states[i]).sample + + output_states += (hidden_states,) + i += 1 + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + +def upblock_forward_XTI( + self, + hidden_states, + res_hidden_states_tuple, + temb=None, + encoder_hidden_states=None, + upsample_size=None, +): + i = 0 + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states[i] + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states[i]).sample + + i += 1 + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states \ No newline at end of file diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 690d111..225de33 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -95,6 +95,8 @@ import library.train_util as train_util import tools.original_control_net as original_control_net from tools.original_control_net import ControlNetInfo +from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI + # Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う TOKENIZER_PATH = "openai/clip-vit-large-patch14" V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う @@ -491,6 +493,9 @@ class PipelineLike: # Textual Inversion self.token_replacements = {} + # XTI + self.token_replacements_XTI = {} + # CLIP guidance self.clip_guidance_scale = clip_guidance_scale self.clip_image_guidance_scale = clip_image_guidance_scale @@ -514,15 +519,26 @@ class PipelineLike: def add_token_replacement(self, target_token_id, rep_token_ids): self.token_replacements[target_token_id] = rep_token_ids - def replace_token(self, tokens): + def replace_token(self, tokens, layer=None): new_tokens = [] for token in tokens: if token in self.token_replacements: - new_tokens.extend(self.token_replacements[token]) + replacer_ = self.token_replacements[token] + if layer: + replacer = [] + for r in replacer_: + if r in self.token_replacements_XTI: + replacer.append(self.token_replacements_XTI[r][layer]) + else: + replacer = replacer_ + new_tokens.extend(replacer) else: new_tokens.append(token) return new_tokens + def add_token_replacement_XTI(self, target_token_id, rep_token_ids): + self.token_replacements_XTI[target_token_id] = rep_token_ids + def set_control_nets(self, ctrl_nets): self.control_nets = ctrl_nets @@ -744,14 +760,15 @@ class PipelineLike: " the batch size of `prompt`." ) - text_embeddings, uncond_embeddings, prompt_tokens = get_weighted_text_embeddings( - pipe=self, - prompt=prompt, - uncond_prompt=negative_prompt if do_classifier_free_guidance else None, - max_embeddings_multiples=max_embeddings_multiples, - clip_skip=self.clip_skip, - **kwargs, - ) + if not self.token_replacements_XTI: + text_embeddings, uncond_embeddings, prompt_tokens = get_weighted_text_embeddings( + pipe=self, + prompt=prompt, + uncond_prompt=negative_prompt if do_classifier_free_guidance else None, + max_embeddings_multiples=max_embeddings_multiples, + clip_skip=self.clip_skip, + **kwargs, + ) if negative_scale is not None: _, real_uncond_embeddings, _ = get_weighted_text_embeddings( @@ -763,11 +780,47 @@ class PipelineLike: **kwargs, ) - if do_classifier_free_guidance: - if negative_scale is None: - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) - else: - text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings]) + if self.token_replacements_XTI: + text_embeddings_concat = [] + for layer in [ + "IN01", + "IN02", + "IN04", + "IN05", + "IN07", + "IN08", + "MID", + "OUT03", + "OUT04", + "OUT05", + "OUT06", + "OUT07", + "OUT08", + "OUT09", + "OUT10", + "OUT11", + ]: + text_embeddings, uncond_embeddings, prompt_tokens = get_weighted_text_embeddings( + pipe=self, + prompt=prompt, + uncond_prompt=negative_prompt if do_classifier_free_guidance else None, + max_embeddings_multiples=max_embeddings_multiples, + clip_skip=self.clip_skip, + layer=layer, + **kwargs, + ) + if do_classifier_free_guidance: + if negative_scale is None: + text_embeddings_concat.append(torch.cat([uncond_embeddings, text_embeddings])) + else: + text_embeddings_concat.append(torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings])) + text_embeddings = torch.stack(text_embeddings_concat) + else: + if do_classifier_free_guidance: + if negative_scale is None: + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + else: + text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings]) # CLIP guidanceで使用するembeddingsを取得する if self.clip_guidance_scale > 0: @@ -1675,7 +1728,7 @@ def parse_prompt_attention(text): return res -def get_prompts_with_weights(pipe: PipelineLike, prompt: List[str], max_length: int): +def get_prompts_with_weights(pipe: PipelineLike, prompt: List[str], max_length: int, layer=None): r""" Tokenize a list of prompts and return its tokens with weights of each token. No padding, starting or ending token is included. @@ -1691,7 +1744,7 @@ def get_prompts_with_weights(pipe: PipelineLike, prompt: List[str], max_length: # tokenize and discard the starting and the ending token token = pipe.tokenizer(word).input_ids[1:-1] - token = pipe.replace_token(token) + token = pipe.replace_token(token, layer=layer) text_token += token # copy the weight by length of token @@ -1807,6 +1860,7 @@ def get_weighted_text_embeddings( skip_parsing: Optional[bool] = False, skip_weighting: Optional[bool] = False, clip_skip=None, + layer=None, **kwargs, ): r""" @@ -1837,11 +1891,11 @@ def get_weighted_text_embeddings( prompt = [prompt] if not skip_parsing: - prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2) + prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2, layer=layer) if uncond_prompt is not None: if isinstance(uncond_prompt, str): uncond_prompt = [uncond_prompt] - uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2) + uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2, layer=layer) else: prompt_tokens = [token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids] prompt_weights = [[1.0] * len(token) for token in prompt_tokens] @@ -2229,13 +2283,17 @@ def main(args): if network is None: return - network.apply_to(text_encoder, unet) + if not args.network_merge: + network.apply_to(text_encoder, unet) - if args.opt_channels_last: - network.to(memory_format=torch.channels_last) - network.to(dtype).to(device) + if args.opt_channels_last: + network.to(memory_format=torch.channels_last) + network.to(dtype).to(device) + + networks.append(network) + else: + network.merge_to(text_encoder, unet, dtype, device) - networks.append(network) else: networks = [] @@ -2289,6 +2347,11 @@ def main(args): if args.diffusers_xformers: pipe.enable_xformers_memory_efficient_attention() + if args.XTI_embeddings: + diffusers.models.UNet2DConditionModel.forward = unet_forward_XTI + diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D.forward = downblock_forward_XTI + diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D.forward = upblock_forward_XTI + # Textual Inversionを処理する if args.textual_inversion_embeddings: token_ids_embeds = [] @@ -2335,6 +2398,71 @@ def main(args): for token_id, embed in zip(token_ids, embeds): token_embeds[token_id] = embed + if args.XTI_embeddings: + XTI_layers = [ + "IN01", + "IN02", + "IN04", + "IN05", + "IN07", + "IN08", + "MID", + "OUT03", + "OUT04", + "OUT05", + "OUT06", + "OUT07", + "OUT08", + "OUT09", + "OUT10", + "OUT11", + ] + token_ids_embeds_XTI = [] + for embeds_file in args.XTI_embeddings: + if model_util.is_safetensors(embeds_file): + from safetensors.torch import load_file + + data = load_file(embeds_file) + else: + data = torch.load(embeds_file, map_location="cpu") + if set(data.keys()) != set(XTI_layers): + raise ValueError("NOT XTI") + embeds = torch.concat(list(data.values())) + num_vectors_per_token = data["MID"].size()[0] + + token_string = os.path.splitext(os.path.basename(embeds_file))[0] + token_strings = [token_string] + [f"{token_string}{i+1}" for i in range(num_vectors_per_token - 1)] + + # add new word to tokenizer, count is num_vectors_per_token + num_added_tokens = tokenizer.add_tokens(token_strings) + assert ( + num_added_tokens == num_vectors_per_token + ), f"tokenizer has same word to token string (filename). please rename the file / 指定した名前(ファイル名)のトークンが既に存在します。ファイルをリネームしてください: {embeds_file}" + + token_ids = tokenizer.convert_tokens_to_ids(token_strings) + print(f"XTI embeddings `{token_string}` loaded. Tokens are added: {token_ids}") + + # if num_vectors_per_token > 1: + pipe.add_token_replacement(token_ids[0], token_ids) + + token_strings_XTI = [] + for layer_name in XTI_layers: + token_strings_XTI += [f"{t}_{layer_name}" for t in token_strings] + tokenizer.add_tokens(token_strings_XTI) + token_ids_XTI = tokenizer.convert_tokens_to_ids(token_strings_XTI) + token_ids_embeds_XTI.append((token_ids_XTI, embeds)) + for t in token_ids: + t_XTI_dic = {} + for i, layer_name in enumerate(XTI_layers): + t_XTI_dic[layer_name] = t + (i + 1) * num_added_tokens + pipe.add_token_replacement_XTI(t, t_XTI_dic) + + text_encoder.resize_token_embeddings(len(tokenizer)) + token_embeds = text_encoder.get_input_embeddings().weight.data + for token_ids, embeds in token_ids_embeds_XTI: + for token_id, embed in zip(token_ids, embeds): + token_embeds[token_id] = embed + # promptを取得する if args.from_file is not None: print(f"reading prompts from {args.from_file}") @@ -2983,6 +3111,7 @@ def setup_parser() -> argparse.ArgumentParser: "--network_args", type=str, default=None, nargs="*", help="additional argmuments for network (key=value) / ネットワークへの追加の引数" ) parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する") + parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする") parser.add_argument( "--textual_inversion_embeddings", type=str, @@ -2990,6 +3119,13 @@ def setup_parser() -> argparse.ArgumentParser: nargs="*", help="Embeddings files of Textual Inversion / Textual Inversionのembeddings", ) + parser.add_argument( + "--XTI_embeddings", + type=str, + default=None, + nargs="*", + help="Embeddings files of Extended Textual Inversion / Extended Textual Inversionのembeddings", + ) parser.add_argument("--clip_skip", type=int, default=None, help="layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う") parser.add_argument( "--max_embeddings_multiples", diff --git a/library/model_util.py b/library/model_util.py index 3d8e753..35b0b6a 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -45,596 +45,574 @@ DIFFUSERS_REF_MODEL_ID_V2 = "stabilityai/stable-diffusion-2-1" def shave_segments(path, n_shave_prefix_segments=1): - """ - Removes segments. Positive values shave the first segments, negative shave the last segments. - """ - if n_shave_prefix_segments >= 0: - return ".".join(path.split(".")[n_shave_prefix_segments:]) - else: - return ".".join(path.split(".")[:n_shave_prefix_segments]) + """ + Removes segments. Positive values shave the first segments, negative shave the last segments. + """ + if n_shave_prefix_segments >= 0: + return ".".join(path.split(".")[n_shave_prefix_segments:]) + else: + return ".".join(path.split(".")[:n_shave_prefix_segments]) def renew_resnet_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside resnets to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item.replace("in_layers.0", "norm1") - new_item = new_item.replace("in_layers.2", "conv1") + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item.replace("in_layers.0", "norm1") + new_item = new_item.replace("in_layers.2", "conv1") - new_item = new_item.replace("out_layers.0", "norm2") - new_item = new_item.replace("out_layers.3", "conv2") + new_item = new_item.replace("out_layers.0", "norm2") + new_item = new_item.replace("out_layers.3", "conv2") - new_item = new_item.replace("emb_layers.1", "time_emb_proj") - new_item = new_item.replace("skip_connection", "conv_shortcut") + new_item = new_item.replace("emb_layers.1", "time_emb_proj") + new_item = new_item.replace("skip_connection", "conv_shortcut") - new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) - mapping.append({"old": old_item, "new": new_item}) + mapping.append({"old": old_item, "new": new_item}) - return mapping + return mapping def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside resnets to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item - new_item = new_item.replace("nin_shortcut", "conv_shortcut") - new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + new_item = new_item.replace("nin_shortcut", "conv_shortcut") + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) - mapping.append({"old": old_item, "new": new_item}) + mapping.append({"old": old_item, "new": new_item}) - return mapping + return mapping def renew_attention_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside attentions to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item - # new_item = new_item.replace('norm.weight', 'group_norm.weight') - # new_item = new_item.replace('norm.bias', 'group_norm.bias') + # new_item = new_item.replace('norm.weight', 'group_norm.weight') + # new_item = new_item.replace('norm.bias', 'group_norm.bias') - # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') - # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') + # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') + # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') - # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) - mapping.append({"old": old_item, "new": new_item}) + mapping.append({"old": old_item, "new": new_item}) - return mapping + return mapping def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside attentions to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item - new_item = new_item.replace("norm.weight", "group_norm.weight") - new_item = new_item.replace("norm.bias", "group_norm.bias") + new_item = new_item.replace("norm.weight", "group_norm.weight") + new_item = new_item.replace("norm.bias", "group_norm.bias") - new_item = new_item.replace("q.weight", "query.weight") - new_item = new_item.replace("q.bias", "query.bias") + new_item = new_item.replace("q.weight", "query.weight") + new_item = new_item.replace("q.bias", "query.bias") - new_item = new_item.replace("k.weight", "key.weight") - new_item = new_item.replace("k.bias", "key.bias") + new_item = new_item.replace("k.weight", "key.weight") + new_item = new_item.replace("k.bias", "key.bias") - new_item = new_item.replace("v.weight", "value.weight") - new_item = new_item.replace("v.bias", "value.bias") + new_item = new_item.replace("v.weight", "value.weight") + new_item = new_item.replace("v.bias", "value.bias") - new_item = new_item.replace("proj_out.weight", "proj_attn.weight") - new_item = new_item.replace("proj_out.bias", "proj_attn.bias") + new_item = new_item.replace("proj_out.weight", "proj_attn.weight") + new_item = new_item.replace("proj_out.bias", "proj_attn.bias") - new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) - mapping.append({"old": old_item, "new": new_item}) + mapping.append({"old": old_item, "new": new_item}) - return mapping + return mapping def assign_to_checkpoint( paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None ): - """ - This does the final conversion step: take locally converted weights and apply a global renaming - to them. It splits attention layers, and takes into account additional replacements - that may arise. + """ + This does the final conversion step: take locally converted weights and apply a global renaming + to them. It splits attention layers, and takes into account additional replacements + that may arise. - Assigns the weights to the new checkpoint. - """ - assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." + Assigns the weights to the new checkpoint. + """ + assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." - # Splits the attention layers into three variables. - if attention_paths_to_split is not None: - for path, path_map in attention_paths_to_split.items(): - old_tensor = old_checkpoint[path] - channels = old_tensor.shape[0] // 3 + # Splits the attention layers into three variables. + if attention_paths_to_split is not None: + for path, path_map in attention_paths_to_split.items(): + old_tensor = old_checkpoint[path] + channels = old_tensor.shape[0] // 3 - target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) + target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) - num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 + num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 - old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) - query, key, value = old_tensor.split(channels // num_heads, dim=1) + old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) + query, key, value = old_tensor.split(channels // num_heads, dim=1) - checkpoint[path_map["query"]] = query.reshape(target_shape) - checkpoint[path_map["key"]] = key.reshape(target_shape) - checkpoint[path_map["value"]] = value.reshape(target_shape) + checkpoint[path_map["query"]] = query.reshape(target_shape) + checkpoint[path_map["key"]] = key.reshape(target_shape) + checkpoint[path_map["value"]] = value.reshape(target_shape) - for path in paths: - new_path = path["new"] + for path in paths: + new_path = path["new"] - # These have already been assigned - if attention_paths_to_split is not None and new_path in attention_paths_to_split: - continue + # These have already been assigned + if attention_paths_to_split is not None and new_path in attention_paths_to_split: + continue - # Global renaming happens here - new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") - new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") - new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") + # Global renaming happens here + new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") + new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") + new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") - if additional_replacements is not None: - for replacement in additional_replacements: - new_path = new_path.replace(replacement["old"], replacement["new"]) + if additional_replacements is not None: + for replacement in additional_replacements: + new_path = new_path.replace(replacement["old"], replacement["new"]) - # proj_attn.weight has to be converted from conv 1D to linear - if "proj_attn.weight" in new_path: - checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] - else: - checkpoint[new_path] = old_checkpoint[path["old"]] + # proj_attn.weight has to be converted from conv 1D to linear + if "proj_attn.weight" in new_path: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] + else: + checkpoint[new_path] = old_checkpoint[path["old"]] def conv_attn_to_linear(checkpoint): - keys = list(checkpoint.keys()) - attn_keys = ["query.weight", "key.weight", "value.weight"] - for key in keys: - if ".".join(key.split(".")[-2:]) in attn_keys: - if checkpoint[key].ndim > 2: - checkpoint[key] = checkpoint[key][:, :, 0, 0] - elif "proj_attn.weight" in key: - if checkpoint[key].ndim > 2: - checkpoint[key] = checkpoint[key][:, :, 0] + keys = list(checkpoint.keys()) + attn_keys = ["query.weight", "key.weight", "value.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in attn_keys: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0, 0] + elif "proj_attn.weight" in key: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0] def linear_transformer_to_conv(checkpoint): - keys = list(checkpoint.keys()) - tf_keys = ["proj_in.weight", "proj_out.weight"] - for key in keys: - if ".".join(key.split(".")[-2:]) in tf_keys: - if checkpoint[key].ndim == 2: - checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2) + keys = list(checkpoint.keys()) + tf_keys = ["proj_in.weight", "proj_out.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in tf_keys: + if checkpoint[key].ndim == 2: + checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2) def convert_ldm_unet_checkpoint(v2, checkpoint, config): - """ - Takes a state dict and a config, and returns a converted checkpoint. - """ + """ + Takes a state dict and a config, and returns a converted checkpoint. + """ - # extract state_dict for UNet - unet_state_dict = {} - unet_key = "model.diffusion_model." - keys = list(checkpoint.keys()) - for key in keys: - if key.startswith(unet_key): - unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) + # extract state_dict for UNet + unet_state_dict = {} + unet_key = "model.diffusion_model." + keys = list(checkpoint.keys()) + for key in keys: + if key.startswith(unet_key): + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) - new_checkpoint = {} + new_checkpoint = {} - new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] - new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] - new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] - new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] + new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] + new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] + new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] + new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] - new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] - new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] + new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] + new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] - new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] - new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] - new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] - new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] + new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] + new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] + new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] + new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] - # Retrieves the keys for the input blocks only - num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) - input_blocks = { - layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key] - for layer_id in range(num_input_blocks) - } + # Retrieves the keys for the input blocks only + num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) + input_blocks = { + layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key] for layer_id in range(num_input_blocks) + } - # Retrieves the keys for the middle blocks only - num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) - middle_blocks = { - layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key] - for layer_id in range(num_middle_blocks) - } + # Retrieves the keys for the middle blocks only + num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) + middle_blocks = { + layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key] for layer_id in range(num_middle_blocks) + } - # Retrieves the keys for the output blocks only - num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) - output_blocks = { - layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key] - for layer_id in range(num_output_blocks) - } + # Retrieves the keys for the output blocks only + num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) + output_blocks = { + layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key] for layer_id in range(num_output_blocks) + } - for i in range(1, num_input_blocks): - block_id = (i - 1) // (config["layers_per_block"] + 1) - layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) + for i in range(1, num_input_blocks): + block_id = (i - 1) // (config["layers_per_block"] + 1) + layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) - resnets = [ - key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key - ] - attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] + resnets = [key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key] + attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] - if f"input_blocks.{i}.0.op.weight" in unet_state_dict: - new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( - f"input_blocks.{i}.0.op.weight" - ) - new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( - f"input_blocks.{i}.0.op.bias" - ) + if f"input_blocks.{i}.0.op.weight" in unet_state_dict: + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.weight" + ) + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(f"input_blocks.{i}.0.op.bias") - paths = renew_resnet_paths(resnets) - meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} - assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config - ) + paths = renew_resnet_paths(resnets) + meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config) - if len(attentions): - paths = renew_attention_paths(attentions) - meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} - assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config - ) + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} + assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config) - resnet_0 = middle_blocks[0] - attentions = middle_blocks[1] - resnet_1 = middle_blocks[2] + resnet_0 = middle_blocks[0] + attentions = middle_blocks[1] + resnet_1 = middle_blocks[2] - resnet_0_paths = renew_resnet_paths(resnet_0) - assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) + resnet_0_paths = renew_resnet_paths(resnet_0) + assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) - resnet_1_paths = renew_resnet_paths(resnet_1) - assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) + resnet_1_paths = renew_resnet_paths(resnet_1) + assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) - attentions_paths = renew_attention_paths(attentions) - meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} - assign_to_checkpoint( - attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config - ) + attentions_paths = renew_attention_paths(attentions) + meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config) - for i in range(num_output_blocks): - block_id = i // (config["layers_per_block"] + 1) - layer_in_block_id = i % (config["layers_per_block"] + 1) - output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] - output_block_list = {} + for i in range(num_output_blocks): + block_id = i // (config["layers_per_block"] + 1) + layer_in_block_id = i % (config["layers_per_block"] + 1) + output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] + output_block_list = {} - for layer in output_block_layers: - layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) - if layer_id in output_block_list: - output_block_list[layer_id].append(layer_name) - else: - output_block_list[layer_id] = [layer_name] + for layer in output_block_layers: + layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) + if layer_id in output_block_list: + output_block_list[layer_id].append(layer_name) + else: + output_block_list[layer_id] = [layer_name] - if len(output_block_list) > 1: - resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] - attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] + if len(output_block_list) > 1: + resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] + attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] - resnet_0_paths = renew_resnet_paths(resnets) - paths = renew_resnet_paths(resnets) + resnet_0_paths = renew_resnet_paths(resnets) + paths = renew_resnet_paths(resnets) - meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} - assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config - ) + meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config) - # オリジナル: - # if ["conv.weight", "conv.bias"] in output_block_list.values(): - # index = list(output_block_list.values()).index(["conv.weight", "conv.bias"]) + # オリジナル: + # if ["conv.weight", "conv.bias"] in output_block_list.values(): + # index = list(output_block_list.values()).index(["conv.weight", "conv.bias"]) - # biasとweightの順番に依存しないようにする:もっといいやり方がありそうだが - for l in output_block_list.values(): - l.sort() + # biasとweightの順番に依存しないようにする:もっといいやり方がありそうだが + for l in output_block_list.values(): + l.sort() - if ["conv.bias", "conv.weight"] in output_block_list.values(): - index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) - new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ - f"output_blocks.{i}.{index}.conv.bias" - ] - new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ - f"output_blocks.{i}.{index}.conv.weight" - ] + if ["conv.bias", "conv.weight"] in output_block_list.values(): + index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.bias" + ] + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.weight" + ] - # Clear attentions as they have been attributed above. - if len(attentions) == 2: - attentions = [] + # Clear attentions as they have been attributed above. + if len(attentions) == 2: + attentions = [] - if len(attentions): - paths = renew_attention_paths(attentions) - meta_path = { - "old": f"output_blocks.{i}.1", - "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", - } - assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config - ) - else: - resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) - for path in resnet_0_paths: - old_path = ".".join(["output_blocks", str(i), path["old"]]) - new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = { + "old": f"output_blocks.{i}.1", + "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", + } + assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config) + else: + resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) + for path in resnet_0_paths: + old_path = ".".join(["output_blocks", str(i), path["old"]]) + new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) - new_checkpoint[new_path] = unet_state_dict[old_path] + new_checkpoint[new_path] = unet_state_dict[old_path] - # SDのv2では1*1のconv2dがlinearに変わっているので、linear->convに変換する - if v2: - linear_transformer_to_conv(new_checkpoint) + # SDのv2では1*1のconv2dがlinearに変わっているので、linear->convに変換する + if v2: + linear_transformer_to_conv(new_checkpoint) - return new_checkpoint + return new_checkpoint def convert_ldm_vae_checkpoint(checkpoint, config): - # extract state dict for VAE - vae_state_dict = {} - vae_key = "first_stage_model." - keys = list(checkpoint.keys()) - for key in keys: - if key.startswith(vae_key): - vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) - # if len(vae_state_dict) == 0: - # # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict - # vae_state_dict = checkpoint + # extract state dict for VAE + vae_state_dict = {} + vae_key = "first_stage_model." + keys = list(checkpoint.keys()) + for key in keys: + if key.startswith(vae_key): + vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) + # if len(vae_state_dict) == 0: + # # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict + # vae_state_dict = checkpoint - new_checkpoint = {} + new_checkpoint = {} - new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] - new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] - new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] - new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] - new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] - new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] + new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] + new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] + new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] + new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] + new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] + new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] - new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] - new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] - new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] - new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] - new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] - new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] + new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] + new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] + new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] + new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] + new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] + new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] - new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] - new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] - new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] - new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] + new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] + new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] + new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] + new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] - # Retrieves the keys for the encoder down blocks only - num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) - down_blocks = { - layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) - } + # Retrieves the keys for the encoder down blocks only + num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) + down_blocks = {layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)} - # Retrieves the keys for the decoder up blocks only - num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) - up_blocks = { - layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) - } + # Retrieves the keys for the decoder up blocks only + num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) + up_blocks = {layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)} - for i in range(num_down_blocks): - resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] + for i in range(num_down_blocks): + resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] - if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: - new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( - f"encoder.down.{i}.downsample.conv.weight" - ) - new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( - f"encoder.down.{i}.downsample.conv.bias" - ) + if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.weight" + ) + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.bias" + ) - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) - mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] - num_mid_res_blocks = 2 - for i in range(1, num_mid_res_blocks + 1): - resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] + for i in range(num_up_blocks): + block_id = num_up_blocks - 1 - i + resnets = [key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key] - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.weight" + ] + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.bias" + ] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - - mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] - paths = renew_vae_attention_paths(mid_attentions) - meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - conv_attn_to_linear(new_checkpoint) - - for i in range(num_up_blocks): - block_id = num_up_blocks - 1 - i - resnets = [ - key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key - ] - - if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: - new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ - f"decoder.up.{block_id}.upsample.conv.weight" - ] - new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ - f"decoder.up.{block_id}.upsample.conv.bias" - ] - - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - - mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] - num_mid_res_blocks = 2 - for i in range(1, num_mid_res_blocks + 1): - resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] - - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - - mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] - paths = renew_vae_attention_paths(mid_attentions) - meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - conv_attn_to_linear(new_checkpoint) - return new_checkpoint + conv_attn_to_linear(new_checkpoint) + return new_checkpoint def create_unet_diffusers_config(v2): - """ - Creates a config for the diffusers based on the config of the LDM model. - """ - # unet_params = original_config.model.params.unet_config.params + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + # unet_params = original_config.model.params.unet_config.params - block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT] + block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT] - down_block_types = [] - resolution = 1 - for i in range(len(block_out_channels)): - block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D" - down_block_types.append(block_type) - if i != len(block_out_channels) - 1: - resolution *= 2 + down_block_types = [] + resolution = 1 + for i in range(len(block_out_channels)): + block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D" + down_block_types.append(block_type) + if i != len(block_out_channels) - 1: + resolution *= 2 - up_block_types = [] - for i in range(len(block_out_channels)): - block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D" - up_block_types.append(block_type) - resolution //= 2 + up_block_types = [] + for i in range(len(block_out_channels)): + block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D" + up_block_types.append(block_type) + resolution //= 2 - config = dict( - sample_size=UNET_PARAMS_IMAGE_SIZE, - in_channels=UNET_PARAMS_IN_CHANNELS, - out_channels=UNET_PARAMS_OUT_CHANNELS, - down_block_types=tuple(down_block_types), - up_block_types=tuple(up_block_types), - block_out_channels=tuple(block_out_channels), - layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS, - cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM, - attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM, - ) + config = dict( + sample_size=UNET_PARAMS_IMAGE_SIZE, + in_channels=UNET_PARAMS_IN_CHANNELS, + out_channels=UNET_PARAMS_OUT_CHANNELS, + down_block_types=tuple(down_block_types), + up_block_types=tuple(up_block_types), + block_out_channels=tuple(block_out_channels), + layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS, + cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM, + attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM, + ) - return config + return config def create_vae_diffusers_config(): - """ - Creates a config for the diffusers based on the config of the LDM model. - """ - # vae_params = original_config.model.params.first_stage_config.params.ddconfig - # _ = original_config.model.params.first_stage_config.params.embed_dim - block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT] - down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) - up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + # vae_params = original_config.model.params.first_stage_config.params.ddconfig + # _ = original_config.model.params.first_stage_config.params.embed_dim + block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT] + down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) + up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) - config = dict( - sample_size=VAE_PARAMS_RESOLUTION, - in_channels=VAE_PARAMS_IN_CHANNELS, - out_channels=VAE_PARAMS_OUT_CH, - down_block_types=tuple(down_block_types), - up_block_types=tuple(up_block_types), - block_out_channels=tuple(block_out_channels), - latent_channels=VAE_PARAMS_Z_CHANNELS, - layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS, - ) - return config + config = dict( + sample_size=VAE_PARAMS_RESOLUTION, + in_channels=VAE_PARAMS_IN_CHANNELS, + out_channels=VAE_PARAMS_OUT_CH, + down_block_types=tuple(down_block_types), + up_block_types=tuple(up_block_types), + block_out_channels=tuple(block_out_channels), + latent_channels=VAE_PARAMS_Z_CHANNELS, + layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS, + ) + return config def convert_ldm_clip_checkpoint_v1(checkpoint): - keys = list(checkpoint.keys()) - text_model_dict = {} - for key in keys: - if key.startswith("cond_stage_model.transformer"): - text_model_dict[key[len("cond_stage_model.transformer."):]] = checkpoint[key] - return text_model_dict + keys = list(checkpoint.keys()) + text_model_dict = {} + for key in keys: + if key.startswith("cond_stage_model.transformer"): + text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] + return text_model_dict def convert_ldm_clip_checkpoint_v2(checkpoint, max_length): - # 嫌になるくらい違うぞ! - def convert_key(key): - if not key.startswith("cond_stage_model"): - return None + # 嫌になるくらい違うぞ! + def convert_key(key): + if not key.startswith("cond_stage_model"): + return None - # common conversion - key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.") - key = key.replace("cond_stage_model.model.", "text_model.") + # common conversion + key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.") + key = key.replace("cond_stage_model.model.", "text_model.") - if "resblocks" in key: - # resblocks conversion - key = key.replace(".resblocks.", ".layers.") - if ".ln_" in key: - key = key.replace(".ln_", ".layer_norm") - elif ".mlp." in key: - key = key.replace(".c_fc.", ".fc1.") - key = key.replace(".c_proj.", ".fc2.") - elif '.attn.out_proj' in key: - key = key.replace(".attn.out_proj.", ".self_attn.out_proj.") - elif '.attn.in_proj' in key: - key = None # 特殊なので後で処理する - else: - raise ValueError(f"unexpected key in SD: {key}") - elif '.positional_embedding' in key: - key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight") - elif '.text_projection' in key: - key = None # 使われない??? - elif '.logit_scale' in key: - key = None # 使われない??? - elif '.token_embedding' in key: - key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight") - elif '.ln_final' in key: - key = key.replace(".ln_final", ".final_layer_norm") - return key + if "resblocks" in key: + # resblocks conversion + key = key.replace(".resblocks.", ".layers.") + if ".ln_" in key: + key = key.replace(".ln_", ".layer_norm") + elif ".mlp." in key: + key = key.replace(".c_fc.", ".fc1.") + key = key.replace(".c_proj.", ".fc2.") + elif ".attn.out_proj" in key: + key = key.replace(".attn.out_proj.", ".self_attn.out_proj.") + elif ".attn.in_proj" in key: + key = None # 特殊なので後で処理する + else: + raise ValueError(f"unexpected key in SD: {key}") + elif ".positional_embedding" in key: + key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight") + elif ".text_projection" in key: + key = None # 使われない??? + elif ".logit_scale" in key: + key = None # 使われない??? + elif ".token_embedding" in key: + key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight") + elif ".ln_final" in key: + key = key.replace(".ln_final", ".final_layer_norm") + return key - keys = list(checkpoint.keys()) - new_sd = {} - for key in keys: - # remove resblocks 23 - if '.resblocks.23.' in key: - continue - new_key = convert_key(key) - if new_key is None: - continue - new_sd[new_key] = checkpoint[key] + keys = list(checkpoint.keys()) + new_sd = {} + for key in keys: + # remove resblocks 23 + if ".resblocks.23." in key: + continue + new_key = convert_key(key) + if new_key is None: + continue + new_sd[new_key] = checkpoint[key] - # attnの変換 - for key in keys: - if '.resblocks.23.' in key: - continue - if '.resblocks' in key and '.attn.in_proj_' in key: - # 三つに分割 - values = torch.chunk(checkpoint[key], 3) + # attnの変換 + for key in keys: + if ".resblocks.23." in key: + continue + if ".resblocks" in key and ".attn.in_proj_" in key: + # 三つに分割 + values = torch.chunk(checkpoint[key], 3) - key_suffix = ".weight" if "weight" in key else ".bias" - key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.") - key_pfx = key_pfx.replace("_weight", "") - key_pfx = key_pfx.replace("_bias", "") - key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.") - new_sd[key_pfx + "q_proj" + key_suffix] = values[0] - new_sd[key_pfx + "k_proj" + key_suffix] = values[1] - new_sd[key_pfx + "v_proj" + key_suffix] = values[2] + key_suffix = ".weight" if "weight" in key else ".bias" + key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.") + key_pfx = key_pfx.replace("_weight", "") + key_pfx = key_pfx.replace("_bias", "") + key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.") + new_sd[key_pfx + "q_proj" + key_suffix] = values[0] + new_sd[key_pfx + "k_proj" + key_suffix] = values[1] + new_sd[key_pfx + "v_proj" + key_suffix] = values[2] - # rename or add position_ids - ANOTHER_POSITION_IDS_KEY = "text_model.encoder.text_model.embeddings.position_ids" - if ANOTHER_POSITION_IDS_KEY in new_sd: - # waifu diffusion v1.4 - position_ids = new_sd[ANOTHER_POSITION_IDS_KEY] - del new_sd[ANOTHER_POSITION_IDS_KEY] - else: - position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64) + # rename or add position_ids + ANOTHER_POSITION_IDS_KEY = "text_model.encoder.text_model.embeddings.position_ids" + if ANOTHER_POSITION_IDS_KEY in new_sd: + # waifu diffusion v1.4 + position_ids = new_sd[ANOTHER_POSITION_IDS_KEY] + del new_sd[ANOTHER_POSITION_IDS_KEY] + else: + position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64) + + new_sd["text_model.embeddings.position_ids"] = position_ids + return new_sd - new_sd["text_model.embeddings.position_ids"] = position_ids - return new_sd # endregion @@ -642,547 +620,546 @@ def convert_ldm_clip_checkpoint_v2(checkpoint, max_length): # region Diffusers->StableDiffusion の変換コード # convert_diffusers_to_original_stable_diffusion をコピーして修正している(ASL 2.0) + def conv_transformer_to_linear(checkpoint): - keys = list(checkpoint.keys()) - tf_keys = ["proj_in.weight", "proj_out.weight"] - for key in keys: - if ".".join(key.split(".")[-2:]) in tf_keys: - if checkpoint[key].ndim > 2: - checkpoint[key] = checkpoint[key][:, :, 0, 0] + keys = list(checkpoint.keys()) + tf_keys = ["proj_in.weight", "proj_out.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in tf_keys: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0, 0] def convert_unet_state_dict_to_sd(v2, unet_state_dict): - unet_conversion_map = [ - # (stable-diffusion, HF Diffusers) - ("time_embed.0.weight", "time_embedding.linear_1.weight"), - ("time_embed.0.bias", "time_embedding.linear_1.bias"), - ("time_embed.2.weight", "time_embedding.linear_2.weight"), - ("time_embed.2.bias", "time_embedding.linear_2.bias"), - ("input_blocks.0.0.weight", "conv_in.weight"), - ("input_blocks.0.0.bias", "conv_in.bias"), - ("out.0.weight", "conv_norm_out.weight"), - ("out.0.bias", "conv_norm_out.bias"), - ("out.2.weight", "conv_out.weight"), - ("out.2.bias", "conv_out.bias"), - ] + unet_conversion_map = [ + # (stable-diffusion, HF Diffusers) + ("time_embed.0.weight", "time_embedding.linear_1.weight"), + ("time_embed.0.bias", "time_embedding.linear_1.bias"), + ("time_embed.2.weight", "time_embedding.linear_2.weight"), + ("time_embed.2.bias", "time_embedding.linear_2.bias"), + ("input_blocks.0.0.weight", "conv_in.weight"), + ("input_blocks.0.0.bias", "conv_in.bias"), + ("out.0.weight", "conv_norm_out.weight"), + ("out.0.bias", "conv_norm_out.bias"), + ("out.2.weight", "conv_out.weight"), + ("out.2.bias", "conv_out.bias"), + ] - unet_conversion_map_resnet = [ - # (stable-diffusion, HF Diffusers) - ("in_layers.0", "norm1"), - ("in_layers.2", "conv1"), - ("out_layers.0", "norm2"), - ("out_layers.3", "conv2"), - ("emb_layers.1", "time_emb_proj"), - ("skip_connection", "conv_shortcut"), - ] + unet_conversion_map_resnet = [ + # (stable-diffusion, HF Diffusers) + ("in_layers.0", "norm1"), + ("in_layers.2", "conv1"), + ("out_layers.0", "norm2"), + ("out_layers.3", "conv2"), + ("emb_layers.1", "time_emb_proj"), + ("skip_connection", "conv_shortcut"), + ] - unet_conversion_map_layer = [] - for i in range(4): - # loop over downblocks/upblocks + unet_conversion_map_layer = [] + for i in range(4): + # loop over downblocks/upblocks + + for j in range(2): + # loop over resnets/attentions for downblocks + hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." + sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." + unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) + + if i < 3: + # no attention layers in down_blocks.3 + hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." + sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." + unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) + + for j in range(3): + # loop over resnets/attentions for upblocks + hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." + sd_up_res_prefix = f"output_blocks.{3*i + j}.0." + unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) + + if i > 0: + # no attention layers in up_blocks.0 + hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." + sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." + unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) + + if i < 3: + # no downsample in down_blocks.3 + hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." + sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." + unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) + + # no upsample in up_blocks.3 + hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." + sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}." + unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) + + hf_mid_atn_prefix = "mid_block.attentions.0." + sd_mid_atn_prefix = "middle_block.1." + unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) for j in range(2): - # loop over resnets/attentions for downblocks - hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." - sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." - unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) + hf_mid_res_prefix = f"mid_block.resnets.{j}." + sd_mid_res_prefix = f"middle_block.{2*j}." + unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) - if i < 3: - # no attention layers in down_blocks.3 - hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." - sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." - unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) + # buyer beware: this is a *brittle* function, + # and correct output requires that all of these pieces interact in + # the exact order in which I have arranged them. + mapping = {k: k for k in unet_state_dict.keys()} + for sd_name, hf_name in unet_conversion_map: + mapping[hf_name] = sd_name + for k, v in mapping.items(): + if "resnets" in k: + for sd_part, hf_part in unet_conversion_map_resnet: + v = v.replace(hf_part, sd_part) + mapping[k] = v + for k, v in mapping.items(): + for sd_part, hf_part in unet_conversion_map_layer: + v = v.replace(hf_part, sd_part) + mapping[k] = v + new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()} - for j in range(3): - # loop over resnets/attentions for upblocks - hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." - sd_up_res_prefix = f"output_blocks.{3*i + j}.0." - unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) + if v2: + conv_transformer_to_linear(new_state_dict) - if i > 0: - # no attention layers in up_blocks.0 - hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." - sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." - unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) - - if i < 3: - # no downsample in down_blocks.3 - hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." - sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." - unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) - - # no upsample in up_blocks.3 - hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." - sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}." - unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) - - hf_mid_atn_prefix = "mid_block.attentions.0." - sd_mid_atn_prefix = "middle_block.1." - unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) - - for j in range(2): - hf_mid_res_prefix = f"mid_block.resnets.{j}." - sd_mid_res_prefix = f"middle_block.{2*j}." - unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) - - # buyer beware: this is a *brittle* function, - # and correct output requires that all of these pieces interact in - # the exact order in which I have arranged them. - mapping = {k: k for k in unet_state_dict.keys()} - for sd_name, hf_name in unet_conversion_map: - mapping[hf_name] = sd_name - for k, v in mapping.items(): - if "resnets" in k: - for sd_part, hf_part in unet_conversion_map_resnet: - v = v.replace(hf_part, sd_part) - mapping[k] = v - for k, v in mapping.items(): - for sd_part, hf_part in unet_conversion_map_layer: - v = v.replace(hf_part, sd_part) - mapping[k] = v - new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()} - - if v2: - conv_transformer_to_linear(new_state_dict) - - return new_state_dict + return new_state_dict # ================# # VAE Conversion # # ================# + def reshape_weight_for_sd(w): # convert HF linear weights to SD conv2d weights - return w.reshape(*w.shape, 1, 1) + return w.reshape(*w.shape, 1, 1) def convert_vae_state_dict(vae_state_dict): - vae_conversion_map = [ - # (stable-diffusion, HF Diffusers) - ("nin_shortcut", "conv_shortcut"), - ("norm_out", "conv_norm_out"), - ("mid.attn_1.", "mid_block.attentions.0."), - ] + vae_conversion_map = [ + # (stable-diffusion, HF Diffusers) + ("nin_shortcut", "conv_shortcut"), + ("norm_out", "conv_norm_out"), + ("mid.attn_1.", "mid_block.attentions.0."), + ] - for i in range(4): - # down_blocks have two resnets - for j in range(2): - hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}." - sd_down_prefix = f"encoder.down.{i}.block.{j}." - vae_conversion_map.append((sd_down_prefix, hf_down_prefix)) + for i in range(4): + # down_blocks have two resnets + for j in range(2): + hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}." + sd_down_prefix = f"encoder.down.{i}.block.{j}." + vae_conversion_map.append((sd_down_prefix, hf_down_prefix)) - if i < 3: - hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0." - sd_downsample_prefix = f"down.{i}.downsample." - vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix)) + if i < 3: + hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0." + sd_downsample_prefix = f"down.{i}.downsample." + vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix)) - hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." - sd_upsample_prefix = f"up.{3-i}.upsample." - vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix)) + hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." + sd_upsample_prefix = f"up.{3-i}.upsample." + vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix)) - # up_blocks have three resnets - # also, up blocks in hf are numbered in reverse from sd - for j in range(3): - hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}." - sd_up_prefix = f"decoder.up.{3-i}.block.{j}." - vae_conversion_map.append((sd_up_prefix, hf_up_prefix)) + # up_blocks have three resnets + # also, up blocks in hf are numbered in reverse from sd + for j in range(3): + hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}." + sd_up_prefix = f"decoder.up.{3-i}.block.{j}." + vae_conversion_map.append((sd_up_prefix, hf_up_prefix)) - # this part accounts for mid blocks in both the encoder and the decoder - for i in range(2): - hf_mid_res_prefix = f"mid_block.resnets.{i}." - sd_mid_res_prefix = f"mid.block_{i+1}." - vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix)) + # this part accounts for mid blocks in both the encoder and the decoder + for i in range(2): + hf_mid_res_prefix = f"mid_block.resnets.{i}." + sd_mid_res_prefix = f"mid.block_{i+1}." + vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix)) - vae_conversion_map_attn = [ - # (stable-diffusion, HF Diffusers) - ("norm.", "group_norm."), - ("q.", "query."), - ("k.", "key."), - ("v.", "value."), - ("proj_out.", "proj_attn."), - ] + vae_conversion_map_attn = [ + # (stable-diffusion, HF Diffusers) + ("norm.", "group_norm."), + ("q.", "query."), + ("k.", "key."), + ("v.", "value."), + ("proj_out.", "proj_attn."), + ] - mapping = {k: k for k in vae_state_dict.keys()} - for k, v in mapping.items(): - for sd_part, hf_part in vae_conversion_map: - v = v.replace(hf_part, sd_part) - mapping[k] = v - for k, v in mapping.items(): - if "attentions" in k: - for sd_part, hf_part in vae_conversion_map_attn: - v = v.replace(hf_part, sd_part) - mapping[k] = v - new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()} - weights_to_convert = ["q", "k", "v", "proj_out"] - for k, v in new_state_dict.items(): - for weight_name in weights_to_convert: - if f"mid.attn_1.{weight_name}.weight" in k: - # print(f"Reshaping {k} for SD format") - new_state_dict[k] = reshape_weight_for_sd(v) + mapping = {k: k for k in vae_state_dict.keys()} + for k, v in mapping.items(): + for sd_part, hf_part in vae_conversion_map: + v = v.replace(hf_part, sd_part) + mapping[k] = v + for k, v in mapping.items(): + if "attentions" in k: + for sd_part, hf_part in vae_conversion_map_attn: + v = v.replace(hf_part, sd_part) + mapping[k] = v + new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()} + weights_to_convert = ["q", "k", "v", "proj_out"] + for k, v in new_state_dict.items(): + for weight_name in weights_to_convert: + if f"mid.attn_1.{weight_name}.weight" in k: + # print(f"Reshaping {k} for SD format") + new_state_dict[k] = reshape_weight_for_sd(v) - return new_state_dict + return new_state_dict # endregion # region 自作のモデル読み書きなど + def is_safetensors(path): - return os.path.splitext(path)[1].lower() == '.safetensors' + return os.path.splitext(path)[1].lower() == ".safetensors" -def load_checkpoint_with_text_encoder_conversion(ckpt_path): - # text encoderの格納形式が違うモデルに対応する ('text_model'がない) - TEXT_ENCODER_KEY_REPLACEMENTS = [ - ('cond_stage_model.transformer.embeddings.', 'cond_stage_model.transformer.text_model.embeddings.'), - ('cond_stage_model.transformer.encoder.', 'cond_stage_model.transformer.text_model.encoder.'), - ('cond_stage_model.transformer.final_layer_norm.', 'cond_stage_model.transformer.text_model.final_layer_norm.') - ] +def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"): + # text encoderの格納形式が違うモデルに対応する ('text_model'がない) + TEXT_ENCODER_KEY_REPLACEMENTS = [ + ("cond_stage_model.transformer.embeddings.", "cond_stage_model.transformer.text_model.embeddings."), + ("cond_stage_model.transformer.encoder.", "cond_stage_model.transformer.text_model.encoder."), + ("cond_stage_model.transformer.final_layer_norm.", "cond_stage_model.transformer.text_model.final_layer_norm."), + ] - if is_safetensors(ckpt_path): - checkpoint = None - state_dict = load_file(ckpt_path, "cpu") - else: - checkpoint = torch.load(ckpt_path, map_location="cpu") - if "state_dict" in checkpoint: - state_dict = checkpoint["state_dict"] + if is_safetensors(ckpt_path): + checkpoint = None + state_dict = load_file(ckpt_path) # , device) # may causes error else: - state_dict = checkpoint - checkpoint = None + checkpoint = torch.load(ckpt_path, map_location=device) + if "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + else: + state_dict = checkpoint + checkpoint = None - key_reps = [] - for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS: - for key in state_dict.keys(): - if key.startswith(rep_from): - new_key = rep_to + key[len(rep_from):] - key_reps.append((key, new_key)) + key_reps = [] + for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS: + for key in state_dict.keys(): + if key.startswith(rep_from): + new_key = rep_to + key[len(rep_from) :] + key_reps.append((key, new_key)) - for key, new_key in key_reps: - state_dict[new_key] = state_dict[key] - del state_dict[key] + for key, new_key in key_reps: + state_dict[new_key] = state_dict[key] + del state_dict[key] - return checkpoint, state_dict + return checkpoint, state_dict # TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認 -def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None): - _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path) - if dtype is not None: - for k, v in state_dict.items(): - if type(v) is torch.Tensor: - state_dict[k] = v.to(dtype) +def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dtype=None): + _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path, device) - # Convert the UNet2DConditionModel model. - unet_config = create_unet_diffusers_config(v2) - converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config) + # Convert the UNet2DConditionModel model. + unet_config = create_unet_diffusers_config(v2) + converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config) - unet = UNet2DConditionModel(**unet_config) - info = unet.load_state_dict(converted_unet_checkpoint) - print("loading u-net:", info) + unet = UNet2DConditionModel(**unet_config).to(device) + info = unet.load_state_dict(converted_unet_checkpoint) + print("loading u-net:", info) - # Convert the VAE model. - vae_config = create_vae_diffusers_config() - converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config) + # Convert the VAE model. + vae_config = create_vae_diffusers_config() + converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config) - vae = AutoencoderKL(**vae_config) - info = vae.load_state_dict(converted_vae_checkpoint) - print("loading vae:", info) + vae = AutoencoderKL(**vae_config).to(device) + info = vae.load_state_dict(converted_vae_checkpoint) + print("loading vae:", info) - # convert text_model - if v2: - converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77) - cfg = CLIPTextConfig( - vocab_size=49408, - hidden_size=1024, - intermediate_size=4096, - num_hidden_layers=23, - num_attention_heads=16, - max_position_embeddings=77, - hidden_act="gelu", - layer_norm_eps=1e-05, - dropout=0.0, - attention_dropout=0.0, - initializer_range=0.02, - initializer_factor=1.0, - pad_token_id=1, - bos_token_id=0, - eos_token_id=2, - model_type="clip_text_model", - projection_dim=512, - torch_dtype="float32", - transformers_version="4.25.0.dev0", - ) - text_model = CLIPTextModel._from_config(cfg) - info = text_model.load_state_dict(converted_text_encoder_checkpoint) - else: - converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict) + # convert text_model + if v2: + converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77) + cfg = CLIPTextConfig( + vocab_size=49408, + hidden_size=1024, + intermediate_size=4096, + num_hidden_layers=23, + num_attention_heads=16, + max_position_embeddings=77, + hidden_act="gelu", + layer_norm_eps=1e-05, + dropout=0.0, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + model_type="clip_text_model", + projection_dim=512, + torch_dtype="float32", + transformers_version="4.25.0.dev0", + ) + text_model = CLIPTextModel._from_config(cfg) + info = text_model.load_state_dict(converted_text_encoder_checkpoint) + else: + converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict) - logging.set_verbosity_error() # don't show annoying warning - text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") - logging.set_verbosity_warning() + logging.set_verbosity_error() # don't show annoying warning + text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device) + logging.set_verbosity_warning() - info = text_model.load_state_dict(converted_text_encoder_checkpoint) - print("loading text encoder:", info) + info = text_model.load_state_dict(converted_text_encoder_checkpoint) + print("loading text encoder:", info) - return text_model, vae, unet + return text_model, vae, unet def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False): - def convert_key(key): - # position_idsの除去 - if ".position_ids" in key: - return None + def convert_key(key): + # position_idsの除去 + if ".position_ids" in key: + return None - # common - key = key.replace("text_model.encoder.", "transformer.") - key = key.replace("text_model.", "") - if "layers" in key: - # resblocks conversion - key = key.replace(".layers.", ".resblocks.") - if ".layer_norm" in key: - key = key.replace(".layer_norm", ".ln_") - elif ".mlp." in key: - key = key.replace(".fc1.", ".c_fc.") - key = key.replace(".fc2.", ".c_proj.") - elif '.self_attn.out_proj' in key: - key = key.replace(".self_attn.out_proj.", ".attn.out_proj.") - elif '.self_attn.' in key: - key = None # 特殊なので後で処理する - else: - raise ValueError(f"unexpected key in DiffUsers model: {key}") - elif '.position_embedding' in key: - key = key.replace("embeddings.position_embedding.weight", "positional_embedding") - elif '.token_embedding' in key: - key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight") - elif 'final_layer_norm' in key: - key = key.replace("final_layer_norm", "ln_final") - return key + # common + key = key.replace("text_model.encoder.", "transformer.") + key = key.replace("text_model.", "") + if "layers" in key: + # resblocks conversion + key = key.replace(".layers.", ".resblocks.") + if ".layer_norm" in key: + key = key.replace(".layer_norm", ".ln_") + elif ".mlp." in key: + key = key.replace(".fc1.", ".c_fc.") + key = key.replace(".fc2.", ".c_proj.") + elif ".self_attn.out_proj" in key: + key = key.replace(".self_attn.out_proj.", ".attn.out_proj.") + elif ".self_attn." in key: + key = None # 特殊なので後で処理する + else: + raise ValueError(f"unexpected key in DiffUsers model: {key}") + elif ".position_embedding" in key: + key = key.replace("embeddings.position_embedding.weight", "positional_embedding") + elif ".token_embedding" in key: + key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight") + elif "final_layer_norm" in key: + key = key.replace("final_layer_norm", "ln_final") + return key - keys = list(checkpoint.keys()) - new_sd = {} - for key in keys: - new_key = convert_key(key) - if new_key is None: - continue - new_sd[new_key] = checkpoint[key] - - # attnの変換 - for key in keys: - if 'layers' in key and 'q_proj' in key: - # 三つを結合 - key_q = key - key_k = key.replace("q_proj", "k_proj") - key_v = key.replace("q_proj", "v_proj") - - value_q = checkpoint[key_q] - value_k = checkpoint[key_k] - value_v = checkpoint[key_v] - value = torch.cat([value_q, value_k, value_v]) - - new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.") - new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_") - new_sd[new_key] = value - - # 最後の層などを捏造するか - if make_dummy_weights: - print("make dummy weights for resblock.23, text_projection and logit scale.") - keys = list(new_sd.keys()) + keys = list(checkpoint.keys()) + new_sd = {} for key in keys: - if key.startswith("transformer.resblocks.22."): - new_sd[key.replace(".22.", ".23.")] = new_sd[key].clone() # copyしないとsafetensorsの保存で落ちる + new_key = convert_key(key) + if new_key is None: + continue + new_sd[new_key] = checkpoint[key] - # Diffusersに含まれない重みを作っておく - new_sd['text_projection'] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device) - new_sd['logit_scale'] = torch.tensor(1) + # attnの変換 + for key in keys: + if "layers" in key and "q_proj" in key: + # 三つを結合 + key_q = key + key_k = key.replace("q_proj", "k_proj") + key_v = key.replace("q_proj", "v_proj") - return new_sd + value_q = checkpoint[key_q] + value_k = checkpoint[key_k] + value_v = checkpoint[key_v] + value = torch.cat([value_q, value_k, value_v]) + + new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.") + new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_") + new_sd[new_key] = value + + # 最後の層などを捏造するか + if make_dummy_weights: + print("make dummy weights for resblock.23, text_projection and logit scale.") + keys = list(new_sd.keys()) + for key in keys: + if key.startswith("transformer.resblocks.22."): + new_sd[key.replace(".22.", ".23.")] = new_sd[key].clone() # copyしないとsafetensorsの保存で落ちる + + # Diffusersに含まれない重みを作っておく + new_sd["text_projection"] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device) + new_sd["logit_scale"] = torch.tensor(1) + + return new_sd def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None, vae=None): - if ckpt_path is not None: - # epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む - checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path) - if checkpoint is None: # safetensors または state_dictのckpt - checkpoint = {} - strict = False + if ckpt_path is not None: + # epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む + checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path) + if checkpoint is None: # safetensors または state_dictのckpt + checkpoint = {} + strict = False + else: + strict = True + if "state_dict" in state_dict: + del state_dict["state_dict"] else: - strict = True - if "state_dict" in state_dict: - del state_dict["state_dict"] - else: - # 新しく作る - assert vae is not None, "VAE is required to save a checkpoint without a given checkpoint" - checkpoint = {} - state_dict = {} - strict = False + # 新しく作る + assert vae is not None, "VAE is required to save a checkpoint without a given checkpoint" + checkpoint = {} + state_dict = {} + strict = False - def update_sd(prefix, sd): - for k, v in sd.items(): - key = prefix + k - assert not strict or key in state_dict, f"Illegal key in save SD: {key}" - if save_dtype is not None: - v = v.detach().clone().to("cpu").to(save_dtype) - state_dict[key] = v + def update_sd(prefix, sd): + for k, v in sd.items(): + key = prefix + k + assert not strict or key in state_dict, f"Illegal key in save SD: {key}" + if save_dtype is not None: + v = v.detach().clone().to("cpu").to(save_dtype) + state_dict[key] = v - # Convert the UNet model - unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict()) - update_sd("model.diffusion_model.", unet_state_dict) + # Convert the UNet model + unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict()) + update_sd("model.diffusion_model.", unet_state_dict) - # Convert the text encoder model - if v2: - make_dummy = ckpt_path is None # 参照元のcheckpointがない場合は最後の層を前の層から複製して作るなどダミーの重みを入れる - text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict(), make_dummy) - update_sd("cond_stage_model.model.", text_enc_dict) - else: - text_enc_dict = text_encoder.state_dict() - update_sd("cond_stage_model.transformer.", text_enc_dict) + # Convert the text encoder model + if v2: + make_dummy = ckpt_path is None # 参照元のcheckpointがない場合は最後の層を前の層から複製して作るなどダミーの重みを入れる + text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict(), make_dummy) + update_sd("cond_stage_model.model.", text_enc_dict) + else: + text_enc_dict = text_encoder.state_dict() + update_sd("cond_stage_model.transformer.", text_enc_dict) - # Convert the VAE - if vae is not None: - vae_dict = convert_vae_state_dict(vae.state_dict()) - update_sd("first_stage_model.", vae_dict) + # Convert the VAE + if vae is not None: + vae_dict = convert_vae_state_dict(vae.state_dict()) + update_sd("first_stage_model.", vae_dict) - # Put together new checkpoint - key_count = len(state_dict.keys()) - new_ckpt = {'state_dict': state_dict} + # Put together new checkpoint + key_count = len(state_dict.keys()) + new_ckpt = {"state_dict": state_dict} - # epoch and global_step are sometimes not int - try: - if 'epoch' in checkpoint: - epochs += checkpoint['epoch'] - if 'global_step' in checkpoint: - steps += checkpoint['global_step'] - except: - pass + # epoch and global_step are sometimes not int + try: + if "epoch" in checkpoint: + epochs += checkpoint["epoch"] + if "global_step" in checkpoint: + steps += checkpoint["global_step"] + except: + pass - new_ckpt['epoch'] = epochs - new_ckpt['global_step'] = steps + new_ckpt["epoch"] = epochs + new_ckpt["global_step"] = steps - if is_safetensors(output_file): - # TODO Tensor以外のdictの値を削除したほうがいいか - save_file(state_dict, output_file) - else: - torch.save(new_ckpt, output_file) + if is_safetensors(output_file): + # TODO Tensor以外のdictの値を削除したほうがいいか + save_file(state_dict, output_file) + else: + torch.save(new_ckpt, output_file) - return key_count + return key_count def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False): - if pretrained_model_name_or_path is None: - # load default settings for v1/v2 - if v2: - pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V2 - else: - pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V1 + if pretrained_model_name_or_path is None: + # load default settings for v1/v2 + if v2: + pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V2 + else: + pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V1 - scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler") - tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer") - if vae is None: - vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae") + scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler") + tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer") + if vae is None: + vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae") - pipeline = StableDiffusionPipeline( - unet=unet, - text_encoder=text_encoder, - vae=vae, - scheduler=scheduler, - tokenizer=tokenizer, - safety_checker=None, - feature_extractor=None, - requires_safety_checker=None, - ) - pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors) + pipeline = StableDiffusionPipeline( + unet=unet, + text_encoder=text_encoder, + vae=vae, + scheduler=scheduler, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=None, + ) + pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors) VAE_PREFIX = "first_stage_model." def load_vae(vae_id, dtype): - print(f"load VAE: {vae_id}") - if os.path.isdir(vae_id) or not os.path.isfile(vae_id): - # Diffusers local/remote - try: - vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype) - except EnvironmentError as e: - print(f"exception occurs in loading vae: {e}") - print("retry with subfolder='vae'") - vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype) + print(f"load VAE: {vae_id}") + if os.path.isdir(vae_id) or not os.path.isfile(vae_id): + # Diffusers local/remote + try: + vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype) + except EnvironmentError as e: + print(f"exception occurs in loading vae: {e}") + print("retry with subfolder='vae'") + vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype) + return vae + + # local + vae_config = create_vae_diffusers_config() + + if vae_id.endswith(".bin"): + # SD 1.5 VAE on Huggingface + converted_vae_checkpoint = torch.load(vae_id, map_location="cpu") + else: + # StableDiffusion + vae_model = load_file(vae_id, "cpu") if is_safetensors(vae_id) else torch.load(vae_id, map_location="cpu") + vae_sd = vae_model["state_dict"] if "state_dict" in vae_model else vae_model + + # vae only or full model + full_model = False + for vae_key in vae_sd: + if vae_key.startswith(VAE_PREFIX): + full_model = True + break + if not full_model: + sd = {} + for key, value in vae_sd.items(): + sd[VAE_PREFIX + key] = value + vae_sd = sd + del sd + + # Convert the VAE model. + converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config) + + vae = AutoencoderKL(**vae_config) + vae.load_state_dict(converted_vae_checkpoint) return vae - # local - vae_config = create_vae_diffusers_config() - - if vae_id.endswith(".bin"): - # SD 1.5 VAE on Huggingface - converted_vae_checkpoint = torch.load(vae_id, map_location="cpu") - else: - # StableDiffusion - vae_model = (load_file(vae_id, "cpu") if is_safetensors(vae_id) - else torch.load(vae_id, map_location="cpu")) - vae_sd = vae_model['state_dict'] if 'state_dict' in vae_model else vae_model - - # vae only or full model - full_model = False - for vae_key in vae_sd: - if vae_key.startswith(VAE_PREFIX): - full_model = True - break - if not full_model: - sd = {} - for key, value in vae_sd.items(): - sd[VAE_PREFIX + key] = value - vae_sd = sd - del sd - - # Convert the VAE model. - converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config) - - vae = AutoencoderKL(**vae_config) - vae.load_state_dict(converted_vae_checkpoint) - return vae # endregion def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64): - max_width, max_height = max_reso - max_area = (max_width // divisible) * (max_height // divisible) + max_width, max_height = max_reso + max_area = (max_width // divisible) * (max_height // divisible) - resos = set() + resos = set() - size = int(math.sqrt(max_area)) * divisible - resos.add((size, size)) + size = int(math.sqrt(max_area)) * divisible + resos.add((size, size)) - size = min_size - while size <= max_size: - width = size - height = min(max_size, (max_area // (width // divisible)) * divisible) - resos.add((width, height)) - resos.add((height, width)) + size = min_size + while size <= max_size: + width = size + height = min(max_size, (max_area // (width // divisible)) * divisible) + resos.add((width, height)) + resos.add((height, width)) - # # make additional resos - # if width >= height and width - divisible >= min_size: - # resos.add((width - divisible, height)) - # resos.add((height, width - divisible)) - # if height >= width and height - divisible >= min_size: - # resos.add((width, height - divisible)) - # resos.add((height - divisible, width)) + # # make additional resos + # if width >= height and width - divisible >= min_size: + # resos.add((width - divisible, height)) + # resos.add((height, width - divisible)) + # if height >= width and height - divisible >= min_size: + # resos.add((width, height - divisible)) + # resos.add((height - divisible, width)) - size += divisible + size += divisible - resos = list(resos) - resos.sort() - return resos + resos = list(resos) + resos.sort() + return resos -if __name__ == '__main__': - resos = make_bucket_resolutions((512, 768)) - print(len(resos)) - print(resos) - aspect_ratios = [w / h for w, h in resos] - print(aspect_ratios) +if __name__ == "__main__": + resos = make_bucket_resolutions((512, 768)) + print(len(resos)) + print(resos) + aspect_ratios = [w / h for w, h in resos] + print(aspect_ratios) - ars = set() - for ar in aspect_ratios: - if ar in ars: - print("error! duplicate ar:", ar) - ars.add(ar) + ars = set() + for ar in aspect_ratios: + if ar in ars: + print("error! duplicate ar:", ar) + ars.add(ar) diff --git a/library/train_util.py b/library/train_util.py index e1a8e92..59dbc44 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -404,6 +404,8 @@ class BaseDataset(torch.utils.data.Dataset): self.token_padding_disabled = False self.tag_frequency = {} + self.XTI_layers = None + self.token_strings = None self.enable_bucket = False self.bucket_manager: BucketManager = None # not initialized @@ -464,6 +466,10 @@ class BaseDataset(torch.utils.data.Dataset): def disable_token_padding(self): self.token_padding_disabled = True + def enable_XTI(self, layers=None, token_strings=None): + self.XTI_layers = layers + self.token_strings = token_strings + def add_replacement(self, str_from, str_to): self.replacements[str_from] = str_to @@ -909,9 +915,22 @@ class BaseDataset(torch.utils.data.Dataset): latents_list.append(latents) caption = self.process_caption(subset, image_info.caption) - captions.append(caption) + if self.XTI_layers: + caption_layer = [] + for layer in self.XTI_layers: + token_strings_from = " ".join(self.token_strings) + token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings]) + caption_ = caption.replace(token_strings_from, token_strings_to) + caption_layer.append(caption_) + captions.append(caption_layer) + else: + captions.append(caption) if not self.token_padding_disabled: # this option might be omitted in future - input_ids_list.append(self.get_input_ids(caption)) + if self.XTI_layers: + token_caption = self.get_input_ids(caption_layer) + else: + token_caption = self.get_input_ids(caption) + input_ids_list.append(token_caption) example = {} example["loss_weights"] = torch.FloatTensor(loss_weights) @@ -1314,6 +1333,10 @@ class DatasetGroup(torch.utils.data.ConcatDataset): # for dataset in self.datasets: # dataset.make_buckets() + def enable_XTI(self, *args, **kwargs): + for dataset in self.datasets: + dataset.enable_XTI(*args, **kwargs) + def cache_latents(self, vae, vae_batch_size=1): for i, dataset in enumerate(self.datasets): print(f"[Dataset {i}]") @@ -2617,14 +2640,15 @@ def prepare_dtype(args: argparse.Namespace): return weight_dtype, save_dtype -def load_target_model(args: argparse.Namespace, weight_dtype): +def load_target_model(args: argparse.Namespace, weight_dtype, device='cpu'): name_or_path = args.pretrained_model_name_or_path name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers if load_stable_diffusion_format: print("load StableDiffusion checkpoint") - text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, name_or_path) + text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, name_or_path, device) else: + # Diffusers model is loaded to CPU print("load Diffusers pretrained models") try: pipe = StableDiffusionPipeline.from_pretrained(name_or_path, tokenizer=None, safety_checker=None) diff --git a/networks/lora.py b/networks/lora.py index 6d3875d..2bf7851 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -13,386 +13,471 @@ from library import train_util class LoRAModule(torch.nn.Module): - """ - replaces forward method of the original Linear, instead of replacing the original Linear module. - """ + """ + replaces forward method of the original Linear, instead of replacing the original Linear module. + """ - def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1): - """ if alpha == 0 or None, alpha is rank (no scaling). """ - super().__init__() - self.lora_name = lora_name + def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1): + """if alpha == 0 or None, alpha is rank (no scaling).""" + super().__init__() + self.lora_name = lora_name - if org_module.__class__.__name__ == 'Conv2d': - in_dim = org_module.in_channels - out_dim = org_module.out_channels - else: - in_dim = org_module.in_features - out_dim = org_module.out_features + if org_module.__class__.__name__ == "Conv2d": + in_dim = org_module.in_channels + out_dim = org_module.out_channels + else: + in_dim = org_module.in_features + out_dim = org_module.out_features - # if limit_rank: - # self.lora_dim = min(lora_dim, in_dim, out_dim) - # if self.lora_dim != lora_dim: - # print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}") - # else: - self.lora_dim = lora_dim + # if limit_rank: + # self.lora_dim = min(lora_dim, in_dim, out_dim) + # if self.lora_dim != lora_dim: + # print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}") + # else: + self.lora_dim = lora_dim - if org_module.__class__.__name__ == 'Conv2d': - kernel_size = org_module.kernel_size - stride = org_module.stride - padding = org_module.padding - self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) - self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) - else: - self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) - self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) + if org_module.__class__.__name__ == "Conv2d": + kernel_size = org_module.kernel_size + stride = org_module.stride + padding = org_module.padding + self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) + self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) + else: + self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) + self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) - if type(alpha) == torch.Tensor: - alpha = alpha.detach().float().numpy() # without casting, bf16 causes error - alpha = self.lora_dim if alpha is None or alpha == 0 else alpha - self.scale = alpha / self.lora_dim - self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える + if type(alpha) == torch.Tensor: + alpha = alpha.detach().float().numpy() # without casting, bf16 causes error + alpha = self.lora_dim if alpha is None or alpha == 0 else alpha + self.scale = alpha / self.lora_dim + self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える - # same as microsoft's - torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) - torch.nn.init.zeros_(self.lora_up.weight) + # same as microsoft's + torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) + torch.nn.init.zeros_(self.lora_up.weight) - self.multiplier = multiplier - self.org_module = org_module # remove in applying - self.region = None - self.region_mask = None + self.multiplier = multiplier + self.org_module = org_module # remove in applying + self.region = None + self.region_mask = None - def apply_to(self): - self.org_forward = self.org_module.forward - self.org_module.forward = self.forward - del self.org_module + def apply_to(self): + self.org_forward = self.org_module.forward + self.org_module.forward = self.forward + del self.org_module - def set_region(self, region): - self.region = region - self.region_mask = None + def merge_to(self, sd, dtype, device): + # get up/down weight + up_weight = sd["lora_up.weight"].to(torch.float).to(device) + down_weight = sd["lora_down.weight"].to(torch.float).to(device) - def forward(self, x): - if self.region is None: - return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale + # extract weight from org_module + org_sd = self.org_module.state_dict() + weight = org_sd["weight"].to(torch.float) - # regional LoRA FIXME same as additional-network extension - if x.size()[1] % 77 == 0: - # print(f"LoRA for context: {self.lora_name}") - self.region = None - return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale + # merge weight + if len(weight.size()) == 2: + # linear + weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + weight + + self.multiplier + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * self.scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + # print(conved.size(), weight.size(), module.stride, module.padding) + weight = weight + self.multiplier * conved * self.scale - # calculate region mask first time - if self.region_mask is None: - if len(x.size()) == 4: - h, w = x.size()[2:4] - else: - seq_len = x.size()[1] - ratio = math.sqrt((self.region.size()[0] * self.region.size()[1]) / seq_len) - h = int(self.region.size()[0] / ratio + .5) - w = seq_len // h + # set weight to org_module + org_sd["weight"] = weight.to(dtype) + self.org_module.load_state_dict(org_sd) - r = self.region.to(x.device) - if r.dtype == torch.bfloat16: - r = r.to(torch.float) - r = r.unsqueeze(0).unsqueeze(1) - # print(self.lora_name, self.region.size(), x.size(), r.size(), h, w) - r = torch.nn.functional.interpolate(r, (h, w), mode='bilinear') - r = r.to(x.dtype) + def set_region(self, region): + self.region = region + self.region_mask = None - if len(x.size()) == 3: - r = torch.reshape(r, (1, x.size()[1], -1)) + def forward(self, x): + if self.region is None: + return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale - self.region_mask = r + # regional LoRA FIXME same as additional-network extension + if x.size()[1] % 77 == 0: + # print(f"LoRA for context: {self.lora_name}") + self.region = None + return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale - return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale * self.region_mask + # calculate region mask first time + if self.region_mask is None: + if len(x.size()) == 4: + h, w = x.size()[2:4] + else: + seq_len = x.size()[1] + ratio = math.sqrt((self.region.size()[0] * self.region.size()[1]) / seq_len) + h = int(self.region.size()[0] / ratio + 0.5) + w = seq_len // h + + r = self.region.to(x.device) + if r.dtype == torch.bfloat16: + r = r.to(torch.float) + r = r.unsqueeze(0).unsqueeze(1) + # print(self.lora_name, self.region.size(), x.size(), r.size(), h, w) + r = torch.nn.functional.interpolate(r, (h, w), mode="bilinear") + r = r.to(x.dtype) + + if len(x.size()) == 3: + r = torch.reshape(r, (1, x.size()[1], -1)) + + self.region_mask = r + + return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale * self.region_mask def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs): - if network_dim is None: - network_dim = 4 # default + if network_dim is None: + network_dim = 4 # default - # extract dim/alpha for conv2d, and block dim - conv_dim = kwargs.get('conv_dim', None) - conv_alpha = kwargs.get('conv_alpha', None) - if conv_dim is not None: - conv_dim = int(conv_dim) - if conv_alpha is None: - conv_alpha = 1.0 - else: - conv_alpha = float(conv_alpha) + # extract dim/alpha for conv2d, and block dim + conv_dim = kwargs.get("conv_dim", None) + conv_alpha = kwargs.get("conv_alpha", None) + if conv_dim is not None: + conv_dim = int(conv_dim) + if conv_alpha is None: + conv_alpha = 1.0 + else: + conv_alpha = float(conv_alpha) - """ - block_dims = kwargs.get("block_dims") - block_alphas = None + """ + block_dims = kwargs.get("block_dims") + block_alphas = None - if block_dims is not None: + if block_dims is not None: block_dims = [int(d) for d in block_dims.split(',')] assert len(block_dims) == NUM_BLOCKS, f"Number of block dimensions is not same to {NUM_BLOCKS}" block_alphas = kwargs.get("block_alphas") if block_alphas is None: - block_alphas = [1] * len(block_dims) + block_alphas = [1] * len(block_dims) else: - block_alphas = [int(a) for a in block_alphas(',')] + block_alphas = [int(a) for a in block_alphas(',')] assert len(block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}" - conv_block_dims = kwargs.get("conv_block_dims") - conv_block_alphas = None + conv_block_dims = kwargs.get("conv_block_dims") + conv_block_alphas = None - if conv_block_dims is not None: + if conv_block_dims is not None: conv_block_dims = [int(d) for d in conv_block_dims.split(',')] assert len(conv_block_dims) == NUM_BLOCKS, f"Number of block dimensions is not same to {NUM_BLOCKS}" conv_block_alphas = kwargs.get("conv_block_alphas") if conv_block_alphas is None: - conv_block_alphas = [1] * len(conv_block_dims) + conv_block_alphas = [1] * len(conv_block_dims) else: - conv_block_alphas = [int(a) for a in conv_block_alphas(',')] + conv_block_alphas = [int(a) for a in conv_block_alphas(',')] assert len(conv_block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}" """ - network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim, - alpha=network_alpha, conv_lora_dim=conv_dim, conv_alpha=conv_alpha) - return network + network = LoRANetwork( + text_encoder, + unet, + multiplier=multiplier, + lora_dim=network_dim, + alpha=network_alpha, + conv_lora_dim=conv_dim, + conv_alpha=conv_alpha, + ) + return network def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, **kwargs): - if weights_sd is None: - if os.path.splitext(file)[1] == '.safetensors': - from safetensors.torch import load_file, safe_open - weights_sd = load_file(file) - else: - weights_sd = torch.load(file, map_location='cpu') + if weights_sd is None: + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file, safe_open - # get dim/alpha mapping - modules_dim = {} - modules_alpha = {} - for key, value in weights_sd.items(): - if '.' not in key: - continue + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") - lora_name = key.split('.')[0] - if 'alpha' in key: - modules_alpha[lora_name] = value - elif 'lora_down' in key: - dim = value.size()[0] - modules_dim[lora_name] = dim - # print(lora_name, value.size(), dim) + # get dim/alpha mapping + modules_dim = {} + modules_alpha = {} + for key, value in weights_sd.items(): + if "." not in key: + continue - # support old LoRA without alpha - for key in modules_dim.keys(): - if key not in modules_alpha: - modules_alpha = modules_dim[key] + lora_name = key.split(".")[0] + if "alpha" in key: + modules_alpha[lora_name] = value + elif "lora_down" in key: + dim = value.size()[0] + modules_dim[lora_name] = dim + # print(lora_name, value.size(), dim) - network = LoRANetwork(text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha) - network.weights_sd = weights_sd - return network + # support old LoRA without alpha + for key in modules_dim.keys(): + if key not in modules_alpha: + modules_alpha = modules_dim[key] + + network = LoRANetwork(text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha) + network.weights_sd = weights_sd + return network class LoRANetwork(torch.nn.Module): - # is it possible to apply conv_in and conv_out? - UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"] - UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] - TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] - LORA_PREFIX_UNET = 'lora_unet' - LORA_PREFIX_TEXT_ENCODER = 'lora_te' + # is it possible to apply conv_in and conv_out? + UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"] + UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] + TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] + LORA_PREFIX_UNET = "lora_unet" + LORA_PREFIX_TEXT_ENCODER = "lora_te" - def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4, alpha=1, conv_lora_dim=None, conv_alpha=None, modules_dim=None, modules_alpha=None) -> None: - super().__init__() - self.multiplier = multiplier + def __init__( + self, + text_encoder, + unet, + multiplier=1.0, + lora_dim=4, + alpha=1, + conv_lora_dim=None, + conv_alpha=None, + modules_dim=None, + modules_alpha=None, + ) -> None: + super().__init__() + self.multiplier = multiplier - self.lora_dim = lora_dim - self.alpha = alpha - self.conv_lora_dim = conv_lora_dim - self.conv_alpha = conv_alpha + self.lora_dim = lora_dim + self.alpha = alpha + self.conv_lora_dim = conv_lora_dim + self.conv_alpha = conv_alpha - if modules_dim is not None: - print(f"create LoRA network from weights") - else: - print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") + if modules_dim is not None: + print(f"create LoRA network from weights") + else: + print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") - self.apply_to_conv2d_3x3 = self.conv_lora_dim is not None - if self.apply_to_conv2d_3x3: - if self.conv_alpha is None: - self.conv_alpha = self.alpha - print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") + self.apply_to_conv2d_3x3 = self.conv_lora_dim is not None + if self.apply_to_conv2d_3x3: + if self.conv_alpha is None: + self.conv_alpha = self.alpha + print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") - # create module instances - def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]: - loras = [] - for name, module in root_module.named_modules(): - if module.__class__.__name__ in target_replace_modules: - # TODO get block index here - for child_name, child_module in module.named_modules(): - is_linear = child_module.__class__.__name__ == "Linear" - is_conv2d = child_module.__class__.__name__ == "Conv2d" - is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) - if is_linear or is_conv2d: - lora_name = prefix + '.' + name + '.' + child_name - lora_name = lora_name.replace('.', '_') + # create module instances + def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]: + loras = [] + for name, module in root_module.named_modules(): + if module.__class__.__name__ in target_replace_modules: + # TODO get block index here + for child_name, child_module in module.named_modules(): + is_linear = child_module.__class__.__name__ == "Linear" + is_conv2d = child_module.__class__.__name__ == "Conv2d" + is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) + if is_linear or is_conv2d: + lora_name = prefix + "." + name + "." + child_name + lora_name = lora_name.replace(".", "_") - if modules_dim is not None: - if lora_name not in modules_dim: - continue # no LoRA module in this weights file - dim = modules_dim[lora_name] - alpha = modules_alpha[lora_name] - else: - if is_linear or is_conv2d_1x1: - dim = self.lora_dim - alpha = self.alpha - elif self.apply_to_conv2d_3x3: - dim = self.conv_lora_dim - alpha = self.conv_alpha - else: - continue + if modules_dim is not None: + if lora_name not in modules_dim: + continue # no LoRA module in this weights file + dim = modules_dim[lora_name] + alpha = modules_alpha[lora_name] + else: + if is_linear or is_conv2d_1x1: + dim = self.lora_dim + alpha = self.alpha + elif self.apply_to_conv2d_3x3: + dim = self.conv_lora_dim + alpha = self.conv_alpha + else: + continue - lora = LoRAModule(lora_name, child_module, self.multiplier, dim, alpha) - loras.append(lora) - return loras + lora = LoRAModule(lora_name, child_module, self.multiplier, dim, alpha) + loras.append(lora) + return loras - self.text_encoder_loras = create_modules(LoRANetwork.LORA_PREFIX_TEXT_ENCODER, - text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) - print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") + self.text_encoder_loras = create_modules( + LoRANetwork.LORA_PREFIX_TEXT_ENCODER, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE + ) + print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") - # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights - target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE - if modules_dim is not None or self.conv_lora_dim is not None: - target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 + # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights + target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE + if modules_dim is not None or self.conv_lora_dim is not None: + target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 - self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, target_modules) - print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") + self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, target_modules) + print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") - self.weights_sd = None + self.weights_sd = None - # assertion - names = set() - for lora in self.text_encoder_loras + self.unet_loras: - assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" - names.add(lora.lora_name) + # assertion + names = set() + for lora in self.text_encoder_loras + self.unet_loras: + assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" + names.add(lora.lora_name) - def set_multiplier(self, multiplier): - self.multiplier = multiplier - for lora in self.text_encoder_loras + self.unet_loras: - lora.multiplier = self.multiplier + def set_multiplier(self, multiplier): + self.multiplier = multiplier + for lora in self.text_encoder_loras + self.unet_loras: + lora.multiplier = self.multiplier - def load_weights(self, file): - if os.path.splitext(file)[1] == '.safetensors': - from safetensors.torch import load_file, safe_open - self.weights_sd = load_file(file) - else: - self.weights_sd = torch.load(file, map_location='cpu') + def load_weights(self, file): + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file, safe_open - def apply_to(self, text_encoder, unet, apply_text_encoder=None, apply_unet=None): - if self.weights_sd: - weights_has_text_encoder = weights_has_unet = False - for key in self.weights_sd.keys(): - if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER): - weights_has_text_encoder = True - elif key.startswith(LoRANetwork.LORA_PREFIX_UNET): - weights_has_unet = True + self.weights_sd = load_file(file) + else: + self.weights_sd = torch.load(file, map_location="cpu") - if apply_text_encoder is None: - apply_text_encoder = weights_has_text_encoder - else: - assert apply_text_encoder == weights_has_text_encoder, f"text encoder weights: {weights_has_text_encoder} but text encoder flag: {apply_text_encoder} / 重みとText Encoderのフラグが矛盾しています" + def apply_to(self, text_encoder, unet, apply_text_encoder=None, apply_unet=None): + if self.weights_sd: + weights_has_text_encoder = weights_has_unet = False + for key in self.weights_sd.keys(): + if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER): + weights_has_text_encoder = True + elif key.startswith(LoRANetwork.LORA_PREFIX_UNET): + weights_has_unet = True - if apply_unet is None: - apply_unet = weights_has_unet - else: - assert apply_unet == weights_has_unet, f"u-net weights: {weights_has_unet} but u-net flag: {apply_unet} / 重みとU-Netのフラグが矛盾しています" - else: - assert apply_text_encoder is not None and apply_unet is not None, f"internal error: flag not set" + if apply_text_encoder is None: + apply_text_encoder = weights_has_text_encoder + else: + assert ( + apply_text_encoder == weights_has_text_encoder + ), f"text encoder weights: {weights_has_text_encoder} but text encoder flag: {apply_text_encoder} / 重みとText Encoderのフラグが矛盾しています" - if apply_text_encoder: - print("enable LoRA for text encoder") - else: - self.text_encoder_loras = [] + if apply_unet is None: + apply_unet = weights_has_unet + else: + assert ( + apply_unet == weights_has_unet + ), f"u-net weights: {weights_has_unet} but u-net flag: {apply_unet} / 重みとU-Netのフラグが矛盾しています" + else: + assert apply_text_encoder is not None and apply_unet is not None, f"internal error: flag not set" - if apply_unet: - print("enable LoRA for U-Net") - else: - self.unet_loras = [] + if apply_text_encoder: + print("enable LoRA for text encoder") + else: + self.text_encoder_loras = [] - for lora in self.text_encoder_loras + self.unet_loras: - lora.apply_to() - self.add_module(lora.lora_name, lora) + if apply_unet: + print("enable LoRA for U-Net") + else: + self.unet_loras = [] - if self.weights_sd: - # if some weights are not in state dict, it is ok because initial LoRA does nothing (lora_up is initialized by zeros) - info = self.load_state_dict(self.weights_sd, False) - print(f"weights are loaded: {info}") + for lora in self.text_encoder_loras + self.unet_loras: + lora.apply_to() + self.add_module(lora.lora_name, lora) - def enable_gradient_checkpointing(self): - # not supported - pass + if self.weights_sd: + # if some weights are not in state dict, it is ok because initial LoRA does nothing (lora_up is initialized by zeros) + info = self.load_state_dict(self.weights_sd, False) + print(f"weights are loaded: {info}") - def prepare_optimizer_params(self, text_encoder_lr, unet_lr): - def enumerate_params(loras): - params = [] - for lora in loras: - params.extend(lora.parameters()) - return params + # TODO refactor to common function with apply_to + def merge_to(self, text_encoder, unet, dtype, device): + assert self.weights_sd is not None, "weights are not loaded" - self.requires_grad_(True) - all_params = [] + apply_text_encoder = apply_unet = False + for key in self.weights_sd.keys(): + if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER): + apply_text_encoder = True + elif key.startswith(LoRANetwork.LORA_PREFIX_UNET): + apply_unet = True - if self.text_encoder_loras: - param_data = {'params': enumerate_params(self.text_encoder_loras)} - if text_encoder_lr is not None: - param_data['lr'] = text_encoder_lr - all_params.append(param_data) + if apply_text_encoder: + print("enable LoRA for text encoder") + else: + self.text_encoder_loras = [] - if self.unet_loras: - param_data = {'params': enumerate_params(self.unet_loras)} - if unet_lr is not None: - param_data['lr'] = unet_lr - all_params.append(param_data) + if apply_unet: + print("enable LoRA for U-Net") + else: + self.unet_loras = [] - return all_params + for lora in self.text_encoder_loras + self.unet_loras: + sd_for_lora = {} + for key in self.weights_sd.keys(): + if key.startswith(lora.lora_name): + sd_for_lora[key[len(lora.lora_name) + 1 :]] = self.weights_sd[key] + lora.merge_to(sd_for_lora, dtype, device) + print(f"weights are merged") - def prepare_grad_etc(self, text_encoder, unet): - self.requires_grad_(True) + def enable_gradient_checkpointing(self): + # not supported + pass - def on_epoch_start(self, text_encoder, unet): - self.train() + def prepare_optimizer_params(self, text_encoder_lr, unet_lr): + def enumerate_params(loras): + params = [] + for lora in loras: + params.extend(lora.parameters()) + return params - def get_trainable_params(self): - return self.parameters() + self.requires_grad_(True) + all_params = [] - def save_weights(self, file, dtype, metadata): - if metadata is not None and len(metadata) == 0: - metadata = None + if self.text_encoder_loras: + param_data = {"params": enumerate_params(self.text_encoder_loras)} + if text_encoder_lr is not None: + param_data["lr"] = text_encoder_lr + all_params.append(param_data) - state_dict = self.state_dict() + if self.unet_loras: + param_data = {"params": enumerate_params(self.unet_loras)} + if unet_lr is not None: + param_data["lr"] = unet_lr + all_params.append(param_data) - if dtype is not None: - for key in list(state_dict.keys()): - v = state_dict[key] - v = v.detach().clone().to("cpu").to(dtype) - state_dict[key] = v + return all_params - if os.path.splitext(file)[1] == '.safetensors': - from safetensors.torch import save_file + def prepare_grad_etc(self, text_encoder, unet): + self.requires_grad_(True) - # Precalculate model hashes to save time on indexing - if metadata is None: - metadata = {} - model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) - metadata["sshs_model_hash"] = model_hash - metadata["sshs_legacy_hash"] = legacy_hash + def on_epoch_start(self, text_encoder, unet): + self.train() - save_file(state_dict, file, metadata) - else: - torch.save(state_dict, file) + def get_trainable_params(self): + return self.parameters() - @ staticmethod - def set_regions(networks, image): - image = image.astype(np.float32) / 255.0 - for i, network in enumerate(networks[:3]): - # NOTE: consider averaging overwrapping area - region = image[:, :, i] - if region.max() == 0: - continue - region = torch.tensor(region) - network.set_region(region) + def save_weights(self, file, dtype, metadata): + if metadata is not None and len(metadata) == 0: + metadata = None - def set_region(self, region): - for lora in self.unet_loras: - lora.set_region(region) + state_dict = self.state_dict() + + if dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(dtype) + state_dict[key] = v + + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file + + # Precalculate model hashes to save time on indexing + if metadata is None: + metadata = {} + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash + + save_file(state_dict, file, metadata) + else: + torch.save(state_dict, file) + + @staticmethod + def set_regions(networks, image): + image = image.astype(np.float32) / 255.0 + for i, network in enumerate(networks[:3]): + # NOTE: consider averaging overwrapping area + region = image[:, :, i] + if region.max() == 0: + continue + region = torch.tensor(region) + network.set_region(region) + + def set_region(self, region): + for lora in self.unet_loras: + lora.set_region(region) diff --git a/networks/merge_lora.py b/networks/merge_lora.py index 8d97392..2fa8861 100644 --- a/networks/merge_lora.py +++ b/networks/merge_lora.py @@ -1,4 +1,3 @@ - import math import argparse import os @@ -9,216 +8,236 @@ import lora def load_state_dict(file_name, dtype): - if os.path.splitext(file_name)[1] == '.safetensors': - sd = load_file(file_name) - else: - sd = torch.load(file_name, map_location='cpu') - for key in list(sd.keys()): - if type(sd[key]) == torch.Tensor: - sd[key] = sd[key].to(dtype) - return sd + if os.path.splitext(file_name)[1] == ".safetensors": + sd = load_file(file_name) + else: + sd = torch.load(file_name, map_location="cpu") + for key in list(sd.keys()): + if type(sd[key]) == torch.Tensor: + sd[key] = sd[key].to(dtype) + return sd def save_to_file(file_name, model, state_dict, dtype): - if dtype is not None: - for key in list(state_dict.keys()): - if type(state_dict[key]) == torch.Tensor: - state_dict[key] = state_dict[key].to(dtype) + if dtype is not None: + for key in list(state_dict.keys()): + if type(state_dict[key]) == torch.Tensor: + state_dict[key] = state_dict[key].to(dtype) - if os.path.splitext(file_name)[1] == '.safetensors': - save_file(model, file_name) - else: - torch.save(model, file_name) + if os.path.splitext(file_name)[1] == ".safetensors": + save_file(model, file_name) + else: + torch.save(model, file_name) def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype): - text_encoder.to(merge_dtype) - unet.to(merge_dtype) + text_encoder.to(merge_dtype) + unet.to(merge_dtype) - # create module map - name_to_module = {} - for i, root_module in enumerate([text_encoder, unet]): - if i == 0: - prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER - target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE - else: - prefix = lora.LoRANetwork.LORA_PREFIX_UNET - target_replace_modules = lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE - - for name, module in root_module.named_modules(): - if module.__class__.__name__ in target_replace_modules: - for child_name, child_module in module.named_modules(): - if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d": - lora_name = prefix + '.' + name + '.' + child_name - lora_name = lora_name.replace('.', '_') - name_to_module[lora_name] = child_module - - for model, ratio in zip(models, ratios): - print(f"loading: {model}") - lora_sd = load_state_dict(model, merge_dtype) - - print(f"merging...") - for key in lora_sd.keys(): - if "lora_down" in key: - up_key = key.replace("lora_down", "lora_up") - alpha_key = key[:key.index("lora_down")] + 'alpha' - - # find original module for this lora - module_name = '.'.join(key.split('.')[:-2]) # remove trailing ".lora_down.weight" - if module_name not in name_to_module: - print(f"no module found for LoRA weight: {key}") - continue - module = name_to_module[module_name] - # print(f"apply {key} to {module}") - - down_weight = lora_sd[key] - up_weight = lora_sd[up_key] - - dim = down_weight.size()[0] - alpha = lora_sd.get(alpha_key, dim) - scale = alpha / dim - - # W <- W + U * D - weight = module.weight - # print(module_name, down_weight.size(), up_weight.size()) - if len(weight.size()) == 2: - # linear - weight = weight + ratio * (up_weight @ down_weight) * scale - elif down_weight.size()[2:4] == (1, 1): - # conv2d 1x1 - weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2) - ).unsqueeze(2).unsqueeze(3) * scale + # create module map + name_to_module = {} + for i, root_module in enumerate([text_encoder, unet]): + if i == 0: + prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER + target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE else: - # conv2d 3x3 - conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) - # print(conved.size(), weight.size(), module.stride, module.padding) - weight = weight + ratio * conved * scale + prefix = lora.LoRANetwork.LORA_PREFIX_UNET + target_replace_modules = ( + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 + ) - module.weight = torch.nn.Parameter(weight) + for name, module in root_module.named_modules(): + if module.__class__.__name__ in target_replace_modules: + for child_name, child_module in module.named_modules(): + if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d": + lora_name = prefix + "." + name + "." + child_name + lora_name = lora_name.replace(".", "_") + name_to_module[lora_name] = child_module + + for model, ratio in zip(models, ratios): + print(f"loading: {model}") + lora_sd = load_state_dict(model, merge_dtype) + + print(f"merging...") + for key in lora_sd.keys(): + if "lora_down" in key: + up_key = key.replace("lora_down", "lora_up") + alpha_key = key[: key.index("lora_down")] + "alpha" + + # find original module for this lora + module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight" + if module_name not in name_to_module: + print(f"no module found for LoRA weight: {key}") + continue + module = name_to_module[module_name] + # print(f"apply {key} to {module}") + + down_weight = lora_sd[key] + up_weight = lora_sd[up_key] + + dim = down_weight.size()[0] + alpha = lora_sd.get(alpha_key, dim) + scale = alpha / dim + + # W <- W + U * D + weight = module.weight + # print(module_name, down_weight.size(), up_weight.size()) + if len(weight.size()) == 2: + # linear + weight = weight + ratio * (up_weight @ down_weight) * scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + weight + + ratio + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + # print(conved.size(), weight.size(), module.stride, module.padding) + weight = weight + ratio * conved * scale + + module.weight = torch.nn.Parameter(weight) def merge_lora_models(models, ratios, merge_dtype): - base_alphas = {} # alpha for merged model - base_dims = {} + base_alphas = {} # alpha for merged model + base_dims = {} - merged_sd = {} - for model, ratio in zip(models, ratios): - print(f"loading: {model}") - lora_sd = load_state_dict(model, merge_dtype) + merged_sd = {} + for model, ratio in zip(models, ratios): + print(f"loading: {model}") + lora_sd = load_state_dict(model, merge_dtype) - # get alpha and dim - alphas = {} # alpha for current model - dims = {} # dims for current model - for key in lora_sd.keys(): - if 'alpha' in key: - lora_module_name = key[:key.rfind(".alpha")] - alpha = float(lora_sd[key].detach().numpy()) - alphas[lora_module_name] = alpha - if lora_module_name not in base_alphas: - base_alphas[lora_module_name] = alpha - elif "lora_down" in key: - lora_module_name = key[:key.rfind(".lora_down")] - dim = lora_sd[key].size()[0] - dims[lora_module_name] = dim - if lora_module_name not in base_dims: - base_dims[lora_module_name] = dim + # get alpha and dim + alphas = {} # alpha for current model + dims = {} # dims for current model + for key in lora_sd.keys(): + if "alpha" in key: + lora_module_name = key[: key.rfind(".alpha")] + alpha = float(lora_sd[key].detach().numpy()) + alphas[lora_module_name] = alpha + if lora_module_name not in base_alphas: + base_alphas[lora_module_name] = alpha + elif "lora_down" in key: + lora_module_name = key[: key.rfind(".lora_down")] + dim = lora_sd[key].size()[0] + dims[lora_module_name] = dim + if lora_module_name not in base_dims: + base_dims[lora_module_name] = dim - for lora_module_name in dims.keys(): - if lora_module_name not in alphas: - alpha = dims[lora_module_name] - alphas[lora_module_name] = alpha - if lora_module_name not in base_alphas: - base_alphas[lora_module_name] = alpha + for lora_module_name in dims.keys(): + if lora_module_name not in alphas: + alpha = dims[lora_module_name] + alphas[lora_module_name] = alpha + if lora_module_name not in base_alphas: + base_alphas[lora_module_name] = alpha - print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}") + print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}") - # merge - print(f"merging...") - for key in lora_sd.keys(): - if 'alpha' in key: - continue + # merge + print(f"merging...") + for key in lora_sd.keys(): + if "alpha" in key: + continue - lora_module_name = key[:key.rfind(".lora_")] + lora_module_name = key[: key.rfind(".lora_")] - base_alpha = base_alphas[lora_module_name] - alpha = alphas[lora_module_name] + base_alpha = base_alphas[lora_module_name] + alpha = alphas[lora_module_name] - scale = math.sqrt(alpha / base_alpha) * ratio + scale = math.sqrt(alpha / base_alpha) * ratio - if key in merged_sd: - assert merged_sd[key].size() == lora_sd[key].size( - ), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません" - merged_sd[key] = merged_sd[key] + lora_sd[key] * scale - else: - merged_sd[key] = lora_sd[key] * scale + if key in merged_sd: + assert ( + merged_sd[key].size() == lora_sd[key].size() + ), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません" + merged_sd[key] = merged_sd[key] + lora_sd[key] * scale + else: + merged_sd[key] = lora_sd[key] * scale - # set alpha to sd - for lora_module_name, alpha in base_alphas.items(): - key = lora_module_name + ".alpha" - merged_sd[key] = torch.tensor(alpha) + # set alpha to sd + for lora_module_name, alpha in base_alphas.items(): + key = lora_module_name + ".alpha" + merged_sd[key] = torch.tensor(alpha) - print("merged model") - print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}") + print("merged model") + print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}") - return merged_sd + return merged_sd def merge(args): - assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" + assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" - def str_to_dtype(p): - if p == 'float': - return torch.float - if p == 'fp16': - return torch.float16 - if p == 'bf16': - return torch.bfloat16 - return None + def str_to_dtype(p): + if p == "float": + return torch.float + if p == "fp16": + return torch.float16 + if p == "bf16": + return torch.bfloat16 + return None - merge_dtype = str_to_dtype(args.precision) - save_dtype = str_to_dtype(args.save_precision) - if save_dtype is None: - save_dtype = merge_dtype + merge_dtype = str_to_dtype(args.precision) + save_dtype = str_to_dtype(args.save_precision) + if save_dtype is None: + save_dtype = merge_dtype - if args.sd_model is not None: - print(f"loading SD model: {args.sd_model}") + if args.sd_model is not None: + print(f"loading SD model: {args.sd_model}") - text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model) + text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model) - merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype) + merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype) - print(f"saving SD model to: {args.save_to}") - model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet, - args.sd_model, 0, 0, save_dtype, vae) - else: - state_dict = merge_lora_models(args.models, args.ratios, merge_dtype) + print(f"saving SD model to: {args.save_to}") + model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet, args.sd_model, 0, 0, save_dtype, vae) + else: + state_dict = merge_lora_models(args.models, args.ratios, merge_dtype) - print(f"saving model to: {args.save_to}") - save_to_file(args.save_to, state_dict, state_dict, save_dtype) + print(f"saving model to: {args.save_to}") + save_to_file(args.save_to, state_dict, state_dict, save_dtype) def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - parser.add_argument("--v2", action='store_true', - help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む') - parser.add_argument("--save_precision", type=str, default=None, - choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ") - parser.add_argument("--precision", type=str, default="float", - choices=["float", "fp16", "bf16"], help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)") - parser.add_argument("--sd_model", type=str, default=None, - help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする") - parser.add_argument("--save_to", type=str, default=None, - help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors") - parser.add_argument("--models", type=str, nargs='*', - help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors") - parser.add_argument("--ratios", type=float, nargs='*', - help="ratios for each model / それぞれのLoRAモデルの比率") + parser = argparse.ArgumentParser() + parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む") + parser.add_argument( + "--save_precision", + type=str, + default=None, + choices=[None, "float", "fp16", "bf16"], + help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ", + ) + parser.add_argument( + "--precision", + type=str, + default="float", + choices=["float", "fp16", "bf16"], + help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)", + ) + parser.add_argument( + "--sd_model", + type=str, + default=None, + help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする", + ) + parser.add_argument( + "--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors" + ) + parser.add_argument( + "--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors" + ) + parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率") - return parser + return parser -if __name__ == '__main__': - parser = setup_parser() +if __name__ == "__main__": + parser = setup_parser() - args = parser.parse_args() - merge(args) + args = parser.parse_args() + merge(args) diff --git a/tools/merge_lycoris.py b/tools/merge_lycoris.py new file mode 100644 index 0000000..570fa2b --- /dev/null +++ b/tools/merge_lycoris.py @@ -0,0 +1,80 @@ +import os +import sys +import argparse +import torch +from lycoris.utils import merge_loha, merge_locon +from lycoris.kohya_model_utils import ( + load_models_from_stable_diffusion_checkpoint, + save_stable_diffusion_checkpoint, + load_file +) +import gradio as gr + + +def merge_models(base_model, lycoris_model, output_name, is_v2, device, dtype, weight): + base = load_models_from_stable_diffusion_checkpoint(is_v2, base_model) + if lycoris_model.rsplit('.', 1)[-1] == 'safetensors': + lyco = load_file(lycoris_model) + else: + lyco = torch.load(lycoris_model) + + algo = None + for key in lyco: + if 'hada' in key: + algo = 'loha' + break + elif 'lora_up' in key: + algo = 'lora' + break + else: + raise NotImplementedError('Cannot find the algo for this lycoris model file.') + + dtype_str = dtype.replace('fp', 'float').replace('bf', 'bfloat') + dtype = { + 'float': torch.float, + 'float16': torch.float16, + 'float32': torch.float32, + 'float64': torch.float64, + 'bfloat': torch.bfloat16, + 'bfloat16': torch.bfloat16, + }.get(dtype_str, None) + if dtype is None: + raise ValueError(f'Cannot Find the dtype "{dtype}"') + + if algo == 'loha': + merge_loha(base, lyco, weight, device) + elif algo == 'lora': + merge_locon(base, lyco, weight, device) + + save_stable_diffusion_checkpoint( + is_v2, output_name, + base[0], base[2], + None, 0, 0, dtype, + base[1] + ) + + return output_name + + +def main(): + iface = gr.Interface( + fn=merge_models, + inputs=[ + gr.inputs.Textbox(label="Base Model Path"), + gr.inputs.Textbox(label="Lycoris Model Path"), + gr.inputs.Textbox(label="Output Model Path", default='./out.pt'), + gr.inputs.Checkbox(label="Is base model SD V2?", default=False), + gr.inputs.Textbox(label="Device", default='cpu'), + gr.inputs.Dropdown(choices=['float', 'float16', 'float32', 'float64', 'bfloat', 'bfloat16'], label="Dtype", default='float'), + gr.inputs.Number(label="Weight", default=1.0) + ], + outputs=gr.outputs.Textbox(label="Merged Model Path"), + title="Model Merger", + description="Merge Lycoris and Stable Diffusion models", + ) + + iface.launch() + + +if __name__ == '__main__': + main() diff --git a/train_network.py b/train_network.py index 423649e..476f76d 100644 --- a/train_network.py +++ b/train_network.py @@ -25,7 +25,7 @@ from library.config_util import ( BlueprintGenerator, ) import library.custom_train_functions as custom_train_functions -from library.custom_train_functions import apply_snr_weight +from library.custom_train_functions import apply_snr_weight # TODO 他のスクリプトと共通化する @@ -127,12 +127,25 @@ def train(args): weight_dtype, save_dtype = train_util.prepare_dtype(args) # モデルを読み込む - text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype) + for pi in range(accelerator.state.num_processes): + # TODO: modify other training scripts as well + if pi == accelerator.state.local_process_index: + print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}") + + text_encoder, vae, unet, _ = train_util.load_target_model( + args, weight_dtype, accelerator.device if args.lowram else "cpu" + ) + + # work on low-ram device + if args.lowram: + text_encoder.to(accelerator.device) + unet.to(accelerator.device) + vae.to(accelerator.device) + + gc.collect() + torch.cuda.empty_cache() + accelerator.wait_for_everyone() - # work on low-ram device - if args.lowram: - text_encoder.to("cuda") - unet.to("cuda") # モデルに xformers とか memory efficient attention を組み込む train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) @@ -189,7 +202,7 @@ def train(args): # dataloaderを準備する # DataLoaderのプロセス数:0はメインプロセスになる n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで - + train_dataloader = torch.utils.data.DataLoader( train_dataset_group, batch_size=1, @@ -556,9 +569,9 @@ def train(args): loss_weights = batch["loss_weights"] # 各sampleごとのweight loss = loss * loss_weights - + if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py new file mode 100644 index 0000000..74e9bc2 --- /dev/null +++ b/train_textual_inversion_XTI.py @@ -0,0 +1,644 @@ +import importlib +import argparse +import gc +import math +import os +import toml +from multiprocessing import Value + +from tqdm import tqdm +import torch +from accelerate.utils import set_seed +import diffusers +from diffusers import DDPMScheduler + +import library.train_util as train_util +import library.config_util as config_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) +import library.custom_train_functions as custom_train_functions +from library.custom_train_functions import apply_snr_weight +from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI + +imagenet_templates_small = [ + "a photo of a {}", + "a rendering of a {}", + "a cropped photo of the {}", + "the photo of a {}", + "a photo of a clean {}", + "a photo of a dirty {}", + "a dark photo of the {}", + "a photo of my {}", + "a photo of the cool {}", + "a close-up photo of a {}", + "a bright photo of the {}", + "a cropped photo of a {}", + "a photo of the {}", + "a good photo of the {}", + "a photo of one {}", + "a close-up photo of the {}", + "a rendition of the {}", + "a photo of the clean {}", + "a rendition of a {}", + "a photo of a nice {}", + "a good photo of a {}", + "a photo of the nice {}", + "a photo of the small {}", + "a photo of the weird {}", + "a photo of the large {}", + "a photo of a cool {}", + "a photo of a small {}", +] + +imagenet_style_templates_small = [ + "a painting in the style of {}", + "a rendering in the style of {}", + "a cropped painting in the style of {}", + "the painting in the style of {}", + "a clean painting in the style of {}", + "a dirty painting in the style of {}", + "a dark painting in the style of {}", + "a picture in the style of {}", + "a cool painting in the style of {}", + "a close-up painting in the style of {}", + "a bright painting in the style of {}", + "a cropped painting in the style of {}", + "a good painting in the style of {}", + "a close-up painting in the style of {}", + "a rendition in the style of {}", + "a nice painting in the style of {}", + "a small painting in the style of {}", + "a weird painting in the style of {}", + "a large painting in the style of {}", +] + + +def train(args): + if args.output_name is None: + args.output_name = args.token_string + use_template = args.use_object_template or args.use_style_template + + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) + + if args.sample_every_n_steps is not None or args.sample_every_n_epochs is not None: + print( + "sample_every_n_steps and sample_every_n_epochs are not supported in this script currently / sample_every_n_stepsとsample_every_n_epochsは現在このスクリプトではサポートされていません" + ) + + cache_latents = args.cache_latents + + if args.seed is not None: + set_seed(args.seed) + + tokenizer = train_util.load_tokenizer(args) + + # acceleratorを準備する + print("prepare accelerator") + accelerator, unwrap_model = train_util.prepare_accelerator(args) + + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) + + # モデルを読み込む + text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype) + + # Convert the init_word to token_id + if args.init_word is not None: + init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False) + if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token: + print( + f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / 初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: length {len(init_token_ids)}" + ) + else: + init_token_ids = None + + # add new word to tokenizer, count is num_vectors_per_token + token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)] + num_added_tokens = tokenizer.add_tokens(token_strings) + assert ( + num_added_tokens == args.num_vectors_per_token + ), f"tokenizer has same word to token string. please use another one / 指定したargs.token_stringは既に存在します。別の単語を使ってください: {args.token_string}" + + token_ids = tokenizer.convert_tokens_to_ids(token_strings) + print(f"tokens are added: {token_ids}") + assert min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1, f"token ids is not ordered" + assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}" + + token_strings_XTI = [] + XTI_layers = [ + "IN01", + "IN02", + "IN04", + "IN05", + "IN07", + "IN08", + "MID", + "OUT03", + "OUT04", + "OUT05", + "OUT06", + "OUT07", + "OUT08", + "OUT09", + "OUT10", + "OUT11", + ] + for layer_name in XTI_layers: + token_strings_XTI += [f"{t}_{layer_name}" for t in token_strings] + + tokenizer.add_tokens(token_strings_XTI) + token_ids_XTI = tokenizer.convert_tokens_to_ids(token_strings_XTI) + print(f"tokens are added (XTI): {token_ids_XTI}") + # Resize the token embeddings as we are adding new special tokens to the tokenizer + text_encoder.resize_token_embeddings(len(tokenizer)) + + # Initialise the newly added placeholder token with the embeddings of the initializer token + token_embeds = text_encoder.get_input_embeddings().weight.data + if init_token_ids is not None: + for i, token_id in enumerate(token_ids_XTI): + token_embeds[token_id] = token_embeds[init_token_ids[(i // 16) % len(init_token_ids)]] + # print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) + + # load weights + if args.weights is not None: + embeddings = load_weights(args.weights) + assert len(token_ids) == len( + embeddings + ), f"num_vectors_per_token is mismatch for weights / 指定した重みとnum_vectors_per_tokenの値が異なります: {len(embeddings)}" + # print(token_ids, embeddings.size()) + for token_id, embedding in zip(token_ids_XTI, embeddings): + token_embeds[token_id] = embedding + # print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) + print(f"weighs loaded") + + print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}") + + # データセットを準備する + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False)) + if args.dataset_config is not None: + print(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "reg_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + print( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) + else: + use_dreambooth_method = args.in_json is None + if use_dreambooth_method: + print("Use DreamBooth method.") + user_config = { + "datasets": [ + {"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)} + ] + } + else: + print("Train with captions.") + user_config = { + "datasets": [ + { + "subsets": [ + { + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + } + ] + } + ] + } + + blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group.enable_XTI(XTI_layers, token_strings=token_strings) + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collater = train_util.collater_class(current_epoch, current_step, ds_for_collater) + + # make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装 + if use_template: + print("use template for training captions. is object: {args.use_object_template}") + templates = imagenet_templates_small if args.use_object_template else imagenet_style_templates_small + replace_to = " ".join(token_strings) + captions = [] + for tmpl in templates: + captions.append(tmpl.format(replace_to)) + train_dataset_group.add_replacement("", captions) + + if args.num_vectors_per_token > 1: + prompt_replacement = (args.token_string, replace_to) + else: + prompt_replacement = None + else: + if args.num_vectors_per_token > 1: + replace_to = " ".join(token_strings) + train_dataset_group.add_replacement(args.token_string, replace_to) + prompt_replacement = (args.token_string, replace_to) + else: + prompt_replacement = None + + if args.debug_dataset: + train_util.debug_dataset(train_dataset_group, show_input_ids=True) + return + if len(train_dataset_group) == 0: + print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください") + return + + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + + # モデルに xformers とか memory efficient attention を組み込む + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) + diffusers.models.UNet2DConditionModel.forward = unet_forward_XTI + diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D.forward = downblock_forward_XTI + diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D.forward = upblock_forward_XTI + + # 学習を準備する + if cache_latents: + vae.to(accelerator.device, dtype=weight_dtype) + vae.requires_grad_(False) + vae.eval() + with torch.no_grad(): + train_dataset_group.cache_latents(vae, args.vae_batch_size) + vae.to("cpu") + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + text_encoder.gradient_checkpointing_enable() + + # 学習に必要なクラスを準備する + print("prepare optimizer, data loader etc.") + trainable_params = text_encoder.get_input_embeddings().parameters() + _, _, optimizer = train_util.get_optimizer(args, trainable_params) + + # dataloaderを準備する + # DataLoaderのプロセス数:0はメインプロセスになる + n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collater, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) + print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + + # lr schedulerを用意する + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + + # acceleratorがなんかよろしくやってくれるらしい + text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + text_encoder, optimizer, train_dataloader, lr_scheduler + ) + + index_no_updates = torch.arange(len(tokenizer)) < token_ids_XTI[0] + # print(len(index_no_updates), torch.sum(index_no_updates)) + orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone() + + # Freeze all parameters except for the token embeddings in text encoder + text_encoder.requires_grad_(True) + text_encoder.text_model.encoder.requires_grad_(False) + text_encoder.text_model.final_layer_norm.requires_grad_(False) + text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) + # text_encoder.text_model.embeddings.token_embedding.requires_grad_(True) + + unet.requires_grad_(False) + unet.to(accelerator.device, dtype=weight_dtype) + if args.gradient_checkpointing: # according to TI example in Diffusers, train is required + unet.train() + else: + unet.eval() + + if not cache_latents: + vae.requires_grad_(False) + vae.eval() + vae.to(accelerator.device, dtype=weight_dtype) + + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + train_util.patch_accelerator_for_fp16_training(accelerator) + text_encoder.to(weight_dtype) + + # resumeする + if args.resume is not None: + print(f"resume training from state: {args.resume}") + accelerator.load_state(args.resume) + + # epoch数を計算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + + # 学習する + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + print("running training / 学習開始") + print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") + print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") + print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + print(f" num epochs / epoch数: {num_train_epochs}") + print(f" batch size per device / バッチサイズ: {args.train_batch_size}") + print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + global_step = 0 + + noise_scheduler = DDPMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False + ) + + if accelerator.is_main_process: + accelerator.init_trackers("textual_inversion") + + for epoch in range(num_train_epochs): + print(f"epoch {epoch+1}/{num_train_epochs}") + current_epoch.value = epoch + 1 + + text_encoder.train() + + loss_total = 0 + + for step, batch in enumerate(train_dataloader): + current_step.value = global_step + with accelerator.accumulate(text_encoder): + with torch.no_grad(): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device) + else: + # latentに変換 + latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * 0.18215 + b_size = latents.shape[0] + + # Get the text embedding for conditioning + input_ids = batch["input_ids"].to(accelerator.device) + # weight_dtype) use float instead of fp16/bf16 because text encoder is float + encoder_hidden_states = torch.stack( + [ + train_util.get_hidden_states(args, s, tokenizer, text_encoder, weight_dtype) + for s in torch.split(input_ids, 1, dim=1) + ] + ) + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents, device=latents.device) + if args.noise_offset: + # https://www.crosslabs.org//blog/diffusion-with-offset-noise + noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) + + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Predict the noise residual + noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).sample + + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) + + if args.min_snr_gamma: + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights + + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + + accelerator.backward(loss) + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = text_encoder.get_input_embeddings().parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + # Let's make sure we don't update any embedding weights besides the newly added token + with torch.no_grad(): + unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = orig_embeds_params[ + index_no_updates + ] + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + # TODO: fix sample_images + # train_util.sample_images( + # accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement + # ) + + current_loss = loss.detach().item() + if args.logging_dir is not None: + logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} + if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value + logs["lr/d*lr"] = ( + lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] + ) + accelerator.log(logs, step=global_step) + + loss_total += current_loss + avr_loss = loss_total / (step + 1) + logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if args.logging_dir is not None: + logs = {"loss/epoch": loss_total / len(train_dataloader)} + accelerator.log(logs, step=epoch + 1) + + accelerator.wait_for_everyone() + + updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids_XTI].data.detach().clone() + + if args.save_every_n_epochs is not None: + model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name + + def save_func(): + ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + "." + args.save_model_as + ckpt_file = os.path.join(args.output_dir, ckpt_name) + print(f"saving checkpoint: {ckpt_file}") + save_weights(ckpt_file, updated_embs, save_dtype) + + def remove_old_func(old_epoch_no): + old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as + old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) + if os.path.exists(old_ckpt_file): + print(f"removing old checkpoint: {old_ckpt_file}") + os.remove(old_ckpt_file) + + saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs) + if saving and args.save_state: + train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1) + + # TODO: fix sample_images + # train_util.sample_images( + # accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement + # ) + + # end of epoch + + is_main_process = accelerator.is_main_process + if is_main_process: + text_encoder = unwrap_model(text_encoder) + + accelerator.end_training() + + if args.save_state: + train_util.save_state_on_train_end(args, accelerator) + + updated_embs = text_encoder.get_input_embeddings().weight[token_ids_XTI].data.detach().clone() + + del accelerator # この後メモリを使うのでこれは消す + + if is_main_process: + os.makedirs(args.output_dir, exist_ok=True) + + model_name = train_util.DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name + ckpt_name = model_name + "." + args.save_model_as + ckpt_file = os.path.join(args.output_dir, ckpt_name) + + print(f"save trained model to {ckpt_file}") + save_weights(ckpt_file, updated_embs, save_dtype) + print("model saved.") + + +def save_weights(file, updated_embs, save_dtype): + updated_embs = updated_embs.reshape(16, -1, updated_embs.shape[-1]) + updated_embs = updated_embs.chunk(16) + XTI_layers = [ + "IN01", + "IN02", + "IN04", + "IN05", + "IN07", + "IN08", + "MID", + "OUT03", + "OUT04", + "OUT05", + "OUT06", + "OUT07", + "OUT08", + "OUT09", + "OUT10", + "OUT11", + ] + state_dict = {} + for i, layer_name in enumerate(XTI_layers): + state_dict[layer_name] = updated_embs[i].squeeze(0).detach().clone().to("cpu").to(save_dtype) + + # if save_dtype is not None: + # for key in list(state_dict.keys()): + # v = state_dict[key] + # v = v.detach().clone().to("cpu").to(save_dtype) + # state_dict[key] = v + + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file + + save_file(state_dict, file) + else: + torch.save(state_dict, file) # can be loaded in Web UI + + +def load_weights(file): + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file + + data = load_file(file) + else: + raise ValueError(f"NOT XTI: {file}") + + if len(data.values()) != 16: + raise ValueError(f"NOT XTI: {file}") + + emb = torch.concat([x for x in data.values()]) + + return emb + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + train_util.add_sd_models_arguments(parser) + train_util.add_dataset_arguments(parser, True, True, False) + train_util.add_training_arguments(parser, True) + train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) + custom_train_functions.add_custom_train_arguments(parser) + + parser.add_argument( + "--save_model_as", + type=str, + default="pt", + choices=[None, "ckpt", "pt", "safetensors"], + help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)", + ) + + parser.add_argument("--weights", type=str, default=None, help="embedding weights to initialize / 学習するネットワークの初期重み") + parser.add_argument( + "--num_vectors_per_token", type=int, default=1, help="number of vectors per token / トークンに割り当てるembeddingsの要素数" + ) + parser.add_argument( + "--token_string", + type=str, + default=None, + help="token string used in training, must not exist in tokenizer / 学習時に使用されるトークン文字列、tokenizerに存在しない文字であること", + ) + parser.add_argument("--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可") + parser.add_argument( + "--use_object_template", + action="store_true", + help="ignore caption and use default templates for object / キャプションは使わずデフォルトの物体用テンプレートで学習する", + ) + parser.add_argument( + "--use_style_template", + action="store_true", + help="ignore caption and use default templates for stype / キャプションは使わずデフォルトのスタイル用テンプレートで学習する", + ) + + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + args = train_util.read_config_from_file(args, parser) + + train(args) From aed9f937da05eb8db670cee48a3b563f37359e44 Mon Sep 17 00:00:00 2001 From: JSTayco Date: Sat, 1 Apr 2023 08:52:45 -0700 Subject: [PATCH 42/44] Removed -v test as older versions of bash don't support that option Made script compatible with very old version of bash, added some guards for library linking, and removed redundant library linking. --- setup.sh | 30 ++++++++++++++---------------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/setup.sh b/setup.sh index bc98d6e..7155f68 100755 --- a/setup.sh +++ b/setup.sh @@ -33,10 +33,10 @@ EOF # Checks to see if variable is set and non-empty. # This is defined first, so we can use the function for some default variable values env_var_exists() { - if [[ ! -v "$1" ]] || [[ -z "$1" ]]; then - return 1 - else + if [[ -n "${!1}" ]]; then return 0 + else + return 1 fi } @@ -416,17 +416,6 @@ if [[ "$OSTYPE" == "linux-gnu"* ]]; then } check_storage_space - - # This is the pre-install work for a kohya installation on a runpod - if [ "$RUNPOD" = true ]; then - if [ -d "$VENV_DIR" ]; then - echo "Pre-existing installation on a runpod detected." - export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:"$VENV_DIR"/lib/python3.10/site-packages/tensorrt/ - export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:"$VENV_DIR"/lib/python3.10/site-packages/nvidia/cuda_runtime/lib/ - cd "$DIR" || exit 1 - fi - fi - update_kohya_ss distro=get_distro_name @@ -506,8 +495,17 @@ if [[ "$OSTYPE" == "linux-gnu"* ]]; then create_symlinks "$libnvinfer_symlink" "$libnvinfer_target" create_symlinks "$libcudart_symlink" "$libcudart_target" - export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:$VENV_DIR/lib/python3.10/site-packages/tensorrt/" - export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:$VENV_DIR/lib/python3.10/site-packages/nvidia/cuda_runtime/lib/" + if [ -d "${VENV_DIR}/lib/python3.10/site-packages/tensorrt/" ]; then + export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:${VENV_DIR}/lib/python3.10/site-packages/tensorrt/" + else + echo "${VENV_DIR}/lib/python3.10/site-packages/tensorrt/ not found; not linking library." + fi + + if [ -d "${VENV_DIR}/lib/python3.10/site-packages/tensorrt/" ]; then + export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:${VENV_DIR}/lib/python3.10/site-packages/nvidia/cuda_runtime/lib/" + else + echo "${VENV_DIR}/lib/python3.10/site-packages/nvidia/cuda_runtime/lib/ not found; not linking library." + fi configure_accelerate From 882e4837b36d78017e137c85d099def43e5880e3 Mon Sep 17 00:00:00 2001 From: JSTayco Date: Sat, 1 Apr 2023 09:15:07 -0700 Subject: [PATCH 43/44] Check for valid install directory and upgrade pip in venv --- setup.sh | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/setup.sh b/setup.sh index 7155f68..b516e78 100755 --- a/setup.sh +++ b/setup.sh @@ -149,6 +149,12 @@ Script directory is ${SCRIPT_DIR}." >&5 PARENT_DIR="$(dirname "${DIR}")" VENV_DIR="$DIR/venv" +if [ ! -w "$DIR" ]; then + echo "We cannot write to ${DIR}." + echo "Please ensure the install directory is accurate and you have the correct permissions." + exit 1 +fi + # Shared functions # This checks for free space on the installation drive and returns that in Gb. size_available() { @@ -213,7 +219,7 @@ install_python_dependencies() { # Updating pip if there is one echo "Checking for pip updates before Python operations." - python3 -m pip install --upgrade pip >&3 + pip install --upgrade pip >&3 echo "Installing python dependencies. This could take a few minutes as it downloads files." echo "If this operation ever runs too long, you can rerun this script in verbose mode to check." From ab6c7e20826498523b460a165bd631747a6eefba Mon Sep 17 00:00:00 2001 From: JSTayco Date: Sat, 1 Apr 2023 09:59:35 -0700 Subject: [PATCH 44/44] Fixed git cloning for new directories. --- setup.sh | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/setup.sh b/setup.sh index b516e78..b07976e 100755 --- a/setup.sh +++ b/setup.sh @@ -149,6 +149,11 @@ Script directory is ${SCRIPT_DIR}." >&5 PARENT_DIR="$(dirname "${DIR}")" VENV_DIR="$DIR/venv" +if [ -w "$PARENT_DIR" ] && [ ! -d "$DIR" ]; then + echo "Creating install folder ${DIR}." + mkdir "$DIR" +fi + if [ ! -w "$DIR" ]; then echo "We cannot write to ${DIR}." echo "Please ensure the install directory is accurate and you have the correct permissions." @@ -339,8 +344,8 @@ update_kohya_ss() { echo "Attempting to clone $GIT_REPO." if [ ! -d "$DIR/.git" ]; then - echo "Cloning and switching to $GIT_REPO:$BRANCH" >*4 - git -C "$DIR" clone -b "$BRANCH" "$GIT_REPO" "$(basename "$DIR")" >&3 + echo "Cloning and switching to $GIT_REPO:$BRANCH" >&4 + git -C "$PARENT_DIR" clone -b "$BRANCH" "$GIT_REPO" "$(basename "$DIR")" >&3 git -C "$DIR" switch "$BRANCH" >&4 else echo "git repo detected. Attempting to update repository instead."