Merge branch 'dev' into find_vae
This commit is contained in:
commit
80adb6979d
43
.github/workflows/on_pull_request.yaml
vendored
43
.github/workflows/on_pull_request.yaml
vendored
@ -18,22 +18,29 @@ jobs:
|
|||||||
steps:
|
steps:
|
||||||
- name: Checkout Code
|
- name: Checkout Code
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v3
|
||||||
- name: Set up Python 3.10
|
- uses: actions/setup-python@v4
|
||||||
uses: actions/setup-python@v4
|
|
||||||
with:
|
with:
|
||||||
python-version: 3.10.6
|
python-version: 3.11
|
||||||
cache: pip
|
# NB: there's no cache: pip here since we're not installing anything
|
||||||
cache-dependency-path: |
|
# from the requirements.txt file(s) in the repository; it's faster
|
||||||
**/requirements*txt
|
# not to have GHA download an (at the time of writing) 4 GB cache
|
||||||
- name: Install PyLint
|
# of PyTorch and other dependencies.
|
||||||
run: |
|
- name: Install Ruff
|
||||||
python -m pip install --upgrade pip
|
run: pip install ruff==0.0.265
|
||||||
pip install pylint
|
- name: Run Ruff
|
||||||
# This lets PyLint check to see if it can resolve imports
|
run: ruff .
|
||||||
- name: Install dependencies
|
|
||||||
run: |
|
# The rest are currently disabled pending fixing of e.g. installing the torch dependency.
|
||||||
export COMMANDLINE_ARGS="--skip-torch-cuda-test --exit"
|
|
||||||
python launch.py
|
# - name: Install PyLint
|
||||||
- name: Analysing the code with pylint
|
# run: |
|
||||||
run: |
|
# python -m pip install --upgrade pip
|
||||||
pylint $(git ls-files '*.py')
|
# pip install pylint
|
||||||
|
# # This lets PyLint check to see if it can resolve imports
|
||||||
|
# - name: Install dependencies
|
||||||
|
# run: |
|
||||||
|
# export COMMANDLINE_ARGS="--skip-torch-cuda-test --exit"
|
||||||
|
# python launch.py
|
||||||
|
# - name: Analysing the code with pylint
|
||||||
|
# run: |
|
||||||
|
# pylint $(git ls-files '*.py')
|
||||||
|
6
.github/workflows/run_tests.yaml
vendored
6
.github/workflows/run_tests.yaml
vendored
@ -17,8 +17,14 @@ jobs:
|
|||||||
cache: pip
|
cache: pip
|
||||||
cache-dependency-path: |
|
cache-dependency-path: |
|
||||||
**/requirements*txt
|
**/requirements*txt
|
||||||
|
launch.py
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: python launch.py --tests test --no-half --disable-opt-split-attention --use-cpu all --skip-torch-cuda-test
|
run: python launch.py --tests test --no-half --disable-opt-split-attention --use-cpu all --skip-torch-cuda-test
|
||||||
|
env:
|
||||||
|
PIP_DISABLE_PIP_VERSION_CHECK: "1"
|
||||||
|
PIP_PROGRESS_BAR: "off"
|
||||||
|
TORCH_INDEX_URL: https://download.pytorch.org/whl/cpu
|
||||||
|
WEBUI_LAUNCH_LIVE_OUTPUT: "1"
|
||||||
- name: Upload main app stdout-stderr
|
- name: Upload main app stdout-stderr
|
||||||
uses: actions/upload-artifact@v3
|
uses: actions/upload-artifact@v3
|
||||||
if: always()
|
if: always()
|
||||||
|
56
CHANGELOG.md
56
CHANGELOG.md
@ -1,3 +1,59 @@
|
|||||||
|
## Upcoming 1.2.1
|
||||||
|
|
||||||
|
### Features:
|
||||||
|
* add an option to always refer to lora by filenames
|
||||||
|
|
||||||
|
### Bug Fixes:
|
||||||
|
* never refer to lora by an alias if multiple loras have same alias or the alias is called none
|
||||||
|
* fix upscalers disappearing after the user reloads UI
|
||||||
|
* allow bf16 in safe unpickler (resolves problems with loading some loras)
|
||||||
|
* allow web UI to be ran fully offline
|
||||||
|
|
||||||
|
## 1.2.0
|
||||||
|
|
||||||
|
### Features:
|
||||||
|
* do not wait for stable diffusion model to load at startup
|
||||||
|
* add filename patterns: [denoising]
|
||||||
|
* directory hiding for extra networks: dirs starting with . will hide their cards on extra network tabs unless specifically searched for
|
||||||
|
* Lora: for the `<...>` text in prompt, use name of Lora that is in the metdata of the file, if present, instead of filename (both can be used to activate lora)
|
||||||
|
* Lora: read infotext params from kohya-ss's extension parameters if they are present and if his extension is not active
|
||||||
|
* Lora: Fix some Loras not working (ones that have 3x3 convolution layer)
|
||||||
|
* Lora: add an option to use old method of applying loras (producing same results as with kohya-ss)
|
||||||
|
* add version to infotext, footer and console output when starting
|
||||||
|
* add links to wiki for filename pattern settings
|
||||||
|
* add extended info for quicksettings setting and use multiselect input instead of a text field
|
||||||
|
|
||||||
|
### Minor:
|
||||||
|
* gradio bumped to 3.29.0
|
||||||
|
* torch bumped to 2.0.1
|
||||||
|
* --subpath option for gradio for use with reverse proxy
|
||||||
|
* linux/OSX: use existing virtualenv if already active (the VIRTUAL_ENV environment variable)
|
||||||
|
* possible frontend optimization: do not apply localizations if there are none
|
||||||
|
* Add extra `None` option for VAE in XYZ plot
|
||||||
|
* print error to console when batch processing in img2img fails
|
||||||
|
* create HTML for extra network pages only on demand
|
||||||
|
* allow directories starting with . to still list their models for lora, checkpoints, etc
|
||||||
|
* put infotext options into their own category in settings tab
|
||||||
|
* do not show licenses page when user selects Show all pages in settings
|
||||||
|
|
||||||
|
### Extensions:
|
||||||
|
* Tooltip localization support
|
||||||
|
* Add api method to get LoRA models with prompt
|
||||||
|
|
||||||
|
### Bug Fixes:
|
||||||
|
* re-add /docs endpoint
|
||||||
|
* fix gamepad navigation
|
||||||
|
* make the lightbox fullscreen image function properly
|
||||||
|
* fix squished thumbnails in extras tab
|
||||||
|
* keep "search" filter for extra networks when user refreshes the tab (previously it showed everthing after you refreshed)
|
||||||
|
* fix webui showing the same image if you configure the generation to always save results into same file
|
||||||
|
* fix bug with upscalers not working properly
|
||||||
|
* Fix MPS on PyTorch 2.0.1, Intel Macs
|
||||||
|
* make it so that custom context menu from contextMenu.js only disappears after user's click, ignoring non-user click events
|
||||||
|
* prevent Reload UI button/link from reloading the page when it's not yet ready
|
||||||
|
* fix prompts from file script failing to read contents from a drag/drop file
|
||||||
|
|
||||||
|
|
||||||
## 1.1.1
|
## 1.1.1
|
||||||
### Bug Fixes:
|
### Bug Fixes:
|
||||||
* fix an error that prevents running webui on torch<2.0 without --disable-safe-unpickle
|
* fix an error that prevents running webui on torch<2.0 without --disable-safe-unpickle
|
||||||
|
@ -88,7 +88,7 @@ class LDSR:
|
|||||||
|
|
||||||
x_t = None
|
x_t = None
|
||||||
logs = None
|
logs = None
|
||||||
for n in range(n_runs):
|
for _ in range(n_runs):
|
||||||
if custom_shape is not None:
|
if custom_shape is not None:
|
||||||
x_t = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device)
|
x_t = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device)
|
||||||
x_t = repeat(x_t, '1 c h w -> b c h w', b=custom_shape[0])
|
x_t = repeat(x_t, '1 c h w -> b c h w', b=custom_shape[0])
|
||||||
@ -110,7 +110,6 @@ class LDSR:
|
|||||||
diffusion_steps = int(steps)
|
diffusion_steps = int(steps)
|
||||||
eta = 1.0
|
eta = 1.0
|
||||||
|
|
||||||
down_sample_method = 'Lanczos'
|
|
||||||
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
if torch.cuda.is_available:
|
if torch.cuda.is_available:
|
||||||
@ -131,11 +130,11 @@ class LDSR:
|
|||||||
im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS)
|
im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS)
|
||||||
else:
|
else:
|
||||||
print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)")
|
print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)")
|
||||||
|
|
||||||
# pad width and height to multiples of 64, pads with the edge values of image to avoid artifacts
|
# pad width and height to multiples of 64, pads with the edge values of image to avoid artifacts
|
||||||
pad_w, pad_h = np.max(((2, 2), np.ceil(np.array(im_og.size) / 64).astype(int)), axis=0) * 64 - im_og.size
|
pad_w, pad_h = np.max(((2, 2), np.ceil(np.array(im_og.size) / 64).astype(int)), axis=0) * 64 - im_og.size
|
||||||
im_padded = Image.fromarray(np.pad(np.array(im_og), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge'))
|
im_padded = Image.fromarray(np.pad(np.array(im_og), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge'))
|
||||||
|
|
||||||
logs = self.run(model["model"], im_padded, diffusion_steps, eta)
|
logs = self.run(model["model"], im_padded, diffusion_steps, eta)
|
||||||
|
|
||||||
sample = logs["sample"]
|
sample = logs["sample"]
|
||||||
@ -158,7 +157,7 @@ class LDSR:
|
|||||||
|
|
||||||
|
|
||||||
def get_cond(selected_path):
|
def get_cond(selected_path):
|
||||||
example = dict()
|
example = {}
|
||||||
up_f = 4
|
up_f = 4
|
||||||
c = selected_path.convert('RGB')
|
c = selected_path.convert('RGB')
|
||||||
c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0)
|
c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0)
|
||||||
@ -196,7 +195,7 @@ def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_s
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize_x0=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None,
|
def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize_x0=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None,
|
||||||
corrector_kwargs=None, x_T=None, ddim_use_x0_pred=False):
|
corrector_kwargs=None, x_T=None, ddim_use_x0_pred=False):
|
||||||
log = dict()
|
log = {}
|
||||||
|
|
||||||
z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key,
|
z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key,
|
||||||
return_first_stage_outputs=True,
|
return_first_stage_outputs=True,
|
||||||
@ -244,7 +243,7 @@ def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize
|
|||||||
x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True)
|
x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True)
|
||||||
log["sample_noquant"] = x_sample_noquant
|
log["sample_noquant"] = x_sample_noquant
|
||||||
log["sample_diff"] = torch.abs(x_sample_noquant - x_sample)
|
log["sample_diff"] = torch.abs(x_sample_noquant - x_sample)
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
log["sample"] = x_sample
|
log["sample"] = x_sample
|
||||||
|
@ -7,7 +7,8 @@ from basicsr.utils.download_util import load_file_from_url
|
|||||||
from modules.upscaler import Upscaler, UpscalerData
|
from modules.upscaler import Upscaler, UpscalerData
|
||||||
from ldsr_model_arch import LDSR
|
from ldsr_model_arch import LDSR
|
||||||
from modules import shared, script_callbacks
|
from modules import shared, script_callbacks
|
||||||
import sd_hijack_autoencoder, sd_hijack_ddpm_v1
|
import sd_hijack_autoencoder # noqa: F401
|
||||||
|
import sd_hijack_ddpm_v1 # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
class UpscalerLDSR(Upscaler):
|
class UpscalerLDSR(Upscaler):
|
||||||
|
@ -1,16 +1,21 @@
|
|||||||
# The content of this file comes from the ldm/models/autoencoder.py file of the compvis/stable-diffusion repo
|
# The content of this file comes from the ldm/models/autoencoder.py file of the compvis/stable-diffusion repo
|
||||||
# The VQModel & VQModelInterface were subsequently removed from ldm/models/autoencoder.py when we moved to the stability-ai/stablediffusion repo
|
# The VQModel & VQModelInterface were subsequently removed from ldm/models/autoencoder.py when we moved to the stability-ai/stablediffusion repo
|
||||||
# As the LDSR upscaler relies on VQModel & VQModelInterface, the hijack aims to put them back into the ldm.models.autoencoder
|
# As the LDSR upscaler relies on VQModel & VQModelInterface, the hijack aims to put them back into the ldm.models.autoencoder
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
from torch.optim.lr_scheduler import LambdaLR
|
||||||
|
|
||||||
|
from ldm.modules.ema import LitEma
|
||||||
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
|
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
|
||||||
from ldm.modules.diffusionmodules.model import Encoder, Decoder
|
from ldm.modules.diffusionmodules.model import Encoder, Decoder
|
||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
|
|
||||||
import ldm.models.autoencoder
|
import ldm.models.autoencoder
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
class VQModel(pl.LightningModule):
|
class VQModel(pl.LightningModule):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -19,7 +24,7 @@ class VQModel(pl.LightningModule):
|
|||||||
n_embed,
|
n_embed,
|
||||||
embed_dim,
|
embed_dim,
|
||||||
ckpt_path=None,
|
ckpt_path=None,
|
||||||
ignore_keys=[],
|
ignore_keys=None,
|
||||||
image_key="image",
|
image_key="image",
|
||||||
colorize_nlabels=None,
|
colorize_nlabels=None,
|
||||||
monitor=None,
|
monitor=None,
|
||||||
@ -57,7 +62,7 @@ class VQModel(pl.LightningModule):
|
|||||||
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
||||||
|
|
||||||
if ckpt_path is not None:
|
if ckpt_path is not None:
|
||||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [])
|
||||||
self.scheduler_config = scheduler_config
|
self.scheduler_config = scheduler_config
|
||||||
self.lr_g_factor = lr_g_factor
|
self.lr_g_factor = lr_g_factor
|
||||||
|
|
||||||
@ -76,11 +81,11 @@ class VQModel(pl.LightningModule):
|
|||||||
if context is not None:
|
if context is not None:
|
||||||
print(f"{context}: Restored training weights")
|
print(f"{context}: Restored training weights")
|
||||||
|
|
||||||
def init_from_ckpt(self, path, ignore_keys=list()):
|
def init_from_ckpt(self, path, ignore_keys=None):
|
||||||
sd = torch.load(path, map_location="cpu")["state_dict"]
|
sd = torch.load(path, map_location="cpu")["state_dict"]
|
||||||
keys = list(sd.keys())
|
keys = list(sd.keys())
|
||||||
for k in keys:
|
for k in keys:
|
||||||
for ik in ignore_keys:
|
for ik in ignore_keys or []:
|
||||||
if k.startswith(ik):
|
if k.startswith(ik):
|
||||||
print("Deleting key {} from state_dict.".format(k))
|
print("Deleting key {} from state_dict.".format(k))
|
||||||
del sd[k]
|
del sd[k]
|
||||||
@ -165,7 +170,7 @@ class VQModel(pl.LightningModule):
|
|||||||
def validation_step(self, batch, batch_idx):
|
def validation_step(self, batch, batch_idx):
|
||||||
log_dict = self._validation_step(batch, batch_idx)
|
log_dict = self._validation_step(batch, batch_idx)
|
||||||
with self.ema_scope():
|
with self.ema_scope():
|
||||||
log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
|
self._validation_step(batch, batch_idx, suffix="_ema")
|
||||||
return log_dict
|
return log_dict
|
||||||
|
|
||||||
def _validation_step(self, batch, batch_idx, suffix=""):
|
def _validation_step(self, batch, batch_idx, suffix=""):
|
||||||
@ -232,7 +237,7 @@ class VQModel(pl.LightningModule):
|
|||||||
return self.decoder.conv_out.weight
|
return self.decoder.conv_out.weight
|
||||||
|
|
||||||
def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
|
def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
|
||||||
log = dict()
|
log = {}
|
||||||
x = self.get_input(batch, self.image_key)
|
x = self.get_input(batch, self.image_key)
|
||||||
x = x.to(self.device)
|
x = x.to(self.device)
|
||||||
if only_inputs:
|
if only_inputs:
|
||||||
@ -249,7 +254,8 @@ class VQModel(pl.LightningModule):
|
|||||||
if plot_ema:
|
if plot_ema:
|
||||||
with self.ema_scope():
|
with self.ema_scope():
|
||||||
xrec_ema, _ = self(x)
|
xrec_ema, _ = self(x)
|
||||||
if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
|
if x.shape[1] > 3:
|
||||||
|
xrec_ema = self.to_rgb(xrec_ema)
|
||||||
log["reconstructions_ema"] = xrec_ema
|
log["reconstructions_ema"] = xrec_ema
|
||||||
return log
|
return log
|
||||||
|
|
||||||
@ -264,7 +270,7 @@ class VQModel(pl.LightningModule):
|
|||||||
|
|
||||||
class VQModelInterface(VQModel):
|
class VQModelInterface(VQModel):
|
||||||
def __init__(self, embed_dim, *args, **kwargs):
|
def __init__(self, embed_dim, *args, **kwargs):
|
||||||
super().__init__(embed_dim=embed_dim, *args, **kwargs)
|
super().__init__(*args, embed_dim=embed_dim, **kwargs)
|
||||||
self.embed_dim = embed_dim
|
self.embed_dim = embed_dim
|
||||||
|
|
||||||
def encode(self, x):
|
def encode(self, x):
|
||||||
@ -282,5 +288,5 @@ class VQModelInterface(VQModel):
|
|||||||
dec = self.decoder(quant)
|
dec = self.decoder(quant)
|
||||||
return dec
|
return dec
|
||||||
|
|
||||||
setattr(ldm.models.autoencoder, "VQModel", VQModel)
|
ldm.models.autoencoder.VQModel = VQModel
|
||||||
setattr(ldm.models.autoencoder, "VQModelInterface", VQModelInterface)
|
ldm.models.autoencoder.VQModelInterface = VQModelInterface
|
||||||
|
@ -48,7 +48,7 @@ class DDPMV1(pl.LightningModule):
|
|||||||
beta_schedule="linear",
|
beta_schedule="linear",
|
||||||
loss_type="l2",
|
loss_type="l2",
|
||||||
ckpt_path=None,
|
ckpt_path=None,
|
||||||
ignore_keys=[],
|
ignore_keys=None,
|
||||||
load_only_unet=False,
|
load_only_unet=False,
|
||||||
monitor="val/loss",
|
monitor="val/loss",
|
||||||
use_ema=True,
|
use_ema=True,
|
||||||
@ -100,7 +100,7 @@ class DDPMV1(pl.LightningModule):
|
|||||||
if monitor is not None:
|
if monitor is not None:
|
||||||
self.monitor = monitor
|
self.monitor = monitor
|
||||||
if ckpt_path is not None:
|
if ckpt_path is not None:
|
||||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
|
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [], only_model=load_only_unet)
|
||||||
|
|
||||||
self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
|
self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
|
||||||
linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
|
linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
|
||||||
@ -182,13 +182,13 @@ class DDPMV1(pl.LightningModule):
|
|||||||
if context is not None:
|
if context is not None:
|
||||||
print(f"{context}: Restored training weights")
|
print(f"{context}: Restored training weights")
|
||||||
|
|
||||||
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
|
def init_from_ckpt(self, path, ignore_keys=None, only_model=False):
|
||||||
sd = torch.load(path, map_location="cpu")
|
sd = torch.load(path, map_location="cpu")
|
||||||
if "state_dict" in list(sd.keys()):
|
if "state_dict" in list(sd.keys()):
|
||||||
sd = sd["state_dict"]
|
sd = sd["state_dict"]
|
||||||
keys = list(sd.keys())
|
keys = list(sd.keys())
|
||||||
for k in keys:
|
for k in keys:
|
||||||
for ik in ignore_keys:
|
for ik in ignore_keys or []:
|
||||||
if k.startswith(ik):
|
if k.startswith(ik):
|
||||||
print("Deleting key {} from state_dict.".format(k))
|
print("Deleting key {} from state_dict.".format(k))
|
||||||
del sd[k]
|
del sd[k]
|
||||||
@ -375,7 +375,7 @@ class DDPMV1(pl.LightningModule):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
|
def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
|
||||||
log = dict()
|
log = {}
|
||||||
x = self.get_input(batch, self.first_stage_key)
|
x = self.get_input(batch, self.first_stage_key)
|
||||||
N = min(x.shape[0], N)
|
N = min(x.shape[0], N)
|
||||||
n_row = min(x.shape[0], n_row)
|
n_row = min(x.shape[0], n_row)
|
||||||
@ -383,7 +383,7 @@ class DDPMV1(pl.LightningModule):
|
|||||||
log["inputs"] = x
|
log["inputs"] = x
|
||||||
|
|
||||||
# get diffusion row
|
# get diffusion row
|
||||||
diffusion_row = list()
|
diffusion_row = []
|
||||||
x_start = x[:n_row]
|
x_start = x[:n_row]
|
||||||
|
|
||||||
for t in range(self.num_timesteps):
|
for t in range(self.num_timesteps):
|
||||||
@ -444,13 +444,13 @@ class LatentDiffusionV1(DDPMV1):
|
|||||||
conditioning_key = None
|
conditioning_key = None
|
||||||
ckpt_path = kwargs.pop("ckpt_path", None)
|
ckpt_path = kwargs.pop("ckpt_path", None)
|
||||||
ignore_keys = kwargs.pop("ignore_keys", [])
|
ignore_keys = kwargs.pop("ignore_keys", [])
|
||||||
super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
|
super().__init__(*args, conditioning_key=conditioning_key, **kwargs)
|
||||||
self.concat_mode = concat_mode
|
self.concat_mode = concat_mode
|
||||||
self.cond_stage_trainable = cond_stage_trainable
|
self.cond_stage_trainable = cond_stage_trainable
|
||||||
self.cond_stage_key = cond_stage_key
|
self.cond_stage_key = cond_stage_key
|
||||||
try:
|
try:
|
||||||
self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
|
self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
|
||||||
except:
|
except Exception:
|
||||||
self.num_downs = 0
|
self.num_downs = 0
|
||||||
if not scale_by_std:
|
if not scale_by_std:
|
||||||
self.scale_factor = scale_factor
|
self.scale_factor = scale_factor
|
||||||
@ -460,7 +460,7 @@ class LatentDiffusionV1(DDPMV1):
|
|||||||
self.instantiate_cond_stage(cond_stage_config)
|
self.instantiate_cond_stage(cond_stage_config)
|
||||||
self.cond_stage_forward = cond_stage_forward
|
self.cond_stage_forward = cond_stage_forward
|
||||||
self.clip_denoised = False
|
self.clip_denoised = False
|
||||||
self.bbox_tokenizer = None
|
self.bbox_tokenizer = None
|
||||||
|
|
||||||
self.restarted_from_ckpt = False
|
self.restarted_from_ckpt = False
|
||||||
if ckpt_path is not None:
|
if ckpt_path is not None:
|
||||||
@ -792,7 +792,7 @@ class LatentDiffusionV1(DDPMV1):
|
|||||||
z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
|
z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
|
||||||
|
|
||||||
# 2. apply model loop over last dim
|
# 2. apply model loop over last dim
|
||||||
if isinstance(self.first_stage_model, VQModelInterface):
|
if isinstance(self.first_stage_model, VQModelInterface):
|
||||||
output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
|
output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
|
||||||
force_not_quantize=predict_cids or force_not_quantize)
|
force_not_quantize=predict_cids or force_not_quantize)
|
||||||
for i in range(z.shape[-1])]
|
for i in range(z.shape[-1])]
|
||||||
@ -877,16 +877,6 @@ class LatentDiffusionV1(DDPMV1):
|
|||||||
c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
|
c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
|
||||||
return self.p_losses(x, c, t, *args, **kwargs)
|
return self.p_losses(x, c, t, *args, **kwargs)
|
||||||
|
|
||||||
def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset
|
|
||||||
def rescale_bbox(bbox):
|
|
||||||
x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
|
|
||||||
y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
|
|
||||||
w = min(bbox[2] / crop_coordinates[2], 1 - x0)
|
|
||||||
h = min(bbox[3] / crop_coordinates[3], 1 - y0)
|
|
||||||
return x0, y0, w, h
|
|
||||||
|
|
||||||
return [rescale_bbox(b) for b in bboxes]
|
|
||||||
|
|
||||||
def apply_model(self, x_noisy, t, cond, return_ids=False):
|
def apply_model(self, x_noisy, t, cond, return_ids=False):
|
||||||
|
|
||||||
if isinstance(cond, dict):
|
if isinstance(cond, dict):
|
||||||
@ -900,7 +890,7 @@ class LatentDiffusionV1(DDPMV1):
|
|||||||
|
|
||||||
if hasattr(self, "split_input_params"):
|
if hasattr(self, "split_input_params"):
|
||||||
assert len(cond) == 1 # todo can only deal with one conditioning atm
|
assert len(cond) == 1 # todo can only deal with one conditioning atm
|
||||||
assert not return_ids
|
assert not return_ids
|
||||||
ks = self.split_input_params["ks"] # eg. (128, 128)
|
ks = self.split_input_params["ks"] # eg. (128, 128)
|
||||||
stride = self.split_input_params["stride"] # eg. (64, 64)
|
stride = self.split_input_params["stride"] # eg. (64, 64)
|
||||||
|
|
||||||
@ -1126,7 +1116,7 @@ class LatentDiffusionV1(DDPMV1):
|
|||||||
if cond is not None:
|
if cond is not None:
|
||||||
if isinstance(cond, dict):
|
if isinstance(cond, dict):
|
||||||
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
|
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
|
||||||
list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
|
[x[:batch_size] for x in cond[key]] for key in cond}
|
||||||
else:
|
else:
|
||||||
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
|
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
|
||||||
|
|
||||||
@ -1157,8 +1147,10 @@ class LatentDiffusionV1(DDPMV1):
|
|||||||
|
|
||||||
if i % log_every_t == 0 or i == timesteps - 1:
|
if i % log_every_t == 0 or i == timesteps - 1:
|
||||||
intermediates.append(x0_partial)
|
intermediates.append(x0_partial)
|
||||||
if callback: callback(i)
|
if callback:
|
||||||
if img_callback: img_callback(img, i)
|
callback(i)
|
||||||
|
if img_callback:
|
||||||
|
img_callback(img, i)
|
||||||
return img, intermediates
|
return img, intermediates
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@ -1205,8 +1197,10 @@ class LatentDiffusionV1(DDPMV1):
|
|||||||
|
|
||||||
if i % log_every_t == 0 or i == timesteps - 1:
|
if i % log_every_t == 0 or i == timesteps - 1:
|
||||||
intermediates.append(img)
|
intermediates.append(img)
|
||||||
if callback: callback(i)
|
if callback:
|
||||||
if img_callback: img_callback(img, i)
|
callback(i)
|
||||||
|
if img_callback:
|
||||||
|
img_callback(img, i)
|
||||||
|
|
||||||
if return_intermediates:
|
if return_intermediates:
|
||||||
return img, intermediates
|
return img, intermediates
|
||||||
@ -1221,7 +1215,7 @@ class LatentDiffusionV1(DDPMV1):
|
|||||||
if cond is not None:
|
if cond is not None:
|
||||||
if isinstance(cond, dict):
|
if isinstance(cond, dict):
|
||||||
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
|
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
|
||||||
list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
|
[x[:batch_size] for x in cond[key]] for key in cond}
|
||||||
else:
|
else:
|
||||||
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
|
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
|
||||||
return self.p_sample_loop(cond,
|
return self.p_sample_loop(cond,
|
||||||
@ -1253,7 +1247,7 @@ class LatentDiffusionV1(DDPMV1):
|
|||||||
|
|
||||||
use_ddim = ddim_steps is not None
|
use_ddim = ddim_steps is not None
|
||||||
|
|
||||||
log = dict()
|
log = {}
|
||||||
z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
|
z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
|
||||||
return_first_stage_outputs=True,
|
return_first_stage_outputs=True,
|
||||||
force_c_encode=True,
|
force_c_encode=True,
|
||||||
@ -1280,7 +1274,7 @@ class LatentDiffusionV1(DDPMV1):
|
|||||||
|
|
||||||
if plot_diffusion_rows:
|
if plot_diffusion_rows:
|
||||||
# get diffusion row
|
# get diffusion row
|
||||||
diffusion_row = list()
|
diffusion_row = []
|
||||||
z_start = z[:n_row]
|
z_start = z[:n_row]
|
||||||
for t in range(self.num_timesteps):
|
for t in range(self.num_timesteps):
|
||||||
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
|
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
|
||||||
@ -1322,7 +1316,7 @@ class LatentDiffusionV1(DDPMV1):
|
|||||||
|
|
||||||
if inpaint:
|
if inpaint:
|
||||||
# make a simple center square
|
# make a simple center square
|
||||||
b, h, w = z.shape[0], z.shape[2], z.shape[3]
|
h, w = z.shape[2], z.shape[3]
|
||||||
mask = torch.ones(N, h, w).to(self.device)
|
mask = torch.ones(N, h, w).to(self.device)
|
||||||
# zeros will be filled in
|
# zeros will be filled in
|
||||||
mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
|
mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
|
||||||
@ -1424,10 +1418,10 @@ class Layout2ImgDiffusionV1(LatentDiffusionV1):
|
|||||||
# TODO: move all layout-specific hacks to this class
|
# TODO: move all layout-specific hacks to this class
|
||||||
def __init__(self, cond_stage_key, *args, **kwargs):
|
def __init__(self, cond_stage_key, *args, **kwargs):
|
||||||
assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
|
assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
|
||||||
super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs)
|
super().__init__(*args, cond_stage_key=cond_stage_key, **kwargs)
|
||||||
|
|
||||||
def log_images(self, batch, N=8, *args, **kwargs):
|
def log_images(self, batch, N=8, *args, **kwargs):
|
||||||
logs = super().log_images(batch=batch, N=N, *args, **kwargs)
|
logs = super().log_images(*args, batch=batch, N=N, **kwargs)
|
||||||
|
|
||||||
key = 'train' if self.training else 'validation'
|
key = 'train' if self.training else 'validation'
|
||||||
dset = self.trainer.datamodule.datasets[key]
|
dset = self.trainer.datamodule.datasets[key]
|
||||||
@ -1443,7 +1437,7 @@ class Layout2ImgDiffusionV1(LatentDiffusionV1):
|
|||||||
logs['bbox_image'] = cond_img
|
logs['bbox_image'] = cond_img
|
||||||
return logs
|
return logs
|
||||||
|
|
||||||
setattr(ldm.models.diffusion.ddpm, "DDPMV1", DDPMV1)
|
ldm.models.diffusion.ddpm.DDPMV1 = DDPMV1
|
||||||
setattr(ldm.models.diffusion.ddpm, "LatentDiffusionV1", LatentDiffusionV1)
|
ldm.models.diffusion.ddpm.LatentDiffusionV1 = LatentDiffusionV1
|
||||||
setattr(ldm.models.diffusion.ddpm, "DiffusionWrapperV1", DiffusionWrapperV1)
|
ldm.models.diffusion.ddpm.DiffusionWrapperV1 = DiffusionWrapperV1
|
||||||
setattr(ldm.models.diffusion.ddpm, "Layout2ImgDiffusionV1", Layout2ImgDiffusionV1)
|
ldm.models.diffusion.ddpm.Layout2ImgDiffusionV1 = Layout2ImgDiffusionV1
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from modules import extra_networks, shared
|
from modules import extra_networks, shared
|
||||||
import lora
|
import lora
|
||||||
|
|
||||||
|
|
||||||
class ExtraNetworkLora(extra_networks.ExtraNetwork):
|
class ExtraNetworkLora(extra_networks.ExtraNetwork):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__('lora')
|
super().__init__('lora')
|
||||||
|
@ -1,10 +1,9 @@
|
|||||||
import glob
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import torch
|
import torch
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
from modules import shared, devices, sd_models, errors
|
from modules import shared, devices, sd_models, errors, scripts
|
||||||
|
|
||||||
metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
|
metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
|
||||||
|
|
||||||
@ -93,6 +92,7 @@ class LoraOnDisk:
|
|||||||
self.metadata = m
|
self.metadata = m
|
||||||
|
|
||||||
self.ssmd_cover_images = self.metadata.pop('ssmd_cover_images', None) # those are cover images and they are too big to display in UI as text
|
self.ssmd_cover_images = self.metadata.pop('ssmd_cover_images', None) # those are cover images and they are too big to display in UI as text
|
||||||
|
self.alias = self.metadata.get('ss_output_name', self.name)
|
||||||
|
|
||||||
|
|
||||||
class LoraModule:
|
class LoraModule:
|
||||||
@ -165,12 +165,14 @@ def load_lora(name, filename):
|
|||||||
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
|
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
|
||||||
elif type(sd_module) == torch.nn.MultiheadAttention:
|
elif type(sd_module) == torch.nn.MultiheadAttention:
|
||||||
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
|
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
|
||||||
elif type(sd_module) == torch.nn.Conv2d:
|
elif type(sd_module) == torch.nn.Conv2d and weight.shape[2:] == (1, 1):
|
||||||
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
|
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
|
||||||
|
elif type(sd_module) == torch.nn.Conv2d and weight.shape[2:] == (3, 3):
|
||||||
|
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (3, 3), bias=False)
|
||||||
else:
|
else:
|
||||||
print(f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}')
|
print(f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}')
|
||||||
continue
|
continue
|
||||||
assert False, f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}'
|
raise AssertionError(f"Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}")
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
module.weight.copy_(weight)
|
module.weight.copy_(weight)
|
||||||
@ -182,7 +184,7 @@ def load_lora(name, filename):
|
|||||||
elif lora_key == "lora_down.weight":
|
elif lora_key == "lora_down.weight":
|
||||||
lora_module.down = module
|
lora_module.down = module
|
||||||
else:
|
else:
|
||||||
assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha'
|
raise AssertionError(f"Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha")
|
||||||
|
|
||||||
if len(keys_failed_to_match) > 0:
|
if len(keys_failed_to_match) > 0:
|
||||||
print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}")
|
print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}")
|
||||||
@ -199,11 +201,11 @@ def load_loras(names, multipliers=None):
|
|||||||
|
|
||||||
loaded_loras.clear()
|
loaded_loras.clear()
|
||||||
|
|
||||||
loras_on_disk = [available_loras.get(name, None) for name in names]
|
loras_on_disk = [available_lora_aliases.get(name, None) for name in names]
|
||||||
if any([x is None for x in loras_on_disk]):
|
if any(x is None for x in loras_on_disk):
|
||||||
list_available_loras()
|
list_available_loras()
|
||||||
|
|
||||||
loras_on_disk = [available_loras.get(name, None) for name in names]
|
loras_on_disk = [available_lora_aliases.get(name, None) for name in names]
|
||||||
|
|
||||||
for i, name in enumerate(names):
|
for i, name in enumerate(names):
|
||||||
lora = already_loaded.get(name, None)
|
lora = already_loaded.get(name, None)
|
||||||
@ -232,6 +234,8 @@ def lora_calc_updown(lora, module, target):
|
|||||||
|
|
||||||
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
|
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
|
||||||
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
||||||
|
elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
|
||||||
|
updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
|
||||||
else:
|
else:
|
||||||
updown = up @ down
|
updown = up @ down
|
||||||
|
|
||||||
@ -240,6 +244,19 @@ def lora_calc_updown(lora, module, target):
|
|||||||
return updown
|
return updown
|
||||||
|
|
||||||
|
|
||||||
|
def lora_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
|
||||||
|
weights_backup = getattr(self, "lora_weights_backup", None)
|
||||||
|
|
||||||
|
if weights_backup is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if isinstance(self, torch.nn.MultiheadAttention):
|
||||||
|
self.in_proj_weight.copy_(weights_backup[0])
|
||||||
|
self.out_proj.weight.copy_(weights_backup[1])
|
||||||
|
else:
|
||||||
|
self.weight.copy_(weights_backup)
|
||||||
|
|
||||||
|
|
||||||
def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
|
def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
|
||||||
"""
|
"""
|
||||||
Applies the currently selected set of Loras to the weights of torch layer self.
|
Applies the currently selected set of Loras to the weights of torch layer self.
|
||||||
@ -264,12 +281,7 @@ def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.Mu
|
|||||||
self.lora_weights_backup = weights_backup
|
self.lora_weights_backup = weights_backup
|
||||||
|
|
||||||
if current_names != wanted_names:
|
if current_names != wanted_names:
|
||||||
if weights_backup is not None:
|
lora_restore_weights_from_backup(self)
|
||||||
if isinstance(self, torch.nn.MultiheadAttention):
|
|
||||||
self.in_proj_weight.copy_(weights_backup[0])
|
|
||||||
self.out_proj.weight.copy_(weights_backup[1])
|
|
||||||
else:
|
|
||||||
self.weight.copy_(weights_backup)
|
|
||||||
|
|
||||||
for lora in loaded_loras:
|
for lora in loaded_loras:
|
||||||
module = lora.modules.get(lora_layer_name, None)
|
module = lora.modules.get(lora_layer_name, None)
|
||||||
@ -297,15 +309,48 @@ def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.Mu
|
|||||||
|
|
||||||
print(f'failed to calculate lora weights for layer {lora_layer_name}')
|
print(f'failed to calculate lora weights for layer {lora_layer_name}')
|
||||||
|
|
||||||
setattr(self, "lora_current_names", wanted_names)
|
self.lora_current_names = wanted_names
|
||||||
|
|
||||||
|
|
||||||
|
def lora_forward(module, input, original_forward):
|
||||||
|
"""
|
||||||
|
Old way of applying Lora by executing operations during layer's forward.
|
||||||
|
Stacking many loras this way results in big performance degradation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if len(loaded_loras) == 0:
|
||||||
|
return original_forward(module, input)
|
||||||
|
|
||||||
|
input = devices.cond_cast_unet(input)
|
||||||
|
|
||||||
|
lora_restore_weights_from_backup(module)
|
||||||
|
lora_reset_cached_weight(module)
|
||||||
|
|
||||||
|
res = original_forward(module, input)
|
||||||
|
|
||||||
|
lora_layer_name = getattr(module, 'lora_layer_name', None)
|
||||||
|
for lora in loaded_loras:
|
||||||
|
module = lora.modules.get(lora_layer_name, None)
|
||||||
|
if module is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
module.up.to(device=devices.device)
|
||||||
|
module.down.to(device=devices.device)
|
||||||
|
|
||||||
|
res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
def lora_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
|
def lora_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
|
||||||
setattr(self, "lora_current_names", ())
|
self.lora_current_names = ()
|
||||||
setattr(self, "lora_weights_backup", None)
|
self.lora_weights_backup = None
|
||||||
|
|
||||||
|
|
||||||
def lora_Linear_forward(self, input):
|
def lora_Linear_forward(self, input):
|
||||||
|
if shared.opts.lora_functional:
|
||||||
|
return lora_forward(self, input, torch.nn.Linear_forward_before_lora)
|
||||||
|
|
||||||
lora_apply_weights(self)
|
lora_apply_weights(self)
|
||||||
|
|
||||||
return torch.nn.Linear_forward_before_lora(self, input)
|
return torch.nn.Linear_forward_before_lora(self, input)
|
||||||
@ -318,6 +363,9 @@ def lora_Linear_load_state_dict(self, *args, **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
def lora_Conv2d_forward(self, input):
|
def lora_Conv2d_forward(self, input):
|
||||||
|
if shared.opts.lora_functional:
|
||||||
|
return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora)
|
||||||
|
|
||||||
lora_apply_weights(self)
|
lora_apply_weights(self)
|
||||||
|
|
||||||
return torch.nn.Conv2d_forward_before_lora(self, input)
|
return torch.nn.Conv2d_forward_before_lora(self, input)
|
||||||
@ -343,24 +391,65 @@ def lora_MultiheadAttention_load_state_dict(self, *args, **kwargs):
|
|||||||
|
|
||||||
def list_available_loras():
|
def list_available_loras():
|
||||||
available_loras.clear()
|
available_loras.clear()
|
||||||
|
available_lora_aliases.clear()
|
||||||
|
forbidden_lora_aliases.clear()
|
||||||
|
forbidden_lora_aliases.update({"none": 1})
|
||||||
|
|
||||||
os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
|
os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
|
||||||
|
|
||||||
candidates = \
|
candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
|
||||||
glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.pt'), recursive=True) + \
|
|
||||||
glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.safetensors'), recursive=True) + \
|
|
||||||
glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.ckpt'), recursive=True)
|
|
||||||
|
|
||||||
for filename in sorted(candidates, key=str.lower):
|
for filename in sorted(candidates, key=str.lower):
|
||||||
if os.path.isdir(filename):
|
if os.path.isdir(filename):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
name = os.path.splitext(os.path.basename(filename))[0]
|
name = os.path.splitext(os.path.basename(filename))[0]
|
||||||
|
entry = LoraOnDisk(name, filename)
|
||||||
|
|
||||||
available_loras[name] = LoraOnDisk(name, filename)
|
available_loras[name] = entry
|
||||||
|
|
||||||
|
if entry.alias in available_lora_aliases:
|
||||||
|
forbidden_lora_aliases[entry.alias.lower()] = 1
|
||||||
|
|
||||||
|
available_lora_aliases[name] = entry
|
||||||
|
available_lora_aliases[entry.alias] = entry
|
||||||
|
|
||||||
|
|
||||||
|
re_lora_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)")
|
||||||
|
|
||||||
|
|
||||||
|
def infotext_pasted(infotext, params):
|
||||||
|
if "AddNet Module 1" in [x[1] for x in scripts.scripts_txt2img.infotext_fields]:
|
||||||
|
return # if the other extension is active, it will handle those fields, no need to do anything
|
||||||
|
|
||||||
|
added = []
|
||||||
|
|
||||||
|
for k in params:
|
||||||
|
if not k.startswith("AddNet Model "):
|
||||||
|
continue
|
||||||
|
|
||||||
|
num = k[13:]
|
||||||
|
|
||||||
|
if params.get("AddNet Module " + num) != "LoRA":
|
||||||
|
continue
|
||||||
|
|
||||||
|
name = params.get("AddNet Model " + num)
|
||||||
|
if name is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
m = re_lora_name.match(name)
|
||||||
|
if m:
|
||||||
|
name = m.group(1)
|
||||||
|
|
||||||
|
multiplier = params.get("AddNet Weight A " + num, "1.0")
|
||||||
|
|
||||||
|
added.append(f"<lora:{name}:{multiplier}>")
|
||||||
|
|
||||||
|
if added:
|
||||||
|
params["Prompt"] += "\n" + "".join(added)
|
||||||
|
|
||||||
available_loras = {}
|
available_loras = {}
|
||||||
|
available_lora_aliases = {}
|
||||||
|
forbidden_lora_aliases = {}
|
||||||
loaded_loras = []
|
loaded_loras = []
|
||||||
|
|
||||||
list_available_loras()
|
list_available_loras()
|
||||||
|
@ -1,12 +1,12 @@
|
|||||||
import torch
|
import torch
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
import lora
|
import lora
|
||||||
import extra_networks_lora
|
import extra_networks_lora
|
||||||
import ui_extra_networks_lora
|
import ui_extra_networks_lora
|
||||||
from modules import script_callbacks, ui_extra_networks, extra_networks, shared
|
from modules import script_callbacks, ui_extra_networks, extra_networks, shared
|
||||||
|
|
||||||
|
|
||||||
def unload():
|
def unload():
|
||||||
torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora
|
torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora
|
||||||
torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_lora
|
torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_lora
|
||||||
@ -49,8 +49,34 @@ torch.nn.MultiheadAttention._load_from_state_dict = lora.lora_MultiheadAttention
|
|||||||
script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules)
|
script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules)
|
||||||
script_callbacks.on_script_unloaded(unload)
|
script_callbacks.on_script_unloaded(unload)
|
||||||
script_callbacks.on_before_ui(before_ui)
|
script_callbacks.on_before_ui(before_ui)
|
||||||
|
script_callbacks.on_infotext_pasted(lora.infotext_pasted)
|
||||||
|
|
||||||
|
|
||||||
shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
|
shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
|
||||||
"sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras),
|
"sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": ["None", *lora.available_loras]}, refresh=lora.list_available_loras),
|
||||||
|
"lora_preferred_name": shared.OptionInfo("Alias from file", "When adding to prompt, refer to lora by", gr.Radio, {"choices": ["Alias from file", "Filename"]}),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
|
||||||
|
shared.options_templates.update(shared.options_section(('compatibility', "Compatibility"), {
|
||||||
|
"lora_functional": shared.OptionInfo(False, "Lora: use old method that takes longer when you have multiple Loras active and produces same results as kohya-ss/sd-webui-additional-networks extension"),
|
||||||
|
}))
|
||||||
|
|
||||||
|
|
||||||
|
def create_lora_json(obj: lora.LoraOnDisk):
|
||||||
|
return {
|
||||||
|
"name": obj.name,
|
||||||
|
"alias": obj.alias,
|
||||||
|
"path": obj.filename,
|
||||||
|
"metadata": obj.metadata,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def api_loras(_: gr.Blocks, app: FastAPI):
|
||||||
|
@app.get("/sdapi/v1/loras")
|
||||||
|
async def get_loras():
|
||||||
|
return [create_lora_json(obj) for obj in lora.available_loras.values()]
|
||||||
|
|
||||||
|
|
||||||
|
script_callbacks.on_app_started(api_loras)
|
||||||
|
|
||||||
|
@ -15,13 +15,19 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
|
|||||||
def list_items(self):
|
def list_items(self):
|
||||||
for name, lora_on_disk in lora.available_loras.items():
|
for name, lora_on_disk in lora.available_loras.items():
|
||||||
path, ext = os.path.splitext(lora_on_disk.filename)
|
path, ext = os.path.splitext(lora_on_disk.filename)
|
||||||
|
|
||||||
|
if shared.opts.lora_preferred_name == "Filename" or lora_on_disk.alias.lower() in lora.forbidden_lora_aliases:
|
||||||
|
alias = name
|
||||||
|
else:
|
||||||
|
alias = lora_on_disk.alias
|
||||||
|
|
||||||
yield {
|
yield {
|
||||||
"name": name,
|
"name": name,
|
||||||
"filename": path,
|
"filename": path,
|
||||||
"preview": self.find_preview(path),
|
"preview": self.find_preview(path),
|
||||||
"description": self.find_description(path),
|
"description": self.find_description(path),
|
||||||
"search_term": self.search_terms_from_path(lora_on_disk.filename),
|
"search_term": self.search_terms_from_path(lora_on_disk.filename),
|
||||||
"prompt": json.dumps(f"<lora:{name}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
|
"prompt": json.dumps(f"<lora:{alias}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
|
||||||
"local_preview": f"{path}.{shared.opts.samples_format}",
|
"local_preview": f"{path}.{shared.opts.samples_format}",
|
||||||
"metadata": json.dumps(lora_on_disk.metadata, indent=4) if lora_on_disk.metadata else None,
|
"metadata": json.dumps(lora_on_disk.metadata, indent=4) if lora_on_disk.metadata else None,
|
||||||
}
|
}
|
||||||
|
@ -10,10 +10,9 @@ from tqdm import tqdm
|
|||||||
from basicsr.utils.download_util import load_file_from_url
|
from basicsr.utils.download_util import load_file_from_url
|
||||||
|
|
||||||
import modules.upscaler
|
import modules.upscaler
|
||||||
from modules import devices, modelloader
|
from modules import devices, modelloader, script_callbacks
|
||||||
from scunet_model_arch import SCUNet as net
|
from scunet_model_arch import SCUNet as net
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
from modules import images
|
|
||||||
|
|
||||||
|
|
||||||
class UpscalerScuNET(modules.upscaler.Upscaler):
|
class UpscalerScuNET(modules.upscaler.Upscaler):
|
||||||
@ -133,8 +132,19 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
|
|||||||
model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
|
model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
|
||||||
model.load_state_dict(torch.load(filename), strict=True)
|
model.load_state_dict(torch.load(filename), strict=True)
|
||||||
model.eval()
|
model.eval()
|
||||||
for k, v in model.named_parameters():
|
for _, v in model.named_parameters():
|
||||||
v.requires_grad = False
|
v.requires_grad = False
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def on_ui_settings():
|
||||||
|
import gradio as gr
|
||||||
|
from modules import shared
|
||||||
|
|
||||||
|
shared.opts.add_option("SCUNET_tile", shared.OptionInfo(256, "Tile size for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")).info("0 = no tiling"))
|
||||||
|
shared.opts.add_option("SCUNET_tile_overlap", shared.OptionInfo(8, "Tile overlap for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, section=('upscaling', "Upscaling")).info("Low values = visible seam"))
|
||||||
|
|
||||||
|
|
||||||
|
script_callbacks.on_ui_settings(on_ui_settings)
|
||||||
|
@ -61,7 +61,9 @@ class WMSA(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
output: tensor shape [b h w c]
|
output: tensor shape [b h w c]
|
||||||
"""
|
"""
|
||||||
if self.type != 'W': x = torch.roll(x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2))
|
if self.type != 'W':
|
||||||
|
x = torch.roll(x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2))
|
||||||
|
|
||||||
x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size)
|
x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size)
|
||||||
h_windows = x.size(1)
|
h_windows = x.size(1)
|
||||||
w_windows = x.size(2)
|
w_windows = x.size(2)
|
||||||
@ -85,8 +87,9 @@ class WMSA(nn.Module):
|
|||||||
output = self.linear(output)
|
output = self.linear(output)
|
||||||
output = rearrange(output, 'b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c', w1=h_windows, p1=self.window_size)
|
output = rearrange(output, 'b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c', w1=h_windows, p1=self.window_size)
|
||||||
|
|
||||||
if self.type != 'W': output = torch.roll(output, shifts=(self.window_size // 2, self.window_size // 2),
|
if self.type != 'W':
|
||||||
dims=(1, 2))
|
output = torch.roll(output, shifts=(self.window_size // 2, self.window_size // 2), dims=(1, 2))
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def relative_embedding(self):
|
def relative_embedding(self):
|
||||||
@ -262,4 +265,4 @@ class SCUNet(nn.Module):
|
|||||||
nn.init.constant_(m.bias, 0)
|
nn.init.constant_(m.bias, 0)
|
||||||
elif isinstance(m, nn.LayerNorm):
|
elif isinstance(m, nn.LayerNorm):
|
||||||
nn.init.constant_(m.bias, 0)
|
nn.init.constant_(m.bias, 0)
|
||||||
nn.init.constant_(m.weight, 1.0)
|
nn.init.constant_(m.weight, 1.0)
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import contextlib
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -8,7 +7,7 @@ from basicsr.utils.download_util import load_file_from_url
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from modules import modelloader, devices, script_callbacks, shared
|
from modules import modelloader, devices, script_callbacks, shared
|
||||||
from modules.shared import cmd_opts, opts, state
|
from modules.shared import opts, state
|
||||||
from swinir_model_arch import SwinIR as net
|
from swinir_model_arch import SwinIR as net
|
||||||
from swinir_model_arch_v2 import Swin2SR as net2
|
from swinir_model_arch_v2 import Swin2SR as net2
|
||||||
from modules.upscaler import Upscaler, UpscalerData
|
from modules.upscaler import Upscaler, UpscalerData
|
||||||
@ -45,7 +44,7 @@ class UpscalerSwinIR(Upscaler):
|
|||||||
img = upscale(img, model)
|
img = upscale(img, model)
|
||||||
try:
|
try:
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
return img
|
return img
|
||||||
|
|
||||||
@ -151,7 +150,7 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
|
|||||||
for w_idx in w_idx_list:
|
for w_idx in w_idx_list:
|
||||||
if state.interrupted or state.skipped:
|
if state.interrupted or state.skipped:
|
||||||
break
|
break
|
||||||
|
|
||||||
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
|
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
|
||||||
out_patch = model(in_patch)
|
out_patch = model(in_patch)
|
||||||
out_patch_mask = torch.ones_like(out_patch)
|
out_patch_mask = torch.ones_like(out_patch)
|
||||||
|
@ -644,7 +644,7 @@ class SwinIR(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, img_size=64, patch_size=1, in_chans=3,
|
def __init__(self, img_size=64, patch_size=1, in_chans=3,
|
||||||
embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
|
embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6),
|
||||||
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
|
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
|
||||||
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
||||||
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
|
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
|
||||||
@ -805,7 +805,7 @@ class SwinIR(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
H, W = x.shape[2:]
|
H, W = x.shape[2:]
|
||||||
x = self.check_image_size(x)
|
x = self.check_image_size(x)
|
||||||
|
|
||||||
self.mean = self.mean.type_as(x)
|
self.mean = self.mean.type_as(x)
|
||||||
x = (x - self.mean) * self.img_range
|
x = (x - self.mean) * self.img_range
|
||||||
|
|
||||||
@ -844,7 +844,7 @@ class SwinIR(nn.Module):
|
|||||||
H, W = self.patches_resolution
|
H, W = self.patches_resolution
|
||||||
flops += H * W * 3 * self.embed_dim * 9
|
flops += H * W * 3 * self.embed_dim * 9
|
||||||
flops += self.patch_embed.flops()
|
flops += self.patch_embed.flops()
|
||||||
for i, layer in enumerate(self.layers):
|
for layer in self.layers:
|
||||||
flops += layer.flops()
|
flops += layer.flops()
|
||||||
flops += H * W * 3 * self.embed_dim * self.embed_dim
|
flops += H * W * 3 * self.embed_dim * self.embed_dim
|
||||||
flops += self.upsample.flops()
|
flops += self.upsample.flops()
|
||||||
|
@ -74,7 +74,7 @@ class WindowAttention(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
|
def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
|
||||||
pretrained_window_size=[0, 0]):
|
pretrained_window_size=(0, 0)):
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
@ -241,7 +241,7 @@ class SwinTransformerBlock(nn.Module):
|
|||||||
attn_mask = None
|
attn_mask = None
|
||||||
|
|
||||||
self.register_buffer("attn_mask", attn_mask)
|
self.register_buffer("attn_mask", attn_mask)
|
||||||
|
|
||||||
def calculate_mask(self, x_size):
|
def calculate_mask(self, x_size):
|
||||||
# calculate attention mask for SW-MSA
|
# calculate attention mask for SW-MSA
|
||||||
H, W = x_size
|
H, W = x_size
|
||||||
@ -263,7 +263,7 @@ class SwinTransformerBlock(nn.Module):
|
|||||||
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
||||||
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
||||||
|
|
||||||
return attn_mask
|
return attn_mask
|
||||||
|
|
||||||
def forward(self, x, x_size):
|
def forward(self, x, x_size):
|
||||||
H, W = x_size
|
H, W = x_size
|
||||||
@ -288,7 +288,7 @@ class SwinTransformerBlock(nn.Module):
|
|||||||
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
|
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
|
||||||
else:
|
else:
|
||||||
attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
|
attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
|
||||||
|
|
||||||
# merge windows
|
# merge windows
|
||||||
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
||||||
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
|
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
|
||||||
@ -369,7 +369,7 @@ class PatchMerging(nn.Module):
|
|||||||
H, W = self.input_resolution
|
H, W = self.input_resolution
|
||||||
flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
|
flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
|
||||||
flops += H * W * self.dim // 2
|
flops += H * W * self.dim // 2
|
||||||
return flops
|
return flops
|
||||||
|
|
||||||
class BasicLayer(nn.Module):
|
class BasicLayer(nn.Module):
|
||||||
""" A basic Swin Transformer layer for one stage.
|
""" A basic Swin Transformer layer for one stage.
|
||||||
@ -447,7 +447,7 @@ class BasicLayer(nn.Module):
|
|||||||
nn.init.constant_(blk.norm1.weight, 0)
|
nn.init.constant_(blk.norm1.weight, 0)
|
||||||
nn.init.constant_(blk.norm2.bias, 0)
|
nn.init.constant_(blk.norm2.bias, 0)
|
||||||
nn.init.constant_(blk.norm2.weight, 0)
|
nn.init.constant_(blk.norm2.weight, 0)
|
||||||
|
|
||||||
class PatchEmbed(nn.Module):
|
class PatchEmbed(nn.Module):
|
||||||
r""" Image to Patch Embedding
|
r""" Image to Patch Embedding
|
||||||
Args:
|
Args:
|
||||||
@ -492,7 +492,7 @@ class PatchEmbed(nn.Module):
|
|||||||
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
|
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
|
||||||
if self.norm is not None:
|
if self.norm is not None:
|
||||||
flops += Ho * Wo * self.embed_dim
|
flops += Ho * Wo * self.embed_dim
|
||||||
return flops
|
return flops
|
||||||
|
|
||||||
class RSTB(nn.Module):
|
class RSTB(nn.Module):
|
||||||
"""Residual Swin Transformer Block (RSTB).
|
"""Residual Swin Transformer Block (RSTB).
|
||||||
@ -531,7 +531,7 @@ class RSTB(nn.Module):
|
|||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
window_size=window_size,
|
window_size=window_size,
|
||||||
mlp_ratio=mlp_ratio,
|
mlp_ratio=mlp_ratio,
|
||||||
qkv_bias=qkv_bias,
|
qkv_bias=qkv_bias,
|
||||||
drop=drop, attn_drop=attn_drop,
|
drop=drop, attn_drop=attn_drop,
|
||||||
drop_path=drop_path,
|
drop_path=drop_path,
|
||||||
norm_layer=norm_layer,
|
norm_layer=norm_layer,
|
||||||
@ -622,7 +622,7 @@ class Upsample(nn.Sequential):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
|
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
|
||||||
super(Upsample, self).__init__(*m)
|
super(Upsample, self).__init__(*m)
|
||||||
|
|
||||||
class Upsample_hf(nn.Sequential):
|
class Upsample_hf(nn.Sequential):
|
||||||
"""Upsample module.
|
"""Upsample module.
|
||||||
|
|
||||||
@ -642,7 +642,7 @@ class Upsample_hf(nn.Sequential):
|
|||||||
m.append(nn.PixelShuffle(3))
|
m.append(nn.PixelShuffle(3))
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
|
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
|
||||||
super(Upsample_hf, self).__init__(*m)
|
super(Upsample_hf, self).__init__(*m)
|
||||||
|
|
||||||
|
|
||||||
class UpsampleOneStep(nn.Sequential):
|
class UpsampleOneStep(nn.Sequential):
|
||||||
@ -667,8 +667,8 @@ class UpsampleOneStep(nn.Sequential):
|
|||||||
H, W = self.input_resolution
|
H, W = self.input_resolution
|
||||||
flops = H * W * self.num_feat * 3 * 9
|
flops = H * W * self.num_feat * 3 * 9
|
||||||
return flops
|
return flops
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Swin2SR(nn.Module):
|
class Swin2SR(nn.Module):
|
||||||
r""" Swin2SR
|
r""" Swin2SR
|
||||||
@ -698,8 +698,8 @@ class Swin2SR(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, img_size=64, patch_size=1, in_chans=3,
|
def __init__(self, img_size=64, patch_size=1, in_chans=3,
|
||||||
embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
|
embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6),
|
||||||
window_size=7, mlp_ratio=4., qkv_bias=True,
|
window_size=7, mlp_ratio=4., qkv_bias=True,
|
||||||
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
||||||
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
|
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
|
||||||
use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
|
use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
|
||||||
@ -764,7 +764,7 @@ class Swin2SR(nn.Module):
|
|||||||
num_heads=num_heads[i_layer],
|
num_heads=num_heads[i_layer],
|
||||||
window_size=window_size,
|
window_size=window_size,
|
||||||
mlp_ratio=self.mlp_ratio,
|
mlp_ratio=self.mlp_ratio,
|
||||||
qkv_bias=qkv_bias,
|
qkv_bias=qkv_bias,
|
||||||
drop=drop_rate, attn_drop=attn_drop_rate,
|
drop=drop_rate, attn_drop=attn_drop_rate,
|
||||||
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
|
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
|
||||||
norm_layer=norm_layer,
|
norm_layer=norm_layer,
|
||||||
@ -776,7 +776,7 @@ class Swin2SR(nn.Module):
|
|||||||
|
|
||||||
)
|
)
|
||||||
self.layers.append(layer)
|
self.layers.append(layer)
|
||||||
|
|
||||||
if self.upsampler == 'pixelshuffle_hf':
|
if self.upsampler == 'pixelshuffle_hf':
|
||||||
self.layers_hf = nn.ModuleList()
|
self.layers_hf = nn.ModuleList()
|
||||||
for i_layer in range(self.num_layers):
|
for i_layer in range(self.num_layers):
|
||||||
@ -787,7 +787,7 @@ class Swin2SR(nn.Module):
|
|||||||
num_heads=num_heads[i_layer],
|
num_heads=num_heads[i_layer],
|
||||||
window_size=window_size,
|
window_size=window_size,
|
||||||
mlp_ratio=self.mlp_ratio,
|
mlp_ratio=self.mlp_ratio,
|
||||||
qkv_bias=qkv_bias,
|
qkv_bias=qkv_bias,
|
||||||
drop=drop_rate, attn_drop=attn_drop_rate,
|
drop=drop_rate, attn_drop=attn_drop_rate,
|
||||||
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
|
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
|
||||||
norm_layer=norm_layer,
|
norm_layer=norm_layer,
|
||||||
@ -799,7 +799,7 @@ class Swin2SR(nn.Module):
|
|||||||
|
|
||||||
)
|
)
|
||||||
self.layers_hf.append(layer)
|
self.layers_hf.append(layer)
|
||||||
|
|
||||||
self.norm = norm_layer(self.num_features)
|
self.norm = norm_layer(self.num_features)
|
||||||
|
|
||||||
# build the last conv layer in deep feature extraction
|
# build the last conv layer in deep feature extraction
|
||||||
@ -829,10 +829,10 @@ class Swin2SR(nn.Module):
|
|||||||
self.conv_aux = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
self.conv_aux = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
||||||
self.conv_after_aux = nn.Sequential(
|
self.conv_after_aux = nn.Sequential(
|
||||||
nn.Conv2d(3, num_feat, 3, 1, 1),
|
nn.Conv2d(3, num_feat, 3, 1, 1),
|
||||||
nn.LeakyReLU(inplace=True))
|
nn.LeakyReLU(inplace=True))
|
||||||
self.upsample = Upsample(upscale, num_feat)
|
self.upsample = Upsample(upscale, num_feat)
|
||||||
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
||||||
|
|
||||||
elif self.upsampler == 'pixelshuffle_hf':
|
elif self.upsampler == 'pixelshuffle_hf':
|
||||||
self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
|
self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
|
||||||
nn.LeakyReLU(inplace=True))
|
nn.LeakyReLU(inplace=True))
|
||||||
@ -846,7 +846,7 @@ class Swin2SR(nn.Module):
|
|||||||
nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
|
nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
|
||||||
nn.LeakyReLU(inplace=True))
|
nn.LeakyReLU(inplace=True))
|
||||||
self.conv_last_hf = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
self.conv_last_hf = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
||||||
|
|
||||||
elif self.upsampler == 'pixelshuffledirect':
|
elif self.upsampler == 'pixelshuffledirect':
|
||||||
# for lightweight SR (to save parameters)
|
# for lightweight SR (to save parameters)
|
||||||
self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
|
self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
|
||||||
@ -905,7 +905,7 @@ class Swin2SR(nn.Module):
|
|||||||
x = self.patch_unembed(x, x_size)
|
x = self.patch_unembed(x, x_size)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def forward_features_hf(self, x):
|
def forward_features_hf(self, x):
|
||||||
x_size = (x.shape[2], x.shape[3])
|
x_size = (x.shape[2], x.shape[3])
|
||||||
x = self.patch_embed(x)
|
x = self.patch_embed(x)
|
||||||
@ -919,7 +919,7 @@ class Swin2SR(nn.Module):
|
|||||||
x = self.norm(x) # B L C
|
x = self.norm(x) # B L C
|
||||||
x = self.patch_unembed(x, x_size)
|
x = self.patch_unembed(x, x_size)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
H, W = x.shape[2:]
|
H, W = x.shape[2:]
|
||||||
@ -951,7 +951,7 @@ class Swin2SR(nn.Module):
|
|||||||
x = self.conv_after_body(self.forward_features(x)) + x
|
x = self.conv_after_body(self.forward_features(x)) + x
|
||||||
x_before = self.conv_before_upsample(x)
|
x_before = self.conv_before_upsample(x)
|
||||||
x_out = self.conv_last(self.upsample(x_before))
|
x_out = self.conv_last(self.upsample(x_before))
|
||||||
|
|
||||||
x_hf = self.conv_first_hf(x_before)
|
x_hf = self.conv_first_hf(x_before)
|
||||||
x_hf = self.conv_after_body_hf(self.forward_features_hf(x_hf)) + x_hf
|
x_hf = self.conv_after_body_hf(self.forward_features_hf(x_hf)) + x_hf
|
||||||
x_hf = self.conv_before_upsample_hf(x_hf)
|
x_hf = self.conv_before_upsample_hf(x_hf)
|
||||||
@ -977,15 +977,15 @@ class Swin2SR(nn.Module):
|
|||||||
x_first = self.conv_first(x)
|
x_first = self.conv_first(x)
|
||||||
res = self.conv_after_body(self.forward_features(x_first)) + x_first
|
res = self.conv_after_body(self.forward_features(x_first)) + x_first
|
||||||
x = x + self.conv_last(res)
|
x = x + self.conv_last(res)
|
||||||
|
|
||||||
x = x / self.img_range + self.mean
|
x = x / self.img_range + self.mean
|
||||||
if self.upsampler == "pixelshuffle_aux":
|
if self.upsampler == "pixelshuffle_aux":
|
||||||
return x[:, :, :H*self.upscale, :W*self.upscale], aux
|
return x[:, :, :H*self.upscale, :W*self.upscale], aux
|
||||||
|
|
||||||
elif self.upsampler == "pixelshuffle_hf":
|
elif self.upsampler == "pixelshuffle_hf":
|
||||||
x_out = x_out / self.img_range + self.mean
|
x_out = x_out / self.img_range + self.mean
|
||||||
return x_out[:, :, :H*self.upscale, :W*self.upscale], x[:, :, :H*self.upscale, :W*self.upscale], x_hf[:, :, :H*self.upscale, :W*self.upscale]
|
return x_out[:, :, :H*self.upscale, :W*self.upscale], x[:, :, :H*self.upscale, :W*self.upscale], x_hf[:, :, :H*self.upscale, :W*self.upscale]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
return x[:, :, :H*self.upscale, :W*self.upscale]
|
return x[:, :, :H*self.upscale, :W*self.upscale]
|
||||||
|
|
||||||
@ -994,7 +994,7 @@ class Swin2SR(nn.Module):
|
|||||||
H, W = self.patches_resolution
|
H, W = self.patches_resolution
|
||||||
flops += H * W * 3 * self.embed_dim * 9
|
flops += H * W * 3 * self.embed_dim * 9
|
||||||
flops += self.patch_embed.flops()
|
flops += self.patch_embed.flops()
|
||||||
for i, layer in enumerate(self.layers):
|
for layer in self.layers:
|
||||||
flops += layer.flops()
|
flops += layer.flops()
|
||||||
flops += H * W * 3 * self.embed_dim * self.embed_dim
|
flops += H * W * 3 * self.embed_dim * self.embed_dim
|
||||||
flops += self.upsample.flops()
|
flops += self.upsample.flops()
|
||||||
@ -1014,4 +1014,4 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
x = torch.randn((1, 3, height, width))
|
x = torch.randn((1, 3, height, width))
|
||||||
x = model(x)
|
x = model(x)
|
||||||
print(x.shape)
|
print(x.shape)
|
||||||
|
@ -6,7 +6,7 @@
|
|||||||
<ul>
|
<ul>
|
||||||
<a href="#" title="replace preview image with currently selected in gallery" onclick={save_card_preview}>replace preview</a>
|
<a href="#" title="replace preview image with currently selected in gallery" onclick={save_card_preview}>replace preview</a>
|
||||||
</ul>
|
</ul>
|
||||||
<span style="display:none" class='search_term'>{search_term}</span>
|
<span style="display:none" class='search_term{serach_only}'>{search_term}</span>
|
||||||
</div>
|
</div>
|
||||||
<span class='name'>{name}</span>
|
<span class='name'>{name}</span>
|
||||||
<span class='description'>{description}</span>
|
<span class='description'>{description}</span>
|
||||||
|
@ -45,29 +45,24 @@ function dimensionChange(e, is_width, is_height){
|
|||||||
|
|
||||||
var viewportOffset = targetElement.getBoundingClientRect();
|
var viewportOffset = targetElement.getBoundingClientRect();
|
||||||
|
|
||||||
viewportscale = Math.min( targetElement.clientWidth/targetElement.naturalWidth, targetElement.clientHeight/targetElement.naturalHeight )
|
var viewportscale = Math.min( targetElement.clientWidth/targetElement.naturalWidth, targetElement.clientHeight/targetElement.naturalHeight )
|
||||||
|
|
||||||
scaledx = targetElement.naturalWidth*viewportscale
|
var scaledx = targetElement.naturalWidth*viewportscale
|
||||||
scaledy = targetElement.naturalHeight*viewportscale
|
var scaledy = targetElement.naturalHeight*viewportscale
|
||||||
|
|
||||||
cleintRectTop = (viewportOffset.top+window.scrollY)
|
var cleintRectTop = (viewportOffset.top+window.scrollY)
|
||||||
cleintRectLeft = (viewportOffset.left+window.scrollX)
|
var cleintRectLeft = (viewportOffset.left+window.scrollX)
|
||||||
cleintRectCentreY = cleintRectTop + (targetElement.clientHeight/2)
|
var cleintRectCentreY = cleintRectTop + (targetElement.clientHeight/2)
|
||||||
cleintRectCentreX = cleintRectLeft + (targetElement.clientWidth/2)
|
var cleintRectCentreX = cleintRectLeft + (targetElement.clientWidth/2)
|
||||||
|
|
||||||
viewRectTop = cleintRectCentreY-(scaledy/2)
|
var arscale = Math.min( scaledx/currentWidth, scaledy/currentHeight )
|
||||||
viewRectLeft = cleintRectCentreX-(scaledx/2)
|
var arscaledx = currentWidth*arscale
|
||||||
arRectWidth = scaledx
|
var arscaledy = currentHeight*arscale
|
||||||
arRectHeight = scaledy
|
|
||||||
|
|
||||||
arscale = Math.min( arRectWidth/currentWidth, arRectHeight/currentHeight )
|
var arRectTop = cleintRectCentreY-(arscaledy/2)
|
||||||
arscaledx = currentWidth*arscale
|
var arRectLeft = cleintRectCentreX-(arscaledx/2)
|
||||||
arscaledy = currentHeight*arscale
|
var arRectWidth = arscaledx
|
||||||
|
var arRectHeight = arscaledy
|
||||||
arRectTop = cleintRectCentreY-(arscaledy/2)
|
|
||||||
arRectLeft = cleintRectCentreX-(arscaledx/2)
|
|
||||||
arRectWidth = arscaledx
|
|
||||||
arRectHeight = arscaledy
|
|
||||||
|
|
||||||
arPreviewRect.style.top = arRectTop+'px';
|
arPreviewRect.style.top = arRectTop+'px';
|
||||||
arPreviewRect.style.left = arRectLeft+'px';
|
arPreviewRect.style.left = arRectLeft+'px';
|
||||||
|
@ -4,7 +4,7 @@ contextMenuInit = function(){
|
|||||||
let menuSpecs = new Map();
|
let menuSpecs = new Map();
|
||||||
|
|
||||||
const uid = function(){
|
const uid = function(){
|
||||||
return Date.now().toString(36) + Math.random().toString(36).substr(2);
|
return Date.now().toString(36) + Math.random().toString(36).substring(2);
|
||||||
}
|
}
|
||||||
|
|
||||||
function showContextMenu(event,element,menuEntries){
|
function showContextMenu(event,element,menuEntries){
|
||||||
@ -16,8 +16,7 @@ contextMenuInit = function(){
|
|||||||
oldMenu.remove()
|
oldMenu.remove()
|
||||||
}
|
}
|
||||||
|
|
||||||
let tabButton = uiCurrentTab
|
let baseStyle = window.getComputedStyle(uiCurrentTab)
|
||||||
let baseStyle = window.getComputedStyle(tabButton)
|
|
||||||
|
|
||||||
const contextMenu = document.createElement('nav')
|
const contextMenu = document.createElement('nav')
|
||||||
contextMenu.id = "context-menu"
|
contextMenu.id = "context-menu"
|
||||||
@ -36,7 +35,7 @@ contextMenuInit = function(){
|
|||||||
menuEntries.forEach(function(entry){
|
menuEntries.forEach(function(entry){
|
||||||
let contextMenuEntry = document.createElement('a')
|
let contextMenuEntry = document.createElement('a')
|
||||||
contextMenuEntry.innerHTML = entry['name']
|
contextMenuEntry.innerHTML = entry['name']
|
||||||
contextMenuEntry.addEventListener("click", function(e) {
|
contextMenuEntry.addEventListener("click", function() {
|
||||||
entry['func']();
|
entry['func']();
|
||||||
})
|
})
|
||||||
contextMenuList.append(contextMenuEntry);
|
contextMenuList.append(contextMenuEntry);
|
||||||
@ -63,7 +62,7 @@ contextMenuInit = function(){
|
|||||||
|
|
||||||
function appendContextMenuOption(targetElementSelector,entryName,entryFunction){
|
function appendContextMenuOption(targetElementSelector,entryName,entryFunction){
|
||||||
|
|
||||||
currentItems = menuSpecs.get(targetElementSelector)
|
var currentItems = menuSpecs.get(targetElementSelector)
|
||||||
|
|
||||||
if(!currentItems){
|
if(!currentItems){
|
||||||
currentItems = []
|
currentItems = []
|
||||||
@ -79,7 +78,7 @@ contextMenuInit = function(){
|
|||||||
}
|
}
|
||||||
|
|
||||||
function removeContextMenuOption(uid){
|
function removeContextMenuOption(uid){
|
||||||
menuSpecs.forEach(function(v,k) {
|
menuSpecs.forEach(function(v) {
|
||||||
let index = -1
|
let index = -1
|
||||||
v.forEach(function(e,ei){if(e['id']==uid){index=ei}})
|
v.forEach(function(e,ei){if(e['id']==uid){index=ei}})
|
||||||
if(index>=0){
|
if(index>=0){
|
||||||
@ -93,8 +92,7 @@ contextMenuInit = function(){
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
gradioApp().addEventListener("click", function(e) {
|
gradioApp().addEventListener("click", function(e) {
|
||||||
let source = e.composedPath()[0]
|
if(! e.isTrusted){
|
||||||
if(source.id && source.id.indexOf('check_progress')>-1){
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -112,7 +110,6 @@ contextMenuInit = function(){
|
|||||||
if(e.composedPath()[0].matches(k)){
|
if(e.composedPath()[0].matches(k)){
|
||||||
showContextMenu(e,e.composedPath()[0],v)
|
showContextMenu(e,e.composedPath()[0],v)
|
||||||
e.preventDefault()
|
e.preventDefault()
|
||||||
return
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
});
|
});
|
||||||
|
@ -69,8 +69,8 @@ function keyupEditAttention(event){
|
|||||||
|
|
||||||
event.preventDefault();
|
event.preventDefault();
|
||||||
|
|
||||||
closeCharacter = ')'
|
var closeCharacter = ')'
|
||||||
delta = opts.keyedit_precision_attention
|
var delta = opts.keyedit_precision_attention
|
||||||
|
|
||||||
if (selectionStart > 0 && text[selectionStart - 1] == '<'){
|
if (selectionStart > 0 && text[selectionStart - 1] == '<'){
|
||||||
closeCharacter = '>'
|
closeCharacter = '>'
|
||||||
@ -91,8 +91,8 @@ function keyupEditAttention(event){
|
|||||||
selectionEnd += 1;
|
selectionEnd += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
end = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1;
|
var end = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1;
|
||||||
weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + 1 + end));
|
var weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + 1 + end));
|
||||||
if (isNaN(weight)) return;
|
if (isNaN(weight)) return;
|
||||||
|
|
||||||
weight += isPlus ? delta : -delta;
|
weight += isPlus ? delta : -delta;
|
||||||
|
@ -1,14 +1,14 @@
|
|||||||
|
|
||||||
function extensions_apply(_, _, disable_all){
|
function extensions_apply(_disabled_list, _update_list, disable_all){
|
||||||
var disable = []
|
var disable = []
|
||||||
var update = []
|
var update = []
|
||||||
|
|
||||||
gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
|
gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
|
||||||
if(x.name.startsWith("enable_") && ! x.checked)
|
if(x.name.startsWith("enable_") && ! x.checked)
|
||||||
disable.push(x.name.substr(7))
|
disable.push(x.name.substring(7))
|
||||||
|
|
||||||
if(x.name.startsWith("update_") && x.checked)
|
if(x.name.startsWith("update_") && x.checked)
|
||||||
update.push(x.name.substr(7))
|
update.push(x.name.substring(7))
|
||||||
})
|
})
|
||||||
|
|
||||||
restart_reload()
|
restart_reload()
|
||||||
@ -16,12 +16,12 @@ function extensions_apply(_, _, disable_all){
|
|||||||
return [JSON.stringify(disable), JSON.stringify(update), disable_all]
|
return [JSON.stringify(disable), JSON.stringify(update), disable_all]
|
||||||
}
|
}
|
||||||
|
|
||||||
function extensions_check(_, _){
|
function extensions_check(){
|
||||||
var disable = []
|
var disable = []
|
||||||
|
|
||||||
gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
|
gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
|
||||||
if(x.name.startsWith("enable_") && ! x.checked)
|
if(x.name.startsWith("enable_") && ! x.checked)
|
||||||
disable.push(x.name.substr(7))
|
disable.push(x.name.substring(7))
|
||||||
})
|
})
|
||||||
|
|
||||||
gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){
|
gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){
|
||||||
@ -41,7 +41,7 @@ function install_extension_from_index(button, url){
|
|||||||
button.disabled = "disabled"
|
button.disabled = "disabled"
|
||||||
button.value = "Installing..."
|
button.value = "Installing..."
|
||||||
|
|
||||||
textarea = gradioApp().querySelector('#extension_to_install textarea')
|
var textarea = gradioApp().querySelector('#extension_to_install textarea')
|
||||||
textarea.value = url
|
textarea.value = url
|
||||||
updateInput(textarea)
|
updateInput(textarea)
|
||||||
|
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
|
|
||||||
function setupExtraNetworksForTab(tabname){
|
function setupExtraNetworksForTab(tabname){
|
||||||
gradioApp().querySelector('#'+tabname+'_extra_tabs').classList.add('extra-networks')
|
gradioApp().querySelector('#'+tabname+'_extra_tabs').classList.add('extra-networks')
|
||||||
|
|
||||||
@ -10,16 +9,34 @@ function setupExtraNetworksForTab(tabname){
|
|||||||
tabs.appendChild(search)
|
tabs.appendChild(search)
|
||||||
tabs.appendChild(refresh)
|
tabs.appendChild(refresh)
|
||||||
|
|
||||||
search.addEventListener("input", function(evt){
|
var applyFilter = function(){
|
||||||
searchTerm = search.value.toLowerCase()
|
var searchTerm = search.value.toLowerCase()
|
||||||
|
|
||||||
gradioApp().querySelectorAll('#'+tabname+'_extra_tabs div.card').forEach(function(elem){
|
gradioApp().querySelectorAll('#'+tabname+'_extra_tabs div.card').forEach(function(elem){
|
||||||
text = elem.querySelector('.name').textContent.toLowerCase() + " " + elem.querySelector('.search_term').textContent.toLowerCase()
|
var searchOnly = elem.querySelector('.search_only')
|
||||||
elem.style.display = text.indexOf(searchTerm) == -1 ? "none" : ""
|
var text = elem.querySelector('.name').textContent.toLowerCase() + " " + elem.querySelector('.search_term').textContent.toLowerCase()
|
||||||
|
|
||||||
|
var visible = text.indexOf(searchTerm) != -1
|
||||||
|
|
||||||
|
if(searchOnly && searchTerm.length < 4){
|
||||||
|
visible = false
|
||||||
|
}
|
||||||
|
|
||||||
|
elem.style.display = visible ? "" : "none"
|
||||||
})
|
})
|
||||||
});
|
}
|
||||||
|
|
||||||
|
search.addEventListener("input", applyFilter);
|
||||||
|
applyFilter();
|
||||||
|
|
||||||
|
extraNetworksApplyFilter[tabname] = applyFilter;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function applyExtraNetworkFilter(tabname){
|
||||||
|
setTimeout(extraNetworksApplyFilter[tabname], 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
var extraNetworksApplyFilter = {}
|
||||||
var activePromptTextarea = {};
|
var activePromptTextarea = {};
|
||||||
|
|
||||||
function setupExtraNetworks(){
|
function setupExtraNetworks(){
|
||||||
@ -55,7 +72,7 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text){
|
|||||||
|
|
||||||
var partToSearch = m[1]
|
var partToSearch = m[1]
|
||||||
var replaced = false
|
var replaced = false
|
||||||
var newTextareaText = textarea.value.replaceAll(re_extranet_g, function(found, index){
|
var newTextareaText = textarea.value.replaceAll(re_extranet_g, function(found){
|
||||||
m = found.match(re_extranet);
|
m = found.match(re_extranet);
|
||||||
if(m[1] == partToSearch){
|
if(m[1] == partToSearch){
|
||||||
replaced = true;
|
replaced = true;
|
||||||
@ -96,9 +113,9 @@ function saveCardPreview(event, tabname, filename){
|
|||||||
}
|
}
|
||||||
|
|
||||||
function extraNetworksSearchButton(tabs_id, event){
|
function extraNetworksSearchButton(tabs_id, event){
|
||||||
searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > div > textarea')
|
var searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > div > textarea')
|
||||||
button = event.target
|
var button = event.target
|
||||||
text = button.classList.contains("search-all") ? "" : button.textContent.trim()
|
var text = button.classList.contains("search-all") ? "" : button.textContent.trim()
|
||||||
|
|
||||||
searchTextarea.value = text
|
searchTextarea.value = text
|
||||||
updateInput(searchTextarea)
|
updateInput(searchTextarea)
|
||||||
@ -133,7 +150,7 @@ function popup(contents){
|
|||||||
}
|
}
|
||||||
|
|
||||||
function extraNetworksShowMetadata(text){
|
function extraNetworksShowMetadata(text){
|
||||||
elem = document.createElement('pre')
|
var elem = document.createElement('pre')
|
||||||
elem.classList.add('popup-metadata');
|
elem.classList.add('popup-metadata');
|
||||||
elem.textContent = text;
|
elem.textContent = text;
|
||||||
|
|
||||||
@ -165,7 +182,7 @@ function requestGet(url, data, handler, errorHandler){
|
|||||||
}
|
}
|
||||||
|
|
||||||
function extraNetworksRequestMetadata(event, extraPage, cardName){
|
function extraNetworksRequestMetadata(event, extraPage, cardName){
|
||||||
showError = function(){ extraNetworksShowMetadata("there was an error getting metadata"); }
|
var showError = function(){ extraNetworksShowMetadata("there was an error getting metadata"); }
|
||||||
|
|
||||||
requestGet("./sd_extra_networks/metadata", {"page": extraPage, "item": cardName}, function(data){
|
requestGet("./sd_extra_networks/metadata", {"page": extraPage, "item": cardName}, function(data){
|
||||||
if(data && data.metadata){
|
if(data && data.metadata){
|
||||||
|
@ -23,7 +23,7 @@ let modalObserver = new MutationObserver(function(mutations) {
|
|||||||
});
|
});
|
||||||
|
|
||||||
function attachGalleryListeners(tab_name) {
|
function attachGalleryListeners(tab_name) {
|
||||||
gallery = gradioApp().querySelector('#'+tab_name+'_gallery')
|
var gallery = gradioApp().querySelector('#'+tab_name+'_gallery')
|
||||||
gallery?.addEventListener('click', () => gradioApp().getElementById(tab_name+"_generation_info_button").click());
|
gallery?.addEventListener('click', () => gradioApp().getElementById(tab_name+"_generation_info_button").click());
|
||||||
gallery?.addEventListener('keydown', (e) => {
|
gallery?.addEventListener('keydown', (e) => {
|
||||||
if (e.keyCode == 37 || e.keyCode == 39) // left or right arrow
|
if (e.keyCode == 37 || e.keyCode == 39) // left or right arrow
|
||||||
|
@ -66,8 +66,8 @@ titles = {
|
|||||||
|
|
||||||
"Interrogate": "Reconstruct prompt from existing image and put it into the prompt field.",
|
"Interrogate": "Reconstruct prompt from existing image and put it into the prompt field.",
|
||||||
|
|
||||||
"Images filename pattern": "Use following tags to define how filenames for images are chosen: [steps], [cfg], [clip_skip], [batch_number], [generation_number], [prompt_hash], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp], [hasprompt<prompt1|default><prompt2>..]; leave empty for default.",
|
"Images filename pattern": "Use following tags to define how filenames for images are chosen: [steps], [cfg], [denoising], [clip_skip], [batch_number], [generation_number], [prompt_hash], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp], [hasprompt<prompt1|default><prompt2>..]; leave empty for default.",
|
||||||
"Directory name pattern": "Use following tags to define how subdirectories for images and grids are chosen: [steps], [cfg], [clip_skip], [batch_number], [generation_number], [prompt_hash], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp], [hasprompt<prompt1|default><prompt2>..]; leave empty for default.",
|
"Directory name pattern": "Use following tags to define how subdirectories for images and grids are chosen: [steps], [cfg], [denoising], [clip_skip], [batch_number], [generation_number], [prompt_hash], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp], [hasprompt<prompt1|default><prompt2>..]; leave empty for default.",
|
||||||
"Max prompt words": "Set the maximum number of words to be used in the [prompt_words] option; ATTENTION: If the words are too long, they may exceed the maximum length of the file path that the system can handle",
|
"Max prompt words": "Set the maximum number of words to be used in the [prompt_words] option; ATTENTION: If the words are too long, they may exceed the maximum length of the file path that the system can handle",
|
||||||
|
|
||||||
"Loopback": "Performs img2img processing multiple times. Output images are used as input for the next loop.",
|
"Loopback": "Performs img2img processing multiple times. Output images are used as input for the next loop.",
|
||||||
@ -118,16 +118,18 @@ titles = {
|
|||||||
|
|
||||||
onUiUpdate(function(){
|
onUiUpdate(function(){
|
||||||
gradioApp().querySelectorAll('span, button, select, p').forEach(function(span){
|
gradioApp().querySelectorAll('span, button, select, p').forEach(function(span){
|
||||||
tooltip = titles[span.textContent];
|
if (span.title) return; // already has a title
|
||||||
|
|
||||||
if(!tooltip){
|
let tooltip = localization[titles[span.textContent]] || titles[span.textContent];
|
||||||
tooltip = titles[span.value];
|
|
||||||
|
if(!tooltip){
|
||||||
|
tooltip = localization[titles[span.value]] || titles[span.value];
|
||||||
}
|
}
|
||||||
|
|
||||||
if(!tooltip){
|
if(!tooltip){
|
||||||
for (const c of span.classList) {
|
for (const c of span.classList) {
|
||||||
if (c in titles) {
|
if (c in titles) {
|
||||||
tooltip = titles[c];
|
tooltip = localization[titles[c]] || titles[c];
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -142,7 +144,7 @@ onUiUpdate(function(){
|
|||||||
if (select.onchange != null) return;
|
if (select.onchange != null) return;
|
||||||
|
|
||||||
select.onchange = function(){
|
select.onchange = function(){
|
||||||
select.title = titles[select.value] || "";
|
select.title = localization[titles[select.value]] || titles[select.value] || "";
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
@ -1,16 +1,12 @@
|
|||||||
|
|
||||||
function setInactive(elem, inactive){
|
|
||||||
if(inactive){
|
|
||||||
elem.classList.add('inactive')
|
|
||||||
} else{
|
|
||||||
elem.classList.remove('inactive')
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
function onCalcResolutionHires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y){
|
function onCalcResolutionHires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y){
|
||||||
hrUpscaleBy = gradioApp().getElementById('txt2img_hr_scale')
|
function setInactive(elem, inactive){
|
||||||
hrResizeX = gradioApp().getElementById('txt2img_hr_resize_x')
|
elem.classList.toggle('inactive', !!inactive)
|
||||||
hrResizeY = gradioApp().getElementById('txt2img_hr_resize_y')
|
}
|
||||||
|
|
||||||
|
var hrUpscaleBy = gradioApp().getElementById('txt2img_hr_scale')
|
||||||
|
var hrResizeX = gradioApp().getElementById('txt2img_hr_resize_x')
|
||||||
|
var hrResizeY = gradioApp().getElementById('txt2img_hr_resize_y')
|
||||||
|
|
||||||
gradioApp().getElementById('txt2img_hires_fix_row2').style.display = opts.use_old_hires_fix_width_height ? "none" : ""
|
gradioApp().getElementById('txt2img_hires_fix_row2').style.display = opts.use_old_hires_fix_width_height ? "none" : ""
|
||||||
|
|
||||||
|
@ -2,11 +2,10 @@
|
|||||||
* temporary fix for https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/668
|
* temporary fix for https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/668
|
||||||
* @see https://github.com/gradio-app/gradio/issues/1721
|
* @see https://github.com/gradio-app/gradio/issues/1721
|
||||||
*/
|
*/
|
||||||
window.addEventListener( 'resize', () => imageMaskResize());
|
|
||||||
function imageMaskResize() {
|
function imageMaskResize() {
|
||||||
const canvases = gradioApp().querySelectorAll('#img2maskimg .touch-none canvas');
|
const canvases = gradioApp().querySelectorAll('#img2maskimg .touch-none canvas');
|
||||||
if ( ! canvases.length ) {
|
if ( ! canvases.length ) {
|
||||||
canvases_fixed = false;
|
canvases_fixed = false; // TODO: this is unused..?
|
||||||
window.removeEventListener( 'resize', imageMaskResize );
|
window.removeEventListener( 'resize', imageMaskResize );
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -15,7 +14,7 @@ function imageMaskResize() {
|
|||||||
const previewImage = wrapper.previousElementSibling;
|
const previewImage = wrapper.previousElementSibling;
|
||||||
|
|
||||||
if ( ! previewImage.complete ) {
|
if ( ! previewImage.complete ) {
|
||||||
previewImage.addEventListener( 'load', () => imageMaskResize());
|
previewImage.addEventListener( 'load', imageMaskResize);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -24,7 +23,6 @@ function imageMaskResize() {
|
|||||||
const nw = previewImage.naturalWidth;
|
const nw = previewImage.naturalWidth;
|
||||||
const nh = previewImage.naturalHeight;
|
const nh = previewImage.naturalHeight;
|
||||||
const portrait = nh > nw;
|
const portrait = nh > nw;
|
||||||
const factor = portrait;
|
|
||||||
|
|
||||||
const wW = Math.min(w, portrait ? h/nh*nw : w/nw*nw);
|
const wW = Math.min(w, portrait ? h/nh*nw : w/nw*nw);
|
||||||
const wH = Math.min(h, portrait ? h/nh*nh : w/nw*nh);
|
const wH = Math.min(h, portrait ? h/nh*nh : w/nw*nh);
|
||||||
@ -40,6 +38,7 @@ function imageMaskResize() {
|
|||||||
c.style.maxHeight = '100%';
|
c.style.maxHeight = '100%';
|
||||||
c.style.objectFit = 'contain';
|
c.style.objectFit = 'contain';
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
onUiUpdate(() => imageMaskResize());
|
onUiUpdate(imageMaskResize);
|
||||||
|
window.addEventListener( 'resize', imageMaskResize);
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
window.onload = (function(){
|
window.onload = (function(){
|
||||||
window.addEventListener('drop', e => {
|
window.addEventListener('drop', e => {
|
||||||
const target = e.composedPath()[0];
|
const target = e.composedPath()[0];
|
||||||
const idx = selected_gallery_index();
|
|
||||||
if (target.placeholder.indexOf("Prompt") == -1) return;
|
if (target.placeholder.indexOf("Prompt") == -1) return;
|
||||||
|
|
||||||
let prompt_target = get_tab_index('tabs') == 1 ? "img2img_prompt_image" : "txt2img_prompt_image";
|
let prompt_target = get_tab_index('tabs') == 1 ? "img2img_prompt_image" : "txt2img_prompt_image";
|
||||||
|
@ -57,7 +57,7 @@ function modalImageSwitch(offset) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
if (result != -1) {
|
if (result != -1) {
|
||||||
nextButton = galleryButtons[negmod((result + offset), galleryButtons.length)]
|
var nextButton = galleryButtons[negmod((result + offset), galleryButtons.length)]
|
||||||
nextButton.click()
|
nextButton.click()
|
||||||
const modalImage = gradioApp().getElementById("modalImage");
|
const modalImage = gradioApp().getElementById("modalImage");
|
||||||
const modal = gradioApp().getElementById("lightboxModal");
|
const modal = gradioApp().getElementById("lightboxModal");
|
||||||
@ -144,15 +144,11 @@ function setupImageForLightbox(e) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
function modalZoomSet(modalImage, enable) {
|
function modalZoomSet(modalImage, enable) {
|
||||||
if (enable) {
|
if(modalImage) modalImage.classList.toggle('modalImageFullscreen', !!enable);
|
||||||
modalImage.classList.add('modalImageFullscreen');
|
|
||||||
} else {
|
|
||||||
modalImage.classList.remove('modalImageFullscreen');
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
function modalZoomToggle(event) {
|
function modalZoomToggle(event) {
|
||||||
modalImage = gradioApp().getElementById("modalImage");
|
var modalImage = gradioApp().getElementById("modalImage");
|
||||||
modalZoomSet(modalImage, !modalImage.classList.contains('modalImageFullscreen'))
|
modalZoomSet(modalImage, !modalImage.classList.contains('modalImageFullscreen'))
|
||||||
event.stopPropagation()
|
event.stopPropagation()
|
||||||
}
|
}
|
||||||
@ -179,7 +175,7 @@ function galleryImageHandler(e) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
onUiUpdate(function() {
|
onUiUpdate(function() {
|
||||||
fullImg_preview = gradioApp().querySelectorAll('.gradio-gallery > div > img')
|
var fullImg_preview = gradioApp().querySelectorAll('.gradio-gallery > div > img')
|
||||||
if (fullImg_preview != null) {
|
if (fullImg_preview != null) {
|
||||||
fullImg_preview.forEach(setupImageForLightbox);
|
fullImg_preview.forEach(setupImageForLightbox);
|
||||||
}
|
}
|
||||||
|
@ -1,36 +1,57 @@
|
|||||||
let delay = 350//ms
|
window.addEventListener('gamepadconnected', (e) => {
|
||||||
window.addEventListener('gamepadconnected', (e) => {
|
const index = e.gamepad.index;
|
||||||
console.log("Gamepad connected!")
|
let isWaiting = false;
|
||||||
const gamepad = e.gamepad;
|
setInterval(async () => {
|
||||||
setInterval(() => {
|
if (!opts.js_modal_lightbox_gamepad || isWaiting) return;
|
||||||
const xValue = gamepad.axes[0].toFixed(2);
|
const gamepad = navigator.getGamepads()[index];
|
||||||
if (xValue < -0.3) {
|
const xValue = gamepad.axes[0];
|
||||||
modalPrevImage(e);
|
if (xValue <= -0.3) {
|
||||||
} else if (xValue > 0.3) {
|
|
||||||
modalNextImage(e);
|
|
||||||
}
|
|
||||||
|
|
||||||
}, delay);
|
|
||||||
});
|
|
||||||
|
|
||||||
|
|
||||||
/*
|
|
||||||
Primarily for vr controller type pointer devices.
|
|
||||||
I use the wheel event because there's currently no way to do it properly with web xr.
|
|
||||||
*/
|
|
||||||
|
|
||||||
let isScrolling = false;
|
|
||||||
window.addEventListener('wheel', (e) => {
|
|
||||||
if (isScrolling) return;
|
|
||||||
isScrolling = true;
|
|
||||||
|
|
||||||
if (e.deltaX <= -0.6) {
|
|
||||||
modalPrevImage(e);
|
modalPrevImage(e);
|
||||||
} else if (e.deltaX >= 0.6) {
|
isWaiting = true;
|
||||||
|
} else if (xValue >= 0.3) {
|
||||||
modalNextImage(e);
|
modalNextImage(e);
|
||||||
|
isWaiting = true;
|
||||||
}
|
}
|
||||||
|
if (isWaiting) {
|
||||||
|
await sleepUntil(() => {
|
||||||
|
const xValue = navigator.getGamepads()[index].axes[0]
|
||||||
|
if (xValue < 0.3 && xValue > -0.3) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}, opts.js_modal_lightbox_gamepad_repeat);
|
||||||
|
isWaiting = false;
|
||||||
|
}
|
||||||
|
}, 10);
|
||||||
|
});
|
||||||
|
|
||||||
setTimeout(() => {
|
/*
|
||||||
isScrolling = false;
|
Primarily for vr controller type pointer devices.
|
||||||
}, delay);
|
I use the wheel event because there's currently no way to do it properly with web xr.
|
||||||
});
|
*/
|
||||||
|
let isScrolling = false;
|
||||||
|
window.addEventListener('wheel', (e) => {
|
||||||
|
if (!opts.js_modal_lightbox_gamepad || isScrolling) return;
|
||||||
|
isScrolling = true;
|
||||||
|
|
||||||
|
if (e.deltaX <= -0.6) {
|
||||||
|
modalPrevImage(e);
|
||||||
|
} else if (e.deltaX >= 0.6) {
|
||||||
|
modalNextImage(e);
|
||||||
|
}
|
||||||
|
|
||||||
|
setTimeout(() => {
|
||||||
|
isScrolling = false;
|
||||||
|
}, opts.js_modal_lightbox_gamepad_repeat);
|
||||||
|
});
|
||||||
|
|
||||||
|
function sleepUntil(f, timeout) {
|
||||||
|
return new Promise((resolve) => {
|
||||||
|
const timeStart = new Date();
|
||||||
|
const wait = setInterval(function() {
|
||||||
|
if (f() || new Date() - timeStart > timeout) {
|
||||||
|
clearInterval(wait);
|
||||||
|
resolve();
|
||||||
|
}
|
||||||
|
}, 20);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
@ -25,6 +25,10 @@ re_emoji = /[\p{Extended_Pictographic}\u{1F3FB}-\u{1F3FF}\u{1F9B0}-\u{1F9B3}]/u
|
|||||||
original_lines = {}
|
original_lines = {}
|
||||||
translated_lines = {}
|
translated_lines = {}
|
||||||
|
|
||||||
|
function hasLocalization() {
|
||||||
|
return window.localization && Object.keys(window.localization).length > 0;
|
||||||
|
}
|
||||||
|
|
||||||
function textNodesUnder(el){
|
function textNodesUnder(el){
|
||||||
var n, a=[], walk=document.createTreeWalker(el,NodeFilter.SHOW_TEXT,null,false);
|
var n, a=[], walk=document.createTreeWalker(el,NodeFilter.SHOW_TEXT,null,false);
|
||||||
while(n=walk.nextNode()) a.push(n);
|
while(n=walk.nextNode()) a.push(n);
|
||||||
@ -35,11 +39,11 @@ function canBeTranslated(node, text){
|
|||||||
if(! text) return false;
|
if(! text) return false;
|
||||||
if(! node.parentElement) return false;
|
if(! node.parentElement) return false;
|
||||||
|
|
||||||
parentType = node.parentElement.nodeName
|
var parentType = node.parentElement.nodeName
|
||||||
if(parentType=='SCRIPT' || parentType=='STYLE' || parentType=='TEXTAREA') return false;
|
if(parentType=='SCRIPT' || parentType=='STYLE' || parentType=='TEXTAREA') return false;
|
||||||
|
|
||||||
if (parentType=='OPTION' || parentType=='SPAN'){
|
if (parentType=='OPTION' || parentType=='SPAN'){
|
||||||
pnode = node
|
var pnode = node
|
||||||
for(var level=0; level<4; level++){
|
for(var level=0; level<4; level++){
|
||||||
pnode = pnode.parentElement
|
pnode = pnode.parentElement
|
||||||
if(! pnode) break;
|
if(! pnode) break;
|
||||||
@ -69,7 +73,7 @@ function getTranslation(text){
|
|||||||
}
|
}
|
||||||
|
|
||||||
function processTextNode(node){
|
function processTextNode(node){
|
||||||
text = node.textContent.trim()
|
var text = node.textContent.trim()
|
||||||
|
|
||||||
if(! canBeTranslated(node, text)) return
|
if(! canBeTranslated(node, text)) return
|
||||||
|
|
||||||
@ -105,30 +109,52 @@ function processNode(node){
|
|||||||
}
|
}
|
||||||
|
|
||||||
function dumpTranslations(){
|
function dumpTranslations(){
|
||||||
dumped = {}
|
if(!hasLocalization()) {
|
||||||
|
// If we don't have any localization,
|
||||||
|
// we will not have traversed the app to find
|
||||||
|
// original_lines, so do that now.
|
||||||
|
processNode(gradioApp());
|
||||||
|
}
|
||||||
|
var dumped = {}
|
||||||
if (localization.rtl) {
|
if (localization.rtl) {
|
||||||
dumped.rtl = true
|
dumped.rtl = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
Object.keys(original_lines).forEach(function(text){
|
for (const text in original_lines) {
|
||||||
if(dumped[text] !== undefined) return
|
if(dumped[text] !== undefined) continue;
|
||||||
|
dumped[text] = localization[text] || text;
|
||||||
|
}
|
||||||
|
|
||||||
dumped[text] = localization[text] || text
|
return dumped;
|
||||||
})
|
|
||||||
|
|
||||||
return dumped
|
|
||||||
}
|
}
|
||||||
|
|
||||||
onUiUpdate(function(m){
|
function download_localization() {
|
||||||
m.forEach(function(mutation){
|
var text = JSON.stringify(dumpTranslations(), null, 4)
|
||||||
mutation.addedNodes.forEach(function(node){
|
|
||||||
processNode(node)
|
|
||||||
})
|
|
||||||
});
|
|
||||||
})
|
|
||||||
|
|
||||||
|
var element = document.createElement('a');
|
||||||
|
element.setAttribute('href', 'data:text/plain;charset=utf-8,' + encodeURIComponent(text));
|
||||||
|
element.setAttribute('download', "localization.json");
|
||||||
|
element.style.display = 'none';
|
||||||
|
document.body.appendChild(element);
|
||||||
|
|
||||||
|
element.click();
|
||||||
|
|
||||||
|
document.body.removeChild(element);
|
||||||
|
}
|
||||||
|
|
||||||
|
document.addEventListener("DOMContentLoaded", function () {
|
||||||
|
if (!hasLocalization()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
onUiUpdate(function (m) {
|
||||||
|
m.forEach(function (mutation) {
|
||||||
|
mutation.addedNodes.forEach(function (node) {
|
||||||
|
processNode(node)
|
||||||
|
})
|
||||||
|
});
|
||||||
|
})
|
||||||
|
|
||||||
document.addEventListener("DOMContentLoaded", function() {
|
|
||||||
processNode(gradioApp())
|
processNode(gradioApp())
|
||||||
|
|
||||||
if (localization.rtl) { // if the language is from right to left,
|
if (localization.rtl) { // if the language is from right to left,
|
||||||
@ -149,17 +175,3 @@ document.addEventListener("DOMContentLoaded", function() {
|
|||||||
})).observe(gradioApp(), { childList: true });
|
})).observe(gradioApp(), { childList: true });
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
function download_localization() {
|
|
||||||
text = JSON.stringify(dumpTranslations(), null, 4)
|
|
||||||
|
|
||||||
var element = document.createElement('a');
|
|
||||||
element.setAttribute('href', 'data:text/plain;charset=utf-8,' + encodeURIComponent(text));
|
|
||||||
element.setAttribute('download', "localization.json");
|
|
||||||
element.style.display = 'none';
|
|
||||||
document.body.appendChild(element);
|
|
||||||
|
|
||||||
element.click();
|
|
||||||
|
|
||||||
document.body.removeChild(element);
|
|
||||||
}
|
|
||||||
|
@ -2,15 +2,15 @@
|
|||||||
|
|
||||||
let lastHeadImg = null;
|
let lastHeadImg = null;
|
||||||
|
|
||||||
notificationButton = null
|
let notificationButton = null;
|
||||||
|
|
||||||
onUiUpdate(function(){
|
onUiUpdate(function(){
|
||||||
if(notificationButton == null){
|
if(notificationButton == null){
|
||||||
notificationButton = gradioApp().getElementById('request_notifications')
|
notificationButton = gradioApp().getElementById('request_notifications')
|
||||||
|
|
||||||
if(notificationButton != null){
|
if(notificationButton != null){
|
||||||
notificationButton.addEventListener('click', function (evt) {
|
notificationButton.addEventListener('click', () => {
|
||||||
Notification.requestPermission();
|
void Notification.requestPermission();
|
||||||
},true);
|
},true);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,16 +1,15 @@
|
|||||||
// code related to showing and updating progressbar shown as the image is being made
|
// code related to showing and updating progressbar shown as the image is being made
|
||||||
|
|
||||||
function rememberGallerySelection(id_gallery){
|
function rememberGallerySelection(){
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
function getGallerySelectedIndex(id_gallery){
|
function getGallerySelectedIndex(){
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
function request(url, data, handler, errorHandler){
|
function request(url, data, handler, errorHandler){
|
||||||
var xhr = new XMLHttpRequest();
|
var xhr = new XMLHttpRequest();
|
||||||
var url = url;
|
|
||||||
xhr.open("POST", url, true);
|
xhr.open("POST", url, true);
|
||||||
xhr.setRequestHeader("Content-Type", "application/json");
|
xhr.setRequestHeader("Content-Type", "application/json");
|
||||||
xhr.onreadystatechange = function () {
|
xhr.onreadystatechange = function () {
|
||||||
@ -107,7 +106,7 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
|
|||||||
divProgress.style.width = rect.width + "px";
|
divProgress.style.width = rect.width + "px";
|
||||||
}
|
}
|
||||||
|
|
||||||
progressText = ""
|
let progressText = ""
|
||||||
|
|
||||||
divInner.style.width = ((res.progress || 0) * 100.0) + '%'
|
divInner.style.width = ((res.progress || 0) * 100.0) + '%'
|
||||||
divInner.style.background = res.progress ? "" : "transparent"
|
divInner.style.background = res.progress ? "" : "transparent"
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
// various functions for interaction with ui.py not large enough to warrant putting them in separate files
|
// various functions for interaction with ui.py not large enough to warrant putting them in separate files
|
||||||
|
|
||||||
function set_theme(theme){
|
function set_theme(theme){
|
||||||
gradioURL = window.location.href
|
var gradioURL = window.location.href
|
||||||
if (!gradioURL.includes('?__theme=')) {
|
if (!gradioURL.includes('?__theme=')) {
|
||||||
window.location.replace(gradioURL + '?__theme=' + theme);
|
window.location.replace(gradioURL + '?__theme=' + theme);
|
||||||
}
|
}
|
||||||
@ -47,7 +47,7 @@ function extract_image_from_gallery(gallery){
|
|||||||
return [gallery[0]];
|
return [gallery[0]];
|
||||||
}
|
}
|
||||||
|
|
||||||
index = selected_gallery_index()
|
var index = selected_gallery_index()
|
||||||
|
|
||||||
if (index < 0 || index >= gallery.length){
|
if (index < 0 || index >= gallery.length){
|
||||||
// Use the first image in the gallery as the default
|
// Use the first image in the gallery as the default
|
||||||
@ -58,7 +58,7 @@ function extract_image_from_gallery(gallery){
|
|||||||
}
|
}
|
||||||
|
|
||||||
function args_to_array(args){
|
function args_to_array(args){
|
||||||
res = []
|
var res = []
|
||||||
for(var i=0;i<args.length;i++){
|
for(var i=0;i<args.length;i++){
|
||||||
res.push(args[i])
|
res.push(args[i])
|
||||||
}
|
}
|
||||||
@ -138,7 +138,7 @@ function get_img2img_tab_index() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
function create_submit_args(args){
|
function create_submit_args(args){
|
||||||
res = []
|
var res = []
|
||||||
for(var i=0;i<args.length;i++){
|
for(var i=0;i<args.length;i++){
|
||||||
res.push(args[i])
|
res.push(args[i])
|
||||||
}
|
}
|
||||||
@ -160,7 +160,7 @@ function showSubmitButtons(tabname, show){
|
|||||||
}
|
}
|
||||||
|
|
||||||
function showRestoreProgressButton(tabname, show){
|
function showRestoreProgressButton(tabname, show){
|
||||||
button = gradioApp().getElementById(tabname + "_restore_progress")
|
var button = gradioApp().getElementById(tabname + "_restore_progress")
|
||||||
if(! button) return
|
if(! button) return
|
||||||
|
|
||||||
button.style.display = show ? "flex" : "none"
|
button.style.display = show ? "flex" : "none"
|
||||||
@ -207,8 +207,9 @@ function submit_img2img(){
|
|||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
function restoreProgressTxt2img(x){
|
function restoreProgressTxt2img(){
|
||||||
showRestoreProgressButton("txt2img", false)
|
showRestoreProgressButton("txt2img", false)
|
||||||
|
var id = localStorage.getItem("txt2img_task_id")
|
||||||
|
|
||||||
id = localStorage.getItem("txt2img_task_id")
|
id = localStorage.getItem("txt2img_task_id")
|
||||||
|
|
||||||
@ -220,10 +221,11 @@ function restoreProgressTxt2img(x){
|
|||||||
|
|
||||||
return id
|
return id
|
||||||
}
|
}
|
||||||
function restoreProgressImg2img(x){
|
|
||||||
showRestoreProgressButton("img2img", false)
|
|
||||||
|
|
||||||
id = localStorage.getItem("img2img_task_id")
|
function restoreProgressImg2img(){
|
||||||
|
showRestoreProgressButton("img2img", false)
|
||||||
|
|
||||||
|
var id = localStorage.getItem("img2img_task_id")
|
||||||
|
|
||||||
if(id) {
|
if(id) {
|
||||||
requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function(){
|
requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function(){
|
||||||
@ -252,7 +254,7 @@ function modelmerger(){
|
|||||||
|
|
||||||
|
|
||||||
function ask_for_style_name(_, prompt_text, negative_prompt_text) {
|
function ask_for_style_name(_, prompt_text, negative_prompt_text) {
|
||||||
name_ = prompt('Style name:')
|
var name_ = prompt('Style name:')
|
||||||
return [name_, prompt_text, negative_prompt_text]
|
return [name_, prompt_text, negative_prompt_text]
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -287,11 +289,11 @@ function recalculate_prompts_img2img(){
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
opts = {}
|
var opts = {}
|
||||||
onUiUpdate(function(){
|
onUiUpdate(function(){
|
||||||
if(Object.keys(opts).length != 0) return;
|
if(Object.keys(opts).length != 0) return;
|
||||||
|
|
||||||
json_elem = gradioApp().getElementById('settings_json')
|
var json_elem = gradioApp().getElementById('settings_json')
|
||||||
if(json_elem == null) return;
|
if(json_elem == null) return;
|
||||||
|
|
||||||
var textarea = json_elem.querySelector('textarea')
|
var textarea = json_elem.querySelector('textarea')
|
||||||
@ -340,12 +342,15 @@ onUiUpdate(function(){
|
|||||||
registerTextarea('img2img_prompt', 'img2img_token_counter', 'img2img_token_button')
|
registerTextarea('img2img_prompt', 'img2img_token_counter', 'img2img_token_button')
|
||||||
registerTextarea('img2img_neg_prompt', 'img2img_negative_token_counter', 'img2img_negative_token_button')
|
registerTextarea('img2img_neg_prompt', 'img2img_negative_token_counter', 'img2img_negative_token_button')
|
||||||
|
|
||||||
show_all_pages = gradioApp().getElementById('settings_show_all_pages')
|
var show_all_pages = gradioApp().getElementById('settings_show_all_pages')
|
||||||
settings_tabs = gradioApp().querySelector('#settings div')
|
var settings_tabs = gradioApp().querySelector('#settings div')
|
||||||
if(show_all_pages && settings_tabs){
|
if(show_all_pages && settings_tabs){
|
||||||
settings_tabs.appendChild(show_all_pages)
|
settings_tabs.appendChild(show_all_pages)
|
||||||
show_all_pages.onclick = function(){
|
show_all_pages.onclick = function(){
|
||||||
gradioApp().querySelectorAll('#settings > div').forEach(function(elem){
|
gradioApp().querySelectorAll('#settings > div').forEach(function(elem){
|
||||||
|
if(elem.id == "settings_tab_licenses")
|
||||||
|
return;
|
||||||
|
|
||||||
elem.style.display = "block";
|
elem.style.display = "block";
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -353,9 +358,9 @@ onUiUpdate(function(){
|
|||||||
})
|
})
|
||||||
|
|
||||||
onOptionsChanged(function(){
|
onOptionsChanged(function(){
|
||||||
elem = gradioApp().getElementById('sd_checkpoint_hash')
|
var elem = gradioApp().getElementById('sd_checkpoint_hash')
|
||||||
sd_checkpoint_hash = opts.sd_checkpoint_hash || ""
|
var sd_checkpoint_hash = opts.sd_checkpoint_hash || ""
|
||||||
shorthash = sd_checkpoint_hash.substr(0,10)
|
var shorthash = sd_checkpoint_hash.substring(0,10)
|
||||||
|
|
||||||
if(elem && elem.textContent != shorthash){
|
if(elem && elem.textContent != shorthash){
|
||||||
elem.textContent = shorthash
|
elem.textContent = shorthash
|
||||||
@ -390,7 +395,16 @@ function update_token_counter(button_id) {
|
|||||||
|
|
||||||
function restart_reload(){
|
function restart_reload(){
|
||||||
document.body.innerHTML='<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>';
|
document.body.innerHTML='<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>';
|
||||||
setTimeout(function(){location.reload()},2000)
|
|
||||||
|
var requestPing = function(){
|
||||||
|
requestGet("./internal/ping", {}, function(data){
|
||||||
|
location.reload();
|
||||||
|
}, function(){
|
||||||
|
setTimeout(requestPing, 500);
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
setTimeout(requestPing, 2000);
|
||||||
|
|
||||||
return []
|
return []
|
||||||
}
|
}
|
||||||
|
62
javascript/ui_settings_hints.js
Normal file
62
javascript/ui_settings_hints.js
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
// various hints and extra info for the settings tab
|
||||||
|
|
||||||
|
settingsHintsSetup = false
|
||||||
|
|
||||||
|
onOptionsChanged(function(){
|
||||||
|
if(settingsHintsSetup) return
|
||||||
|
settingsHintsSetup = true
|
||||||
|
|
||||||
|
gradioApp().querySelectorAll('#settings [id^=setting_]').forEach(function(div){
|
||||||
|
var name = div.id.substr(8)
|
||||||
|
var commentBefore = opts._comments_before[name]
|
||||||
|
var commentAfter = opts._comments_after[name]
|
||||||
|
|
||||||
|
if(! commentBefore && !commentAfter) return
|
||||||
|
|
||||||
|
var span = null
|
||||||
|
if(div.classList.contains('gradio-checkbox')) span = div.querySelector('label span')
|
||||||
|
else if(div.classList.contains('gradio-checkboxgroup')) span = div.querySelector('span').firstChild
|
||||||
|
else if(div.classList.contains('gradio-radio')) span = div.querySelector('span').firstChild
|
||||||
|
else span = div.querySelector('label span').firstChild
|
||||||
|
|
||||||
|
if(!span) return
|
||||||
|
|
||||||
|
if(commentBefore){
|
||||||
|
var comment = document.createElement('DIV')
|
||||||
|
comment.className = 'settings-comment'
|
||||||
|
comment.innerHTML = commentBefore
|
||||||
|
span.parentElement.insertBefore(document.createTextNode('\xa0'), span)
|
||||||
|
span.parentElement.insertBefore(comment, span)
|
||||||
|
span.parentElement.insertBefore(document.createTextNode('\xa0'), span)
|
||||||
|
}
|
||||||
|
if(commentAfter){
|
||||||
|
var comment = document.createElement('DIV')
|
||||||
|
comment.className = 'settings-comment'
|
||||||
|
comment.innerHTML = commentAfter
|
||||||
|
span.parentElement.insertBefore(comment, span.nextSibling)
|
||||||
|
span.parentElement.insertBefore(document.createTextNode('\xa0'), span.nextSibling)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
function settingsHintsShowQuicksettings(){
|
||||||
|
requestGet("./internal/quicksettings-hint", {}, function(data){
|
||||||
|
var table = document.createElement('table')
|
||||||
|
table.className = 'settings-value-table'
|
||||||
|
|
||||||
|
data.forEach(function(obj){
|
||||||
|
var tr = document.createElement('tr')
|
||||||
|
var td = document.createElement('td')
|
||||||
|
td.textContent = obj.name
|
||||||
|
tr.appendChild(td)
|
||||||
|
|
||||||
|
var td = document.createElement('td')
|
||||||
|
td.textContent = obj.label
|
||||||
|
tr.appendChild(td)
|
||||||
|
|
||||||
|
table.appendChild(tr)
|
||||||
|
})
|
||||||
|
|
||||||
|
popup(table);
|
||||||
|
})
|
||||||
|
}
|
105
launch.py
105
launch.py
@ -3,24 +3,23 @@ import subprocess
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import importlib.util
|
import importlib.util
|
||||||
import shlex
|
|
||||||
import platform
|
import platform
|
||||||
import json
|
import json
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
from modules import cmd_args
|
from modules import cmd_args
|
||||||
from modules.paths_internal import script_path, extensions_dir
|
from modules.paths_internal import script_path, extensions_dir
|
||||||
|
|
||||||
commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
|
|
||||||
sys.argv += shlex.split(commandline_args)
|
|
||||||
|
|
||||||
args, _ = cmd_args.parser.parse_known_args()
|
args, _ = cmd_args.parser.parse_known_args()
|
||||||
|
|
||||||
python = sys.executable
|
python = sys.executable
|
||||||
git = os.environ.get('GIT', "git")
|
git = os.environ.get('GIT', "git")
|
||||||
index_url = os.environ.get('INDEX_URL', "")
|
index_url = os.environ.get('INDEX_URL', "")
|
||||||
stored_commit_hash = None
|
|
||||||
dir_repos = "repositories"
|
dir_repos = "repositories"
|
||||||
|
|
||||||
|
# Whether to default to printing command output
|
||||||
|
default_command_live = (os.environ.get('WEBUI_LAUNCH_LIVE_OUTPUT') == "1")
|
||||||
|
|
||||||
if 'GRADIO_ANALYTICS_ENABLED' not in os.environ:
|
if 'GRADIO_ANALYTICS_ENABLED' not in os.environ:
|
||||||
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
|
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
|
||||||
|
|
||||||
@ -56,51 +55,52 @@ Use --skip-python-version-check to suppress this warning.
|
|||||||
""")
|
""")
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
def commit_hash():
|
def commit_hash():
|
||||||
global stored_commit_hash
|
|
||||||
|
|
||||||
if stored_commit_hash is not None:
|
|
||||||
return stored_commit_hash
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
stored_commit_hash = run(f"{git} rev-parse HEAD").strip()
|
return subprocess.check_output(f"{git} rev-parse HEAD", encoding='utf8').strip()
|
||||||
except Exception:
|
except Exception:
|
||||||
stored_commit_hash = "<none>"
|
return "<none>"
|
||||||
|
|
||||||
return stored_commit_hash
|
|
||||||
|
|
||||||
|
|
||||||
def run(command, desc=None, errdesc=None, custom_env=None, live=False):
|
@lru_cache()
|
||||||
|
def git_tag():
|
||||||
|
try:
|
||||||
|
return subprocess.check_output(f"{git} describe --tags", encoding='utf8').strip()
|
||||||
|
except Exception:
|
||||||
|
return "<none>"
|
||||||
|
|
||||||
|
|
||||||
|
def run(command, desc=None, errdesc=None, custom_env=None, live: bool = default_command_live) -> str:
|
||||||
if desc is not None:
|
if desc is not None:
|
||||||
print(desc)
|
print(desc)
|
||||||
|
|
||||||
if live:
|
run_kwargs = {
|
||||||
result = subprocess.run(command, shell=True, env=os.environ if custom_env is None else custom_env)
|
"args": command,
|
||||||
if result.returncode != 0:
|
"shell": True,
|
||||||
raise RuntimeError(f"""{errdesc or 'Error running command'}.
|
"env": os.environ if custom_env is None else custom_env,
|
||||||
Command: {command}
|
"encoding": 'utf8',
|
||||||
Error code: {result.returncode}""")
|
"errors": 'ignore',
|
||||||
|
}
|
||||||
|
|
||||||
return ""
|
if not live:
|
||||||
|
run_kwargs["stdout"] = run_kwargs["stderr"] = subprocess.PIPE
|
||||||
|
|
||||||
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, env=os.environ if custom_env is None else custom_env)
|
result = subprocess.run(**run_kwargs)
|
||||||
|
|
||||||
if result.returncode != 0:
|
if result.returncode != 0:
|
||||||
|
error_bits = [
|
||||||
|
f"{errdesc or 'Error running command'}.",
|
||||||
|
f"Command: {command}",
|
||||||
|
f"Error code: {result.returncode}",
|
||||||
|
]
|
||||||
|
if result.stdout:
|
||||||
|
error_bits.append(f"stdout: {result.stdout}")
|
||||||
|
if result.stderr:
|
||||||
|
error_bits.append(f"stderr: {result.stderr}")
|
||||||
|
raise RuntimeError("\n".join(error_bits))
|
||||||
|
|
||||||
message = f"""{errdesc or 'Error running command'}.
|
return (result.stdout or "")
|
||||||
Command: {command}
|
|
||||||
Error code: {result.returncode}
|
|
||||||
stdout: {result.stdout.decode(encoding="utf8", errors="ignore") if len(result.stdout)>0 else '<empty>'}
|
|
||||||
stderr: {result.stderr.decode(encoding="utf8", errors="ignore") if len(result.stderr)>0 else '<empty>'}
|
|
||||||
"""
|
|
||||||
raise RuntimeError(message)
|
|
||||||
|
|
||||||
return result.stdout.decode(encoding="utf8", errors="ignore")
|
|
||||||
|
|
||||||
|
|
||||||
def check_run(command):
|
|
||||||
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
|
|
||||||
return result.returncode == 0
|
|
||||||
|
|
||||||
|
|
||||||
def is_installed(package):
|
def is_installed(package):
|
||||||
@ -116,11 +116,7 @@ def repo_dir(name):
|
|||||||
return os.path.join(script_path, dir_repos, name)
|
return os.path.join(script_path, dir_repos, name)
|
||||||
|
|
||||||
|
|
||||||
def run_python(code, desc=None, errdesc=None):
|
def run_pip(command, desc=None, live=default_command_live):
|
||||||
return run(f'"{python}" -c "{code}"', desc, errdesc)
|
|
||||||
|
|
||||||
|
|
||||||
def run_pip(command, desc=None, live=False):
|
|
||||||
if args.skip_install:
|
if args.skip_install:
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -128,8 +124,9 @@ def run_pip(command, desc=None, live=False):
|
|||||||
return run(f'"{python}" -m pip {command} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}", live=live)
|
return run(f'"{python}" -m pip {command} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}", live=live)
|
||||||
|
|
||||||
|
|
||||||
def check_run_python(code):
|
def check_run_python(code: str) -> bool:
|
||||||
return check_run(f'"{python}" -c "{code}"')
|
result = subprocess.run([python, "-c", code], capture_output=True, shell=True)
|
||||||
|
return result.returncode == 0
|
||||||
|
|
||||||
|
|
||||||
def git_clone(url, dir, name, commithash=None):
|
def git_clone(url, dir, name, commithash=None):
|
||||||
@ -222,13 +219,14 @@ def run_extensions_installers(settings_file):
|
|||||||
|
|
||||||
|
|
||||||
def prepare_environment():
|
def prepare_environment():
|
||||||
torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==2.0.0 torchvision==0.15.1 --extra-index-url https://download.pytorch.org/whl/cu118")
|
torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118")
|
||||||
|
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}")
|
||||||
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
|
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
|
||||||
|
|
||||||
xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.17')
|
xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.17')
|
||||||
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379")
|
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "https://github.com/TencentARC/GFPGAN/archive/8d2447a2d918f8eba5a4a01463fd48e45126a379.zip")
|
||||||
clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1")
|
clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip")
|
||||||
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "git+https://github.com/mlfoundations/open_clip.git@bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b")
|
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip")
|
||||||
|
|
||||||
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
|
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
|
||||||
taming_transformers_repo = os.environ.get('TAMING_TRANSFORMERS_REPO', "https://github.com/CompVis/taming-transformers.git")
|
taming_transformers_repo = os.environ.get('TAMING_TRANSFORMERS_REPO', "https://github.com/CompVis/taming-transformers.git")
|
||||||
@ -246,15 +244,20 @@ def prepare_environment():
|
|||||||
check_python_version()
|
check_python_version()
|
||||||
|
|
||||||
commit = commit_hash()
|
commit = commit_hash()
|
||||||
|
tag = git_tag()
|
||||||
|
|
||||||
print(f"Python {sys.version}")
|
print(f"Python {sys.version}")
|
||||||
|
print(f"Version: {tag}")
|
||||||
print(f"Commit hash: {commit}")
|
print(f"Commit hash: {commit}")
|
||||||
|
|
||||||
if args.reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):
|
if args.reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):
|
||||||
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)
|
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)
|
||||||
|
|
||||||
if not args.skip_torch_cuda_test:
|
if not args.skip_torch_cuda_test and not check_run_python("import torch; assert torch.cuda.is_available()"):
|
||||||
run_python("import torch; assert torch.cuda.is_available(), 'Torch is not able to use GPU; add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'")
|
raise RuntimeError(
|
||||||
|
'Torch is not able to use GPU; '
|
||||||
|
'add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'
|
||||||
|
)
|
||||||
|
|
||||||
if not is_installed("gfpgan"):
|
if not is_installed("gfpgan"):
|
||||||
run_pip(f"install {gfpgan_package}", "gfpgan")
|
run_pip(f"install {gfpgan_package}", "gfpgan")
|
||||||
@ -302,7 +305,7 @@ def prepare_environment():
|
|||||||
|
|
||||||
if args.update_all_extensions:
|
if args.update_all_extensions:
|
||||||
git_pull_recursive(extensions_dir)
|
git_pull_recursive(extensions_dir)
|
||||||
|
|
||||||
if "--exit" in sys.argv:
|
if "--exit" in sys.argv:
|
||||||
print("Exiting because of --exit argument")
|
print("Exiting because of --exit argument")
|
||||||
exit(0)
|
exit(0)
|
||||||
|
BIN
modules/Roboto-Regular.ttf
Normal file
BIN
modules/Roboto-Regular.ttf
Normal file
Binary file not shown.
@ -15,7 +15,8 @@ from secrets import compare_digest
|
|||||||
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing
|
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing
|
||||||
from modules.api.models import *
|
from modules.api import models
|
||||||
|
from modules.shared import opts
|
||||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
||||||
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
|
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
|
||||||
from modules.textual_inversion.preprocess import preprocess
|
from modules.textual_inversion.preprocess import preprocess
|
||||||
@ -25,21 +26,24 @@ from modules.sd_models import checkpoints_list, unload_model_weights, reload_mod
|
|||||||
from modules.sd_models_config import find_checkpoint_config_near_filename
|
from modules.sd_models_config import find_checkpoint_config_near_filename
|
||||||
from modules.realesrgan_model import get_realesrgan_models
|
from modules.realesrgan_model import get_realesrgan_models
|
||||||
from modules import devices
|
from modules import devices
|
||||||
from typing import List
|
from typing import Dict, List, Any
|
||||||
import piexif
|
import piexif
|
||||||
import piexif.helper
|
import piexif.helper
|
||||||
|
|
||||||
|
|
||||||
def upscaler_to_index(name: str):
|
def upscaler_to_index(name: str):
|
||||||
try:
|
try:
|
||||||
return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
|
return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
|
||||||
except:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in sd_upscalers])}")
|
raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in shared.sd_upscalers])}") from e
|
||||||
|
|
||||||
|
|
||||||
def script_name_to_index(name, scripts):
|
def script_name_to_index(name, scripts):
|
||||||
try:
|
try:
|
||||||
return [script.title().lower() for script in scripts].index(name.lower())
|
return [script.title().lower() for script in scripts].index(name.lower())
|
||||||
except:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=422, detail=f"Script '{name}' not found")
|
raise HTTPException(status_code=422, detail=f"Script '{name}' not found") from e
|
||||||
|
|
||||||
|
|
||||||
def validate_sampler_name(name):
|
def validate_sampler_name(name):
|
||||||
config = sd_samplers.all_samplers_map.get(name, None)
|
config = sd_samplers.all_samplers_map.get(name, None)
|
||||||
@ -48,20 +52,23 @@ def validate_sampler_name(name):
|
|||||||
|
|
||||||
return name
|
return name
|
||||||
|
|
||||||
|
|
||||||
def setUpscalers(req: dict):
|
def setUpscalers(req: dict):
|
||||||
reqDict = vars(req)
|
reqDict = vars(req)
|
||||||
reqDict['extras_upscaler_1'] = reqDict.pop('upscaler_1', None)
|
reqDict['extras_upscaler_1'] = reqDict.pop('upscaler_1', None)
|
||||||
reqDict['extras_upscaler_2'] = reqDict.pop('upscaler_2', None)
|
reqDict['extras_upscaler_2'] = reqDict.pop('upscaler_2', None)
|
||||||
return reqDict
|
return reqDict
|
||||||
|
|
||||||
|
|
||||||
def decode_base64_to_image(encoding):
|
def decode_base64_to_image(encoding):
|
||||||
if encoding.startswith("data:image/"):
|
if encoding.startswith("data:image/"):
|
||||||
encoding = encoding.split(";")[1].split(",")[1]
|
encoding = encoding.split(";")[1].split(",")[1]
|
||||||
try:
|
try:
|
||||||
image = Image.open(BytesIO(base64.b64decode(encoding)))
|
image = Image.open(BytesIO(base64.b64decode(encoding)))
|
||||||
return image
|
return image
|
||||||
except Exception as err:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail="Invalid encoded image")
|
raise HTTPException(status_code=500, detail="Invalid encoded image") from e
|
||||||
|
|
||||||
|
|
||||||
def encode_pil_to_base64(image):
|
def encode_pil_to_base64(image):
|
||||||
with io.BytesIO() as output_bytes:
|
with io.BytesIO() as output_bytes:
|
||||||
@ -92,6 +99,7 @@ def encode_pil_to_base64(image):
|
|||||||
|
|
||||||
return base64.b64encode(bytes_data)
|
return base64.b64encode(bytes_data)
|
||||||
|
|
||||||
|
|
||||||
def api_middleware(app: FastAPI):
|
def api_middleware(app: FastAPI):
|
||||||
rich_available = True
|
rich_available = True
|
||||||
try:
|
try:
|
||||||
@ -99,7 +107,7 @@ def api_middleware(app: FastAPI):
|
|||||||
import starlette # importing just so it can be placed on silent list
|
import starlette # importing just so it can be placed on silent list
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
console = Console()
|
console = Console()
|
||||||
except:
|
except Exception:
|
||||||
import traceback
|
import traceback
|
||||||
rich_available = False
|
rich_available = False
|
||||||
|
|
||||||
@ -157,7 +165,7 @@ def api_middleware(app: FastAPI):
|
|||||||
class Api:
|
class Api:
|
||||||
def __init__(self, app: FastAPI, queue_lock: Lock):
|
def __init__(self, app: FastAPI, queue_lock: Lock):
|
||||||
if shared.cmd_opts.api_auth:
|
if shared.cmd_opts.api_auth:
|
||||||
self.credentials = dict()
|
self.credentials = {}
|
||||||
for auth in shared.cmd_opts.api_auth.split(","):
|
for auth in shared.cmd_opts.api_auth.split(","):
|
||||||
user, password = auth.split(":")
|
user, password = auth.split(":")
|
||||||
self.credentials[user] = password
|
self.credentials[user] = password
|
||||||
@ -166,36 +174,36 @@ class Api:
|
|||||||
self.app = app
|
self.app = app
|
||||||
self.queue_lock = queue_lock
|
self.queue_lock = queue_lock
|
||||||
api_middleware(self.app)
|
api_middleware(self.app)
|
||||||
self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse)
|
self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=models.TextToImageResponse)
|
||||||
self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse)
|
self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=models.ImageToImageResponse)
|
||||||
self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
|
self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=models.ExtrasSingleImageResponse)
|
||||||
self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
|
self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=models.ExtrasBatchImagesResponse)
|
||||||
self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
|
self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=models.PNGInfoResponse)
|
||||||
self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
|
self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=models.ProgressResponse)
|
||||||
self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
|
self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
|
||||||
self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
|
self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
|
||||||
self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"])
|
self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"])
|
||||||
self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel)
|
self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel)
|
||||||
self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
|
self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
|
||||||
self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=FlagsModel)
|
self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel)
|
||||||
self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[SamplerItem])
|
self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem])
|
||||||
self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[UpscalerItem])
|
self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem])
|
||||||
self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[SDModelItem])
|
self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem])
|
||||||
self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[HypernetworkItem])
|
self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem])
|
||||||
self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem])
|
self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[models.FaceRestorerItem])
|
||||||
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem])
|
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem])
|
||||||
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[PromptStyleItem])
|
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[models.PromptStyleItem])
|
||||||
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=EmbeddingsResponse)
|
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)
|
||||||
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
|
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
|
||||||
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=CreateResponse)
|
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
|
||||||
self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=CreateResponse)
|
self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse)
|
||||||
self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=PreprocessResponse)
|
self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse)
|
||||||
self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=TrainResponse)
|
self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse)
|
||||||
self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=TrainResponse)
|
self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse)
|
||||||
self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=MemoryResponse)
|
self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse)
|
||||||
self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
|
self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
|
||||||
self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
|
self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
|
||||||
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=ScriptsList)
|
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
|
||||||
|
|
||||||
self.default_script_arg_txt2img = []
|
self.default_script_arg_txt2img = []
|
||||||
self.default_script_arg_img2img = []
|
self.default_script_arg_img2img = []
|
||||||
@ -219,17 +227,17 @@ class Api:
|
|||||||
script_idx = script_name_to_index(script_name, script_runner.selectable_scripts)
|
script_idx = script_name_to_index(script_name, script_runner.selectable_scripts)
|
||||||
script = script_runner.selectable_scripts[script_idx]
|
script = script_runner.selectable_scripts[script_idx]
|
||||||
return script, script_idx
|
return script, script_idx
|
||||||
|
|
||||||
def get_scripts_list(self):
|
def get_scripts_list(self):
|
||||||
t2ilist = [str(title.lower()) for title in scripts.scripts_txt2img.titles]
|
t2ilist = [str(title.lower()) for title in scripts.scripts_txt2img.titles]
|
||||||
i2ilist = [str(title.lower()) for title in scripts.scripts_img2img.titles]
|
i2ilist = [str(title.lower()) for title in scripts.scripts_img2img.titles]
|
||||||
|
|
||||||
return ScriptsList(txt2img = t2ilist, img2img = i2ilist)
|
return models.ScriptsList(txt2img=t2ilist, img2img=i2ilist)
|
||||||
|
|
||||||
def get_script(self, script_name, script_runner):
|
def get_script(self, script_name, script_runner):
|
||||||
if script_name is None or script_name == "":
|
if script_name is None or script_name == "":
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
script_idx = script_name_to_index(script_name, script_runner.scripts)
|
script_idx = script_name_to_index(script_name, script_runner.scripts)
|
||||||
return script_runner.scripts[script_idx]
|
return script_runner.scripts[script_idx]
|
||||||
|
|
||||||
@ -264,11 +272,11 @@ class Api:
|
|||||||
if request.alwayson_scripts and (len(request.alwayson_scripts) > 0):
|
if request.alwayson_scripts and (len(request.alwayson_scripts) > 0):
|
||||||
for alwayson_script_name in request.alwayson_scripts.keys():
|
for alwayson_script_name in request.alwayson_scripts.keys():
|
||||||
alwayson_script = self.get_script(alwayson_script_name, script_runner)
|
alwayson_script = self.get_script(alwayson_script_name, script_runner)
|
||||||
if alwayson_script == None:
|
if alwayson_script is None:
|
||||||
raise HTTPException(status_code=422, detail=f"always on script {alwayson_script_name} not found")
|
raise HTTPException(status_code=422, detail=f"always on script {alwayson_script_name} not found")
|
||||||
# Selectable script in always on script param check
|
# Selectable script in always on script param check
|
||||||
if alwayson_script.alwayson == False:
|
if alwayson_script.alwayson is False:
|
||||||
raise HTTPException(status_code=422, detail=f"Cannot have a selectable script in the always on scripts params")
|
raise HTTPException(status_code=422, detail="Cannot have a selectable script in the always on scripts params")
|
||||||
# always on script with no arg should always run so you don't really need to add them to the requests
|
# always on script with no arg should always run so you don't really need to add them to the requests
|
||||||
if "args" in request.alwayson_scripts[alwayson_script_name]:
|
if "args" in request.alwayson_scripts[alwayson_script_name]:
|
||||||
# min between arg length in scriptrunner and arg length in the request
|
# min between arg length in scriptrunner and arg length in the request
|
||||||
@ -276,7 +284,7 @@ class Api:
|
|||||||
script_args[alwayson_script.args_from + idx] = request.alwayson_scripts[alwayson_script_name]["args"][idx]
|
script_args[alwayson_script.args_from + idx] = request.alwayson_scripts[alwayson_script_name]["args"][idx]
|
||||||
return script_args
|
return script_args
|
||||||
|
|
||||||
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
|
def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI):
|
||||||
script_runner = scripts.scripts_txt2img
|
script_runner = scripts.scripts_txt2img
|
||||||
if not script_runner.scripts:
|
if not script_runner.scripts:
|
||||||
script_runner.initialize_scripts(False)
|
script_runner.initialize_scripts(False)
|
||||||
@ -310,7 +318,7 @@ class Api:
|
|||||||
p.outpath_samples = opts.outdir_txt2img_samples
|
p.outpath_samples = opts.outdir_txt2img_samples
|
||||||
|
|
||||||
shared.state.begin()
|
shared.state.begin()
|
||||||
if selectable_scripts != None:
|
if selectable_scripts is not None:
|
||||||
p.script_args = script_args
|
p.script_args = script_args
|
||||||
processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
|
processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
|
||||||
else:
|
else:
|
||||||
@ -320,9 +328,9 @@ class Api:
|
|||||||
|
|
||||||
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
|
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
|
||||||
|
|
||||||
return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
|
return models.TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
|
||||||
|
|
||||||
def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
|
def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI):
|
||||||
init_images = img2imgreq.init_images
|
init_images = img2imgreq.init_images
|
||||||
if init_images is None:
|
if init_images is None:
|
||||||
raise HTTPException(status_code=404, detail="Init image not found")
|
raise HTTPException(status_code=404, detail="Init image not found")
|
||||||
@ -367,7 +375,7 @@ class Api:
|
|||||||
p.outpath_samples = opts.outdir_img2img_samples
|
p.outpath_samples = opts.outdir_img2img_samples
|
||||||
|
|
||||||
shared.state.begin()
|
shared.state.begin()
|
||||||
if selectable_scripts != None:
|
if selectable_scripts is not None:
|
||||||
p.script_args = script_args
|
p.script_args = script_args
|
||||||
processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
|
processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
|
||||||
else:
|
else:
|
||||||
@ -381,9 +389,9 @@ class Api:
|
|||||||
img2imgreq.init_images = None
|
img2imgreq.init_images = None
|
||||||
img2imgreq.mask = None
|
img2imgreq.mask = None
|
||||||
|
|
||||||
return ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js())
|
return models.ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js())
|
||||||
|
|
||||||
def extras_single_image_api(self, req: ExtrasSingleImageRequest):
|
def extras_single_image_api(self, req: models.ExtrasSingleImageRequest):
|
||||||
reqDict = setUpscalers(req)
|
reqDict = setUpscalers(req)
|
||||||
|
|
||||||
reqDict['image'] = decode_base64_to_image(reqDict['image'])
|
reqDict['image'] = decode_base64_to_image(reqDict['image'])
|
||||||
@ -391,9 +399,9 @@ class Api:
|
|||||||
with self.queue_lock:
|
with self.queue_lock:
|
||||||
result = postprocessing.run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict)
|
result = postprocessing.run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict)
|
||||||
|
|
||||||
return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])
|
return models.ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])
|
||||||
|
|
||||||
def extras_batch_images_api(self, req: ExtrasBatchImagesRequest):
|
def extras_batch_images_api(self, req: models.ExtrasBatchImagesRequest):
|
||||||
reqDict = setUpscalers(req)
|
reqDict = setUpscalers(req)
|
||||||
|
|
||||||
image_list = reqDict.pop('imageList', [])
|
image_list = reqDict.pop('imageList', [])
|
||||||
@ -402,15 +410,15 @@ class Api:
|
|||||||
with self.queue_lock:
|
with self.queue_lock:
|
||||||
result = postprocessing.run_extras(extras_mode=1, image_folder=image_folder, image="", input_dir="", output_dir="", save_output=False, **reqDict)
|
result = postprocessing.run_extras(extras_mode=1, image_folder=image_folder, image="", input_dir="", output_dir="", save_output=False, **reqDict)
|
||||||
|
|
||||||
return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
|
return models.ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
|
||||||
|
|
||||||
def pnginfoapi(self, req: PNGInfoRequest):
|
def pnginfoapi(self, req: models.PNGInfoRequest):
|
||||||
if(not req.image.strip()):
|
if(not req.image.strip()):
|
||||||
return PNGInfoResponse(info="")
|
return models.PNGInfoResponse(info="")
|
||||||
|
|
||||||
image = decode_base64_to_image(req.image.strip())
|
image = decode_base64_to_image(req.image.strip())
|
||||||
if image is None:
|
if image is None:
|
||||||
return PNGInfoResponse(info="")
|
return models.PNGInfoResponse(info="")
|
||||||
|
|
||||||
geninfo, items = images.read_info_from_image(image)
|
geninfo, items = images.read_info_from_image(image)
|
||||||
if geninfo is None:
|
if geninfo is None:
|
||||||
@ -418,13 +426,13 @@ class Api:
|
|||||||
|
|
||||||
items = {**{'parameters': geninfo}, **items}
|
items = {**{'parameters': geninfo}, **items}
|
||||||
|
|
||||||
return PNGInfoResponse(info=geninfo, items=items)
|
return models.PNGInfoResponse(info=geninfo, items=items)
|
||||||
|
|
||||||
def progressapi(self, req: ProgressRequest = Depends()):
|
def progressapi(self, req: models.ProgressRequest = Depends()):
|
||||||
# copy from check_progress_call of ui.py
|
# copy from check_progress_call of ui.py
|
||||||
|
|
||||||
if shared.state.job_count == 0:
|
if shared.state.job_count == 0:
|
||||||
return ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict(), textinfo=shared.state.textinfo)
|
return models.ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict(), textinfo=shared.state.textinfo)
|
||||||
|
|
||||||
# avoid dividing zero
|
# avoid dividing zero
|
||||||
progress = 0.01
|
progress = 0.01
|
||||||
@ -446,9 +454,9 @@ class Api:
|
|||||||
if shared.state.current_image and not req.skip_current_image:
|
if shared.state.current_image and not req.skip_current_image:
|
||||||
current_image = encode_pil_to_base64(shared.state.current_image)
|
current_image = encode_pil_to_base64(shared.state.current_image)
|
||||||
|
|
||||||
return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo)
|
return models.ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo)
|
||||||
|
|
||||||
def interrogateapi(self, interrogatereq: InterrogateRequest):
|
def interrogateapi(self, interrogatereq: models.InterrogateRequest):
|
||||||
image_b64 = interrogatereq.image
|
image_b64 = interrogatereq.image
|
||||||
if image_b64 is None:
|
if image_b64 is None:
|
||||||
raise HTTPException(status_code=404, detail="Image not found")
|
raise HTTPException(status_code=404, detail="Image not found")
|
||||||
@ -465,7 +473,7 @@ class Api:
|
|||||||
else:
|
else:
|
||||||
raise HTTPException(status_code=404, detail="Model not found")
|
raise HTTPException(status_code=404, detail="Model not found")
|
||||||
|
|
||||||
return InterrogateResponse(caption=processed)
|
return models.InterrogateResponse(caption=processed)
|
||||||
|
|
||||||
def interruptapi(self):
|
def interruptapi(self):
|
||||||
shared.state.interrupt()
|
shared.state.interrupt()
|
||||||
@ -570,36 +578,36 @@ class Api:
|
|||||||
filename = create_embedding(**args) # create empty embedding
|
filename = create_embedding(**args) # create empty embedding
|
||||||
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
|
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
return CreateResponse(info = "create embedding filename: {filename}".format(filename = filename))
|
return models.CreateResponse(info=f"create embedding filename: {filename}")
|
||||||
except AssertionError as e:
|
except AssertionError as e:
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
return TrainResponse(info = "create embedding error: {error}".format(error = e))
|
return models.TrainResponse(info=f"create embedding error: {e}")
|
||||||
|
|
||||||
def create_hypernetwork(self, args: dict):
|
def create_hypernetwork(self, args: dict):
|
||||||
try:
|
try:
|
||||||
shared.state.begin()
|
shared.state.begin()
|
||||||
filename = create_hypernetwork(**args) # create empty embedding
|
filename = create_hypernetwork(**args) # create empty embedding
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
return CreateResponse(info = "create hypernetwork filename: {filename}".format(filename = filename))
|
return models.CreateResponse(info=f"create hypernetwork filename: {filename}")
|
||||||
except AssertionError as e:
|
except AssertionError as e:
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
return TrainResponse(info = "create hypernetwork error: {error}".format(error = e))
|
return models.TrainResponse(info=f"create hypernetwork error: {e}")
|
||||||
|
|
||||||
def preprocess(self, args: dict):
|
def preprocess(self, args: dict):
|
||||||
try:
|
try:
|
||||||
shared.state.begin()
|
shared.state.begin()
|
||||||
preprocess(**args) # quick operation unless blip/booru interrogation is enabled
|
preprocess(**args) # quick operation unless blip/booru interrogation is enabled
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
return PreprocessResponse(info = 'preprocess complete')
|
return models.PreprocessResponse(info = 'preprocess complete')
|
||||||
except KeyError as e:
|
except KeyError as e:
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
return PreprocessResponse(info = "preprocess error: invalid token: {error}".format(error = e))
|
return models.PreprocessResponse(info=f"preprocess error: invalid token: {e}")
|
||||||
except AssertionError as e:
|
except AssertionError as e:
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
return PreprocessResponse(info = "preprocess error: {error}".format(error = e))
|
return models.PreprocessResponse(info=f"preprocess error: {e}")
|
||||||
except FileNotFoundError as e:
|
except FileNotFoundError as e:
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
return PreprocessResponse(info = 'preprocess error: {error}'.format(error = e))
|
return models.PreprocessResponse(info=f'preprocess error: {e}')
|
||||||
|
|
||||||
def train_embedding(self, args: dict):
|
def train_embedding(self, args: dict):
|
||||||
try:
|
try:
|
||||||
@ -617,10 +625,10 @@ class Api:
|
|||||||
if not apply_optimizations:
|
if not apply_optimizations:
|
||||||
sd_hijack.apply_optimizations()
|
sd_hijack.apply_optimizations()
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
return TrainResponse(info = "train embedding complete: filename: {filename} error: {error}".format(filename = filename, error = error))
|
return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
|
||||||
except AssertionError as msg:
|
except AssertionError as msg:
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
return TrainResponse(info = "train embedding error: {msg}".format(msg = msg))
|
return models.TrainResponse(info=f"train embedding error: {msg}")
|
||||||
|
|
||||||
def train_hypernetwork(self, args: dict):
|
def train_hypernetwork(self, args: dict):
|
||||||
try:
|
try:
|
||||||
@ -641,14 +649,15 @@ class Api:
|
|||||||
if not apply_optimizations:
|
if not apply_optimizations:
|
||||||
sd_hijack.apply_optimizations()
|
sd_hijack.apply_optimizations()
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
return TrainResponse(info="train embedding complete: filename: {filename} error: {error}".format(filename=filename, error=error))
|
return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
|
||||||
except AssertionError as msg:
|
except AssertionError:
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
return TrainResponse(info="train embedding error: {error}".format(error=error))
|
return models.TrainResponse(info=f"train embedding error: {error}")
|
||||||
|
|
||||||
def get_memory(self):
|
def get_memory(self):
|
||||||
try:
|
try:
|
||||||
import os, psutil
|
import os
|
||||||
|
import psutil
|
||||||
process = psutil.Process(os.getpid())
|
process = psutil.Process(os.getpid())
|
||||||
res = process.memory_info() # only rss is cross-platform guaranteed so we dont rely on other values
|
res = process.memory_info() # only rss is cross-platform guaranteed so we dont rely on other values
|
||||||
ram_total = 100 * res.rss / process.memory_percent() # and total memory is calculated as actual value is not cross-platform safe
|
ram_total = 100 * res.rss / process.memory_percent() # and total memory is calculated as actual value is not cross-platform safe
|
||||||
@ -675,10 +684,10 @@ class Api:
|
|||||||
'events': warnings,
|
'events': warnings,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
cuda = { 'error': 'unavailable' }
|
cuda = {'error': 'unavailable'}
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
cuda = { 'error': f'{err}' }
|
cuda = {'error': f'{err}'}
|
||||||
return MemoryResponse(ram = ram, cuda = cuda)
|
return models.MemoryResponse(ram=ram, cuda=cuda)
|
||||||
|
|
||||||
def launch(self, server_name, port):
|
def launch(self, server_name, port):
|
||||||
self.app.include_router(self.router)
|
self.app.include_router(self.router)
|
||||||
|
@ -223,8 +223,9 @@ for key in _options:
|
|||||||
if(_options[key].dest != 'help'):
|
if(_options[key].dest != 'help'):
|
||||||
flag = _options[key]
|
flag = _options[key]
|
||||||
_type = str
|
_type = str
|
||||||
if _options[key].default is not None: _type = type(_options[key].default)
|
if _options[key].default is not None:
|
||||||
flags.update({flag.dest: (_type,Field(default=flag.default, description=flag.help))})
|
_type = type(_options[key].default)
|
||||||
|
flags.update({flag.dest: (_type, Field(default=flag.default, description=flag.help))})
|
||||||
|
|
||||||
FlagsModel = create_model("Flags", **flags)
|
FlagsModel = create_model("Flags", **flags)
|
||||||
|
|
||||||
@ -288,4 +289,4 @@ class MemoryResponse(BaseModel):
|
|||||||
|
|
||||||
class ScriptsList(BaseModel):
|
class ScriptsList(BaseModel):
|
||||||
txt2img: list = Field(default=None,title="Txt2img", description="Titles of scripts (txt2img)")
|
txt2img: list = Field(default=None,title="Txt2img", description="Titles of scripts (txt2img)")
|
||||||
img2img: list = Field(default=None,title="Img2img", description="Titles of scripts (img2img)")
|
img2img: list = Field(default=None,title="Img2img", description="Titles of scripts (img2img)")
|
||||||
|
@ -60,7 +60,7 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
|
|||||||
max_debug_str_len = 131072 # (1024*1024)/8
|
max_debug_str_len = 131072 # (1024*1024)/8
|
||||||
|
|
||||||
print("Error completing request", file=sys.stderr)
|
print("Error completing request", file=sys.stderr)
|
||||||
argStr = f"Arguments: {str(args)} {str(kwargs)}"
|
argStr = f"Arguments: {args} {kwargs}"
|
||||||
print(argStr[:max_debug_str_len], file=sys.stderr)
|
print(argStr[:max_debug_str_len], file=sys.stderr)
|
||||||
if len(argStr) > max_debug_str_len:
|
if len(argStr) > max_debug_str_len:
|
||||||
print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr)
|
print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr)
|
||||||
@ -73,7 +73,8 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
|
|||||||
if extra_outputs_array is None:
|
if extra_outputs_array is None:
|
||||||
extra_outputs_array = [None, '']
|
extra_outputs_array = [None, '']
|
||||||
|
|
||||||
res = extra_outputs_array + [f"<div class='error'>{html.escape(type(e).__name__+': '+str(e))}</div>"]
|
error_message = f'{type(e).__name__}: {e}'
|
||||||
|
res = extra_outputs_array + [f"<div class='error'>{html.escape(error_message)}</div>"]
|
||||||
|
|
||||||
shared.state.skipped = False
|
shared.state.skipped = False
|
||||||
shared.state.interrupted = False
|
shared.state.interrupted = False
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, sd_default_config, sd_model_file
|
from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, sd_default_config, sd_model_file # noqa: F401
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
@ -102,3 +102,4 @@ parser.add_argument("--no-gradio-queue", action='store_true', help="Disables gra
|
|||||||
parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers")
|
parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers")
|
||||||
parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
|
parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
|
||||||
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
|
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
|
||||||
|
parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy')
|
||||||
|
@ -1,14 +1,12 @@
|
|||||||
# this file is copied from CodeFormer repository. Please see comment in modules/codeformer_model.py
|
# this file is copied from CodeFormer repository. Please see comment in modules/codeformer_model.py
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn, Tensor
|
from torch import nn, Tensor
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from typing import Optional, List
|
from typing import Optional
|
||||||
|
|
||||||
from modules.codeformer.vqgan_arch import *
|
from modules.codeformer.vqgan_arch import VQAutoEncoder, ResBlock
|
||||||
from basicsr.utils import get_root_logger
|
|
||||||
from basicsr.utils.registry import ARCH_REGISTRY
|
from basicsr.utils.registry import ARCH_REGISTRY
|
||||||
|
|
||||||
def calc_mean_std(feat, eps=1e-5):
|
def calc_mean_std(feat, eps=1e-5):
|
||||||
@ -121,7 +119,7 @@ class TransformerSALayer(nn.Module):
|
|||||||
tgt_mask: Optional[Tensor] = None,
|
tgt_mask: Optional[Tensor] = None,
|
||||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||||
query_pos: Optional[Tensor] = None):
|
query_pos: Optional[Tensor] = None):
|
||||||
|
|
||||||
# self attention
|
# self attention
|
||||||
tgt2 = self.norm1(tgt)
|
tgt2 = self.norm1(tgt)
|
||||||
q = k = self.with_pos_embed(tgt2, query_pos)
|
q = k = self.with_pos_embed(tgt2, query_pos)
|
||||||
@ -161,10 +159,10 @@ class Fuse_sft_block(nn.Module):
|
|||||||
|
|
||||||
@ARCH_REGISTRY.register()
|
@ARCH_REGISTRY.register()
|
||||||
class CodeFormer(VQAutoEncoder):
|
class CodeFormer(VQAutoEncoder):
|
||||||
def __init__(self, dim_embd=512, n_head=8, n_layers=9,
|
def __init__(self, dim_embd=512, n_head=8, n_layers=9,
|
||||||
codebook_size=1024, latent_size=256,
|
codebook_size=1024, latent_size=256,
|
||||||
connect_list=['32', '64', '128', '256'],
|
connect_list=('32', '64', '128', '256'),
|
||||||
fix_modules=['quantize','generator']):
|
fix_modules=('quantize', 'generator')):
|
||||||
super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
|
super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
|
||||||
|
|
||||||
if fix_modules is not None:
|
if fix_modules is not None:
|
||||||
@ -181,14 +179,14 @@ class CodeFormer(VQAutoEncoder):
|
|||||||
self.feat_emb = nn.Linear(256, self.dim_embd)
|
self.feat_emb = nn.Linear(256, self.dim_embd)
|
||||||
|
|
||||||
# transformer
|
# transformer
|
||||||
self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
|
self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
|
||||||
for _ in range(self.n_layers)])
|
for _ in range(self.n_layers)])
|
||||||
|
|
||||||
# logits_predict head
|
# logits_predict head
|
||||||
self.idx_pred_layer = nn.Sequential(
|
self.idx_pred_layer = nn.Sequential(
|
||||||
nn.LayerNorm(dim_embd),
|
nn.LayerNorm(dim_embd),
|
||||||
nn.Linear(dim_embd, codebook_size, bias=False))
|
nn.Linear(dim_embd, codebook_size, bias=False))
|
||||||
|
|
||||||
self.channels = {
|
self.channels = {
|
||||||
'16': 512,
|
'16': 512,
|
||||||
'32': 256,
|
'32': 256,
|
||||||
@ -223,7 +221,7 @@ class CodeFormer(VQAutoEncoder):
|
|||||||
enc_feat_dict = {}
|
enc_feat_dict = {}
|
||||||
out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
|
out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
|
||||||
for i, block in enumerate(self.encoder.blocks):
|
for i, block in enumerate(self.encoder.blocks):
|
||||||
x = block(x)
|
x = block(x)
|
||||||
if i in out_list:
|
if i in out_list:
|
||||||
enc_feat_dict[str(x.shape[-1])] = x.clone()
|
enc_feat_dict[str(x.shape[-1])] = x.clone()
|
||||||
|
|
||||||
@ -268,11 +266,11 @@ class CodeFormer(VQAutoEncoder):
|
|||||||
fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
|
fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
|
||||||
|
|
||||||
for i, block in enumerate(self.generator.blocks):
|
for i, block in enumerate(self.generator.blocks):
|
||||||
x = block(x)
|
x = block(x)
|
||||||
if i in fuse_list: # fuse after i-th block
|
if i in fuse_list: # fuse after i-th block
|
||||||
f_size = str(x.shape[-1])
|
f_size = str(x.shape[-1])
|
||||||
if w>0:
|
if w>0:
|
||||||
x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
|
x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
|
||||||
out = x
|
out = x
|
||||||
# logits doesn't need softmax before cross_entropy loss
|
# logits doesn't need softmax before cross_entropy loss
|
||||||
return out, logits, lq_feat
|
return out, logits, lq_feat
|
||||||
|
@ -5,17 +5,15 @@ VQGAN code, adapted from the original created by the Unleashing Transformers aut
|
|||||||
https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
|
https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
|
||||||
|
|
||||||
'''
|
'''
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import copy
|
|
||||||
from basicsr.utils import get_root_logger
|
from basicsr.utils import get_root_logger
|
||||||
from basicsr.utils.registry import ARCH_REGISTRY
|
from basicsr.utils.registry import ARCH_REGISTRY
|
||||||
|
|
||||||
def normalize(in_channels):
|
def normalize(in_channels):
|
||||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
def swish(x):
|
def swish(x):
|
||||||
@ -212,15 +210,15 @@ class AttnBlock(nn.Module):
|
|||||||
# compute attention
|
# compute attention
|
||||||
b, c, h, w = q.shape
|
b, c, h, w = q.shape
|
||||||
q = q.reshape(b, c, h*w)
|
q = q.reshape(b, c, h*w)
|
||||||
q = q.permute(0, 2, 1)
|
q = q.permute(0, 2, 1)
|
||||||
k = k.reshape(b, c, h*w)
|
k = k.reshape(b, c, h*w)
|
||||||
w_ = torch.bmm(q, k)
|
w_ = torch.bmm(q, k)
|
||||||
w_ = w_ * (int(c)**(-0.5))
|
w_ = w_ * (int(c)**(-0.5))
|
||||||
w_ = F.softmax(w_, dim=2)
|
w_ = F.softmax(w_, dim=2)
|
||||||
|
|
||||||
# attend to values
|
# attend to values
|
||||||
v = v.reshape(b, c, h*w)
|
v = v.reshape(b, c, h*w)
|
||||||
w_ = w_.permute(0, 2, 1)
|
w_ = w_.permute(0, 2, 1)
|
||||||
h_ = torch.bmm(v, w_)
|
h_ = torch.bmm(v, w_)
|
||||||
h_ = h_.reshape(b, c, h, w)
|
h_ = h_.reshape(b, c, h, w)
|
||||||
|
|
||||||
@ -272,18 +270,18 @@ class Encoder(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
x = block(x)
|
x = block(x)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class Generator(nn.Module):
|
class Generator(nn.Module):
|
||||||
def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
|
def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.nf = nf
|
self.nf = nf
|
||||||
self.ch_mult = ch_mult
|
self.ch_mult = ch_mult
|
||||||
self.num_resolutions = len(self.ch_mult)
|
self.num_resolutions = len(self.ch_mult)
|
||||||
self.num_res_blocks = res_blocks
|
self.num_res_blocks = res_blocks
|
||||||
self.resolution = img_size
|
self.resolution = img_size
|
||||||
self.attn_resolutions = attn_resolutions
|
self.attn_resolutions = attn_resolutions
|
||||||
self.in_channels = emb_dim
|
self.in_channels = emb_dim
|
||||||
self.out_channels = 3
|
self.out_channels = 3
|
||||||
@ -317,29 +315,29 @@ class Generator(nn.Module):
|
|||||||
blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1))
|
blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1))
|
||||||
|
|
||||||
self.blocks = nn.ModuleList(blocks)
|
self.blocks = nn.ModuleList(blocks)
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
x = block(x)
|
x = block(x)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@ARCH_REGISTRY.register()
|
@ARCH_REGISTRY.register()
|
||||||
class VQAutoEncoder(nn.Module):
|
class VQAutoEncoder(nn.Module):
|
||||||
def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256,
|
def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=None, codebook_size=1024, emb_dim=256,
|
||||||
beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
|
beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
logger = get_root_logger()
|
logger = get_root_logger()
|
||||||
self.in_channels = 3
|
self.in_channels = 3
|
||||||
self.nf = nf
|
self.nf = nf
|
||||||
self.n_blocks = res_blocks
|
self.n_blocks = res_blocks
|
||||||
self.codebook_size = codebook_size
|
self.codebook_size = codebook_size
|
||||||
self.embed_dim = emb_dim
|
self.embed_dim = emb_dim
|
||||||
self.ch_mult = ch_mult
|
self.ch_mult = ch_mult
|
||||||
self.resolution = img_size
|
self.resolution = img_size
|
||||||
self.attn_resolutions = attn_resolutions
|
self.attn_resolutions = attn_resolutions or [16]
|
||||||
self.quantizer_type = quantizer
|
self.quantizer_type = quantizer
|
||||||
self.encoder = Encoder(
|
self.encoder = Encoder(
|
||||||
self.in_channels,
|
self.in_channels,
|
||||||
@ -365,11 +363,11 @@ class VQAutoEncoder(nn.Module):
|
|||||||
self.kl_weight
|
self.kl_weight
|
||||||
)
|
)
|
||||||
self.generator = Generator(
|
self.generator = Generator(
|
||||||
self.nf,
|
self.nf,
|
||||||
self.embed_dim,
|
self.embed_dim,
|
||||||
self.ch_mult,
|
self.ch_mult,
|
||||||
self.n_blocks,
|
self.n_blocks,
|
||||||
self.resolution,
|
self.resolution,
|
||||||
self.attn_resolutions
|
self.attn_resolutions
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -434,4 +432,4 @@ class VQGANDiscriminator(nn.Module):
|
|||||||
raise ValueError('Wrong params!')
|
raise ValueError('Wrong params!')
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.main(x)
|
return self.main(x)
|
||||||
|
@ -33,11 +33,9 @@ def setup_model(dirname):
|
|||||||
try:
|
try:
|
||||||
from torchvision.transforms.functional import normalize
|
from torchvision.transforms.functional import normalize
|
||||||
from modules.codeformer.codeformer_arch import CodeFormer
|
from modules.codeformer.codeformer_arch import CodeFormer
|
||||||
from basicsr.utils.download_util import load_file_from_url
|
from basicsr.utils import img2tensor, tensor2img
|
||||||
from basicsr.utils import imwrite, img2tensor, tensor2img
|
|
||||||
from facelib.utils.face_restoration_helper import FaceRestoreHelper
|
from facelib.utils.face_restoration_helper import FaceRestoreHelper
|
||||||
from facelib.detection.retinaface import retinaface
|
from facelib.detection.retinaface import retinaface
|
||||||
from modules.shared import cmd_opts
|
|
||||||
|
|
||||||
net_class = CodeFormer
|
net_class = CodeFormer
|
||||||
|
|
||||||
@ -96,7 +94,7 @@ def setup_model(dirname):
|
|||||||
self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
|
self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
|
||||||
self.face_helper.align_warp_face()
|
self.face_helper.align_warp_face()
|
||||||
|
|
||||||
for idx, cropped_face in enumerate(self.face_helper.cropped_faces):
|
for cropped_face in self.face_helper.cropped_faces:
|
||||||
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
|
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
|
||||||
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
||||||
cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)
|
cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)
|
||||||
|
@ -14,7 +14,7 @@ from collections import OrderedDict
|
|||||||
import git
|
import git
|
||||||
|
|
||||||
from modules import shared, extensions
|
from modules import shared, extensions
|
||||||
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path, config_states_dir
|
from modules.paths_internal import script_path, config_states_dir
|
||||||
|
|
||||||
|
|
||||||
all_config_states = OrderedDict()
|
all_config_states = OrderedDict()
|
||||||
@ -35,7 +35,7 @@ def list_config_states():
|
|||||||
j["filepath"] = path
|
j["filepath"] = path
|
||||||
config_states.append(j)
|
config_states.append(j)
|
||||||
|
|
||||||
config_states = list(sorted(config_states, key=lambda cs: cs["created_at"], reverse=True))
|
config_states = sorted(config_states, key=lambda cs: cs["created_at"], reverse=True)
|
||||||
|
|
||||||
for cs in config_states:
|
for cs in config_states:
|
||||||
timestamp = time.asctime(time.gmtime(cs["created_at"]))
|
timestamp = time.asctime(time.gmtime(cs["created_at"]))
|
||||||
|
@ -2,7 +2,6 @@ import os
|
|||||||
import re
|
import re
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from modules import modelloader, paths, deepbooru_model, devices, images, shared
|
from modules import modelloader, paths, deepbooru_model, devices, images, shared
|
||||||
@ -79,7 +78,7 @@ class DeepDanbooru:
|
|||||||
|
|
||||||
res = []
|
res = []
|
||||||
|
|
||||||
filtertags = set([x.strip().replace(' ', '_') for x in shared.opts.deepbooru_filter_tags.split(",")])
|
filtertags = {x.strip().replace(' ', '_') for x in shared.opts.deepbooru_filter_tags.split(",")}
|
||||||
|
|
||||||
for tag in [x for x in tags if x not in filtertags]:
|
for tag in [x for x in tags if x not in filtertags]:
|
||||||
probability = probability_dict[tag]
|
probability = probability_dict[tag]
|
||||||
|
@ -65,7 +65,7 @@ def enable_tf32():
|
|||||||
|
|
||||||
# enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
|
# enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
|
||||||
# see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
|
# see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
|
||||||
if any([torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())]):
|
if any(torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())):
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
|
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
@ -6,7 +6,7 @@ from PIL import Image
|
|||||||
from basicsr.utils.download_util import load_file_from_url
|
from basicsr.utils.download_util import load_file_from_url
|
||||||
|
|
||||||
import modules.esrgan_model_arch as arch
|
import modules.esrgan_model_arch as arch
|
||||||
from modules import shared, modelloader, images, devices
|
from modules import modelloader, images, devices
|
||||||
from modules.upscaler import Upscaler, UpscalerData
|
from modules.upscaler import Upscaler, UpscalerData
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
|
|
||||||
@ -16,9 +16,7 @@ def mod2normal(state_dict):
|
|||||||
# this code is copied from https://github.com/victorca25/iNNfer
|
# this code is copied from https://github.com/victorca25/iNNfer
|
||||||
if 'conv_first.weight' in state_dict:
|
if 'conv_first.weight' in state_dict:
|
||||||
crt_net = {}
|
crt_net = {}
|
||||||
items = []
|
items = list(state_dict)
|
||||||
for k, v in state_dict.items():
|
|
||||||
items.append(k)
|
|
||||||
|
|
||||||
crt_net['model.0.weight'] = state_dict['conv_first.weight']
|
crt_net['model.0.weight'] = state_dict['conv_first.weight']
|
||||||
crt_net['model.0.bias'] = state_dict['conv_first.bias']
|
crt_net['model.0.bias'] = state_dict['conv_first.bias']
|
||||||
@ -52,9 +50,7 @@ def resrgan2normal(state_dict, nb=23):
|
|||||||
if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict:
|
if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict:
|
||||||
re8x = 0
|
re8x = 0
|
||||||
crt_net = {}
|
crt_net = {}
|
||||||
items = []
|
items = list(state_dict)
|
||||||
for k, v in state_dict.items():
|
|
||||||
items.append(k)
|
|
||||||
|
|
||||||
crt_net['model.0.weight'] = state_dict['conv_first.weight']
|
crt_net['model.0.weight'] = state_dict['conv_first.weight']
|
||||||
crt_net['model.0.bias'] = state_dict['conv_first.bias']
|
crt_net['model.0.bias'] = state_dict['conv_first.bias']
|
||||||
@ -156,13 +152,16 @@ class UpscalerESRGAN(Upscaler):
|
|||||||
|
|
||||||
def load_model(self, path: str):
|
def load_model(self, path: str):
|
||||||
if "http" in path:
|
if "http" in path:
|
||||||
filename = load_file_from_url(url=self.model_url, model_dir=self.model_path,
|
filename = load_file_from_url(
|
||||||
file_name="%s.pth" % self.model_name,
|
url=self.model_url,
|
||||||
progress=True)
|
model_dir=self.model_path,
|
||||||
|
file_name=f"{self.model_name}.pth",
|
||||||
|
progress=True,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
filename = path
|
filename = path
|
||||||
if not os.path.exists(filename) or filename is None:
|
if not os.path.exists(filename) or filename is None:
|
||||||
print("Unable to load %s from %s" % (self.model_path, filename))
|
print(f"Unable to load {self.model_path} from {filename}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
|
state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
|
||||||
|
@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
import math
|
import math
|
||||||
import functools
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -38,7 +37,7 @@ class RRDBNet(nn.Module):
|
|||||||
elif upsample_mode == 'pixelshuffle':
|
elif upsample_mode == 'pixelshuffle':
|
||||||
upsample_block = pixelshuffle_block
|
upsample_block = pixelshuffle_block
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
|
raise NotImplementedError(f'upsample mode [{upsample_mode}] is not found')
|
||||||
if upscale == 3:
|
if upscale == 3:
|
||||||
upsampler = upsample_block(nf, nf, 3, act_type=act_type, convtype=convtype)
|
upsampler = upsample_block(nf, nf, 3, act_type=act_type, convtype=convtype)
|
||||||
else:
|
else:
|
||||||
@ -106,7 +105,7 @@ class ResidualDenseBlock_5C(nn.Module):
|
|||||||
Modified options that can be used:
|
Modified options that can be used:
|
||||||
- "Partial Convolution based Padding" arXiv:1811.11718
|
- "Partial Convolution based Padding" arXiv:1811.11718
|
||||||
- "Spectral normalization" arXiv:1802.05957
|
- "Spectral normalization" arXiv:1802.05957
|
||||||
- "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
|
- "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
|
||||||
{Rakotonirina} and A. {Rasoanaivo}
|
{Rakotonirina} and A. {Rasoanaivo}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -171,7 +170,7 @@ class GaussianNoise(nn.Module):
|
|||||||
scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
|
scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
|
||||||
sampled_noise = self.noise.repeat(*x.size()).normal_() * scale
|
sampled_noise = self.noise.repeat(*x.size()).normal_() * scale
|
||||||
x = x + sampled_noise
|
x = x + sampled_noise
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def conv1x1(in_planes, out_planes, stride=1):
|
def conv1x1(in_planes, out_planes, stride=1):
|
||||||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||||
@ -261,10 +260,10 @@ class Upsample(nn.Module):
|
|||||||
|
|
||||||
def extra_repr(self):
|
def extra_repr(self):
|
||||||
if self.scale_factor is not None:
|
if self.scale_factor is not None:
|
||||||
info = 'scale_factor=' + str(self.scale_factor)
|
info = f'scale_factor={self.scale_factor}'
|
||||||
else:
|
else:
|
||||||
info = 'size=' + str(self.size)
|
info = f'size={self.size}'
|
||||||
info += ', mode=' + self.mode
|
info += f', mode={self.mode}'
|
||||||
return info
|
return info
|
||||||
|
|
||||||
|
|
||||||
@ -350,7 +349,7 @@ def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1, beta=1.0):
|
|||||||
elif act_type == 'sigmoid': # [0, 1] range output
|
elif act_type == 'sigmoid': # [0, 1] range output
|
||||||
layer = nn.Sigmoid()
|
layer = nn.Sigmoid()
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('activation layer [{:s}] is not found'.format(act_type))
|
raise NotImplementedError(f'activation layer [{act_type}] is not found')
|
||||||
return layer
|
return layer
|
||||||
|
|
||||||
|
|
||||||
@ -372,7 +371,7 @@ def norm(norm_type, nc):
|
|||||||
elif norm_type == 'none':
|
elif norm_type == 'none':
|
||||||
def norm_layer(x): return Identity()
|
def norm_layer(x): return Identity()
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('normalization layer [{:s}] is not found'.format(norm_type))
|
raise NotImplementedError(f'normalization layer [{norm_type}] is not found')
|
||||||
return layer
|
return layer
|
||||||
|
|
||||||
|
|
||||||
@ -388,7 +387,7 @@ def pad(pad_type, padding):
|
|||||||
elif pad_type == 'zero':
|
elif pad_type == 'zero':
|
||||||
layer = nn.ZeroPad2d(padding)
|
layer = nn.ZeroPad2d(padding)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('padding layer [{:s}] is not implemented'.format(pad_type))
|
raise NotImplementedError(f'padding layer [{pad_type}] is not implemented')
|
||||||
return layer
|
return layer
|
||||||
|
|
||||||
|
|
||||||
@ -432,15 +431,17 @@ def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=
|
|||||||
pad_type='zero', norm_type=None, act_type='relu', mode='CNA', convtype='Conv2D',
|
pad_type='zero', norm_type=None, act_type='relu', mode='CNA', convtype='Conv2D',
|
||||||
spectral_norm=False):
|
spectral_norm=False):
|
||||||
""" Conv layer with padding, normalization, activation """
|
""" Conv layer with padding, normalization, activation """
|
||||||
assert mode in ['CNA', 'NAC', 'CNAC'], 'Wrong conv mode [{:s}]'.format(mode)
|
assert mode in ['CNA', 'NAC', 'CNAC'], f'Wrong conv mode [{mode}]'
|
||||||
padding = get_valid_padding(kernel_size, dilation)
|
padding = get_valid_padding(kernel_size, dilation)
|
||||||
p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
|
p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
|
||||||
padding = padding if pad_type == 'zero' else 0
|
padding = padding if pad_type == 'zero' else 0
|
||||||
|
|
||||||
if convtype=='PartialConv2D':
|
if convtype=='PartialConv2D':
|
||||||
|
from torchvision.ops import PartialConv2d # this is definitely not going to work, but PartialConv2d doesn't work anyway and this shuts up static analyzer
|
||||||
c = PartialConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
c = PartialConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
||||||
dilation=dilation, bias=bias, groups=groups)
|
dilation=dilation, bias=bias, groups=groups)
|
||||||
elif convtype=='DeformConv2D':
|
elif convtype=='DeformConv2D':
|
||||||
|
from torchvision.ops import DeformConv2d # not tested
|
||||||
c = DeformConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
c = DeformConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
||||||
dilation=dilation, bias=bias, groups=groups)
|
dilation=dilation, bias=bias, groups=groups)
|
||||||
elif convtype=='Conv3D':
|
elif convtype=='Conv3D':
|
||||||
|
@ -3,11 +3,10 @@ import sys
|
|||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
import time
|
import time
|
||||||
from datetime import datetime
|
|
||||||
import git
|
import git
|
||||||
|
|
||||||
from modules import shared
|
from modules import shared
|
||||||
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path
|
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401
|
||||||
|
|
||||||
extensions = []
|
extensions = []
|
||||||
|
|
||||||
|
@ -91,7 +91,7 @@ def deactivate(p, extra_network_data):
|
|||||||
"""call deactivate for extra networks in extra_network_data in specified order, then call
|
"""call deactivate for extra networks in extra_network_data in specified order, then call
|
||||||
deactivate for all remaining registered networks"""
|
deactivate for all remaining registered networks"""
|
||||||
|
|
||||||
for extra_network_name, extra_network_args in extra_network_data.items():
|
for extra_network_name in extra_network_data:
|
||||||
extra_network = extra_network_registry.get(extra_network_name, None)
|
extra_network = extra_network_registry.get(extra_network_name, None)
|
||||||
if extra_network is None:
|
if extra_network is None:
|
||||||
continue
|
continue
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from modules import extra_networks, shared, extra_networks
|
from modules import extra_networks, shared
|
||||||
from modules.hypernetworks import hypernetwork
|
from modules.hypernetworks import hypernetwork
|
||||||
|
|
||||||
|
|
||||||
@ -10,7 +10,8 @@ class ExtraNetworkHypernet(extra_networks.ExtraNetwork):
|
|||||||
additional = shared.opts.sd_hypernetwork
|
additional = shared.opts.sd_hypernetwork
|
||||||
|
|
||||||
if additional != "None" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0:
|
if additional != "None" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0:
|
||||||
p.all_prompts = [x + f"<hypernet:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
|
hypernet_prompt_text = f"<hypernet:{additional}:{shared.opts.extra_networks_default_multiplier}>"
|
||||||
|
p.all_prompts = [f"{prompt}{hypernet_prompt_text}" for prompt in p.all_prompts]
|
||||||
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
|
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
|
||||||
|
|
||||||
names = []
|
names = []
|
||||||
|
@ -136,14 +136,14 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
|
|||||||
result_is_instruct_pix2pix_model = False
|
result_is_instruct_pix2pix_model = False
|
||||||
|
|
||||||
if theta_func2:
|
if theta_func2:
|
||||||
shared.state.textinfo = f"Loading B"
|
shared.state.textinfo = "Loading B"
|
||||||
print(f"Loading {secondary_model_info.filename}...")
|
print(f"Loading {secondary_model_info.filename}...")
|
||||||
theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
|
theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
|
||||||
else:
|
else:
|
||||||
theta_1 = None
|
theta_1 = None
|
||||||
|
|
||||||
if theta_func1:
|
if theta_func1:
|
||||||
shared.state.textinfo = f"Loading C"
|
shared.state.textinfo = "Loading C"
|
||||||
print(f"Loading {tertiary_model_info.filename}...")
|
print(f"Loading {tertiary_model_info.filename}...")
|
||||||
theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu')
|
theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu')
|
||||||
|
|
||||||
@ -199,7 +199,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
|
|||||||
result_is_inpainting_model = True
|
result_is_inpainting_model = True
|
||||||
else:
|
else:
|
||||||
theta_0[key] = theta_func2(a, b, multiplier)
|
theta_0[key] = theta_func2(a, b, multiplier)
|
||||||
|
|
||||||
theta_0[key] = to_half(theta_0[key], save_as_half)
|
theta_0[key] = to_half(theta_0[key], save_as_half)
|
||||||
|
|
||||||
shared.state.sampling_step += 1
|
shared.state.sampling_step += 1
|
||||||
|
@ -1,15 +1,11 @@
|
|||||||
import base64
|
import base64
|
||||||
import html
|
|
||||||
import io
|
import io
|
||||||
import math
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
from modules.paths import data_path
|
from modules.paths import data_path
|
||||||
from modules import shared, ui_tempdir, script_callbacks
|
from modules import shared, ui_tempdir, script_callbacks
|
||||||
import tempfile
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)'
|
re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)'
|
||||||
@ -23,14 +19,14 @@ registered_param_bindings = []
|
|||||||
|
|
||||||
|
|
||||||
class ParamBinding:
|
class ParamBinding:
|
||||||
def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=[]):
|
def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=None):
|
||||||
self.paste_button = paste_button
|
self.paste_button = paste_button
|
||||||
self.tabname = tabname
|
self.tabname = tabname
|
||||||
self.source_text_component = source_text_component
|
self.source_text_component = source_text_component
|
||||||
self.source_image_component = source_image_component
|
self.source_image_component = source_image_component
|
||||||
self.source_tabname = source_tabname
|
self.source_tabname = source_tabname
|
||||||
self.override_settings_component = override_settings_component
|
self.override_settings_component = override_settings_component
|
||||||
self.paste_field_names = paste_field_names
|
self.paste_field_names = paste_field_names or []
|
||||||
|
|
||||||
|
|
||||||
def reset():
|
def reset():
|
||||||
@ -59,6 +55,7 @@ def image_from_url_text(filedata):
|
|||||||
is_in_right_dir = ui_tempdir.check_tmp_file(shared.demo, filename)
|
is_in_right_dir = ui_tempdir.check_tmp_file(shared.demo, filename)
|
||||||
assert is_in_right_dir, 'trying to open image file outside of allowed directories'
|
assert is_in_right_dir, 'trying to open image file outside of allowed directories'
|
||||||
|
|
||||||
|
filename = filename.rsplit('?', 1)[0]
|
||||||
return Image.open(filename)
|
return Image.open(filename)
|
||||||
|
|
||||||
if type(filedata) == list:
|
if type(filedata) == list:
|
||||||
@ -129,6 +126,7 @@ def connect_paste_params_buttons():
|
|||||||
_js=jsfunc,
|
_js=jsfunc,
|
||||||
inputs=[binding.source_image_component],
|
inputs=[binding.source_image_component],
|
||||||
outputs=[destination_image_component, destination_width_component, destination_height_component] if destination_width_component else [destination_image_component],
|
outputs=[destination_image_component, destination_width_component, destination_height_component] if destination_width_component else [destination_image_component],
|
||||||
|
show_progress=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
if binding.source_text_component is not None and fields is not None:
|
if binding.source_text_component is not None and fields is not None:
|
||||||
@ -140,6 +138,7 @@ def connect_paste_params_buttons():
|
|||||||
fn=lambda *x: x,
|
fn=lambda *x: x,
|
||||||
inputs=[field for field, name in paste_fields[binding.source_tabname]["fields"] if name in paste_field_names],
|
inputs=[field for field, name in paste_fields[binding.source_tabname]["fields"] if name in paste_field_names],
|
||||||
outputs=[field for field, name in fields if name in paste_field_names],
|
outputs=[field for field, name in fields if name in paste_field_names],
|
||||||
|
show_progress=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
binding.paste_button.click(
|
binding.paste_button.click(
|
||||||
@ -147,6 +146,7 @@ def connect_paste_params_buttons():
|
|||||||
_js=f"switch_to_{binding.tabname}",
|
_js=f"switch_to_{binding.tabname}",
|
||||||
inputs=None,
|
inputs=None,
|
||||||
outputs=None,
|
outputs=None,
|
||||||
|
show_progress=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -247,7 +247,7 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
|||||||
lines.append(lastline)
|
lines.append(lastline)
|
||||||
lastline = ''
|
lastline = ''
|
||||||
|
|
||||||
for i, line in enumerate(lines):
|
for line in lines:
|
||||||
line = line.strip()
|
line = line.strip()
|
||||||
if line.startswith("Negative prompt:"):
|
if line.startswith("Negative prompt:"):
|
||||||
done_with_prompt = True
|
done_with_prompt = True
|
||||||
@ -265,8 +265,8 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
|||||||
v = v[1:-1] if v[0] == '"' and v[-1] == '"' else v
|
v = v[1:-1] if v[0] == '"' and v[-1] == '"' else v
|
||||||
m = re_imagesize.match(v)
|
m = re_imagesize.match(v)
|
||||||
if m is not None:
|
if m is not None:
|
||||||
res[k+"-1"] = m.group(1)
|
res[f"{k}-1"] = m.group(1)
|
||||||
res[k+"-2"] = m.group(2)
|
res[f"{k}-2"] = m.group(2)
|
||||||
else:
|
else:
|
||||||
res[k] = v
|
res[k] = v
|
||||||
|
|
||||||
@ -308,6 +308,8 @@ infotext_to_setting_name_mapping = [
|
|||||||
('UniPC skip type', 'uni_pc_skip_type'),
|
('UniPC skip type', 'uni_pc_skip_type'),
|
||||||
('UniPC order', 'uni_pc_order'),
|
('UniPC order', 'uni_pc_order'),
|
||||||
('UniPC lower order final', 'uni_pc_lower_order_final'),
|
('UniPC lower order final', 'uni_pc_lower_order_final'),
|
||||||
|
('Token merging ratio', 'token_merging_ratio'),
|
||||||
|
('Token merging ratio hr', 'token_merging_ratio_hr'),
|
||||||
('RNG', 'randn_source'),
|
('RNG', 'randn_source'),
|
||||||
('NGMS', 's_min_uncond'),
|
('NGMS', 's_min_uncond'),
|
||||||
]
|
]
|
||||||
@ -409,12 +411,14 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component,
|
|||||||
fn=paste_func,
|
fn=paste_func,
|
||||||
inputs=[input_comp],
|
inputs=[input_comp],
|
||||||
outputs=[x[0] for x in paste_fields],
|
outputs=[x[0] for x in paste_fields],
|
||||||
|
show_progress=False,
|
||||||
)
|
)
|
||||||
button.click(
|
button.click(
|
||||||
fn=None,
|
fn=None,
|
||||||
_js=f"recalculate_prompts_{tabname}",
|
_js=f"recalculate_prompts_{tabname}",
|
||||||
inputs=[],
|
inputs=[],
|
||||||
outputs=[],
|
outputs=[],
|
||||||
|
show_progress=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -78,7 +78,7 @@ def setup_model(dirname):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
from gfpgan import GFPGANer
|
from gfpgan import GFPGANer
|
||||||
from facexlib import detection, parsing
|
from facexlib import detection, parsing # noqa: F401
|
||||||
global user_path
|
global user_path
|
||||||
global have_gfpgan
|
global have_gfpgan
|
||||||
global gfpgan_constructor
|
global gfpgan_constructor
|
||||||
|
@ -13,7 +13,7 @@ cache_data = None
|
|||||||
|
|
||||||
|
|
||||||
def dump_cache():
|
def dump_cache():
|
||||||
with filelock.FileLock(cache_filename+".lock"):
|
with filelock.FileLock(f"{cache_filename}.lock"):
|
||||||
with open(cache_filename, "w", encoding="utf8") as file:
|
with open(cache_filename, "w", encoding="utf8") as file:
|
||||||
json.dump(cache_data, file, indent=4)
|
json.dump(cache_data, file, indent=4)
|
||||||
|
|
||||||
@ -22,7 +22,7 @@ def cache(subsection):
|
|||||||
global cache_data
|
global cache_data
|
||||||
|
|
||||||
if cache_data is None:
|
if cache_data is None:
|
||||||
with filelock.FileLock(cache_filename+".lock"):
|
with filelock.FileLock(f"{cache_filename}.lock"):
|
||||||
if not os.path.isfile(cache_filename):
|
if not os.path.isfile(cache_filename):
|
||||||
cache_data = {}
|
cache_data = {}
|
||||||
else:
|
else:
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import csv
|
|
||||||
import datetime
|
import datetime
|
||||||
import glob
|
import glob
|
||||||
import html
|
import html
|
||||||
@ -18,7 +17,7 @@ from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
|||||||
from torch import einsum
|
from torch import einsum
|
||||||
from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, zeros_
|
from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, zeros_
|
||||||
|
|
||||||
from collections import defaultdict, deque
|
from collections import deque
|
||||||
from statistics import stdev, mean
|
from statistics import stdev, mean
|
||||||
|
|
||||||
|
|
||||||
@ -178,34 +177,34 @@ class Hypernetwork:
|
|||||||
|
|
||||||
def weights(self):
|
def weights(self):
|
||||||
res = []
|
res = []
|
||||||
for k, layers in self.layers.items():
|
for layers in self.layers.values():
|
||||||
for layer in layers:
|
for layer in layers:
|
||||||
res += layer.parameters()
|
res += layer.parameters()
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def train(self, mode=True):
|
def train(self, mode=True):
|
||||||
for k, layers in self.layers.items():
|
for layers in self.layers.values():
|
||||||
for layer in layers:
|
for layer in layers:
|
||||||
layer.train(mode=mode)
|
layer.train(mode=mode)
|
||||||
for param in layer.parameters():
|
for param in layer.parameters():
|
||||||
param.requires_grad = mode
|
param.requires_grad = mode
|
||||||
|
|
||||||
def to(self, device):
|
def to(self, device):
|
||||||
for k, layers in self.layers.items():
|
for layers in self.layers.values():
|
||||||
for layer in layers:
|
for layer in layers:
|
||||||
layer.to(device)
|
layer.to(device)
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def set_multiplier(self, multiplier):
|
def set_multiplier(self, multiplier):
|
||||||
for k, layers in self.layers.items():
|
for layers in self.layers.values():
|
||||||
for layer in layers:
|
for layer in layers:
|
||||||
layer.multiplier = multiplier
|
layer.multiplier = multiplier
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def eval(self):
|
def eval(self):
|
||||||
for k, layers in self.layers.items():
|
for layers in self.layers.values():
|
||||||
for layer in layers:
|
for layer in layers:
|
||||||
layer.eval()
|
layer.eval()
|
||||||
for param in layer.parameters():
|
for param in layer.parameters():
|
||||||
@ -404,7 +403,7 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None):
|
|||||||
k = self.to_k(context_k)
|
k = self.to_k(context_k)
|
||||||
v = self.to_v(context_v)
|
v = self.to_v(context_v)
|
||||||
|
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q, k, v))
|
||||||
|
|
||||||
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
||||||
|
|
||||||
@ -541,7 +540,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||||||
return hypernetwork, filename
|
return hypernetwork, filename
|
||||||
|
|
||||||
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
|
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
|
||||||
|
|
||||||
clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else None
|
clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else None
|
||||||
if clip_grad:
|
if clip_grad:
|
||||||
clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
|
clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
|
||||||
@ -594,7 +593,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||||||
print(e)
|
print(e)
|
||||||
|
|
||||||
scaler = torch.cuda.amp.GradScaler()
|
scaler = torch.cuda.amp.GradScaler()
|
||||||
|
|
||||||
batch_size = ds.batch_size
|
batch_size = ds.batch_size
|
||||||
gradient_step = ds.gradient_step
|
gradient_step = ds.gradient_step
|
||||||
# n steps = batch_size * gradient_step * n image processed
|
# n steps = batch_size * gradient_step * n image processed
|
||||||
@ -620,7 +619,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||||||
try:
|
try:
|
||||||
sd_hijack_checkpoint.add()
|
sd_hijack_checkpoint.add()
|
||||||
|
|
||||||
for i in range((steps-initial_step) * gradient_step):
|
for _ in range((steps-initial_step) * gradient_step):
|
||||||
if scheduler.finished:
|
if scheduler.finished:
|
||||||
break
|
break
|
||||||
if shared.state.interrupted:
|
if shared.state.interrupted:
|
||||||
@ -637,7 +636,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||||||
|
|
||||||
if clip_grad:
|
if clip_grad:
|
||||||
clip_grad_sched.step(hypernetwork.step)
|
clip_grad_sched.step(hypernetwork.step)
|
||||||
|
|
||||||
with devices.autocast():
|
with devices.autocast():
|
||||||
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
||||||
if use_weight:
|
if use_weight:
|
||||||
@ -658,14 +657,14 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||||||
|
|
||||||
_loss_step += loss.item()
|
_loss_step += loss.item()
|
||||||
scaler.scale(loss).backward()
|
scaler.scale(loss).backward()
|
||||||
|
|
||||||
# go back until we reach gradient accumulation steps
|
# go back until we reach gradient accumulation steps
|
||||||
if (j + 1) % gradient_step != 0:
|
if (j + 1) % gradient_step != 0:
|
||||||
continue
|
continue
|
||||||
loss_logging.append(_loss_step)
|
loss_logging.append(_loss_step)
|
||||||
if clip_grad:
|
if clip_grad:
|
||||||
clip_grad(weights, clip_grad_sched.learn_rate)
|
clip_grad(weights, clip_grad_sched.learn_rate)
|
||||||
|
|
||||||
scaler.step(optimizer)
|
scaler.step(optimizer)
|
||||||
scaler.update()
|
scaler.update()
|
||||||
hypernetwork.step += 1
|
hypernetwork.step += 1
|
||||||
@ -675,7 +674,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||||||
_loss_step = 0
|
_loss_step = 0
|
||||||
|
|
||||||
steps_done = hypernetwork.step + 1
|
steps_done = hypernetwork.step + 1
|
||||||
|
|
||||||
epoch_num = hypernetwork.step // steps_per_epoch
|
epoch_num = hypernetwork.step // steps_per_epoch
|
||||||
epoch_step = hypernetwork.step % steps_per_epoch
|
epoch_step = hypernetwork.step % steps_per_epoch
|
||||||
|
|
||||||
|
@ -1,19 +1,17 @@
|
|||||||
import html
|
import html
|
||||||
import os
|
|
||||||
import re
|
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import modules.hypernetworks.hypernetwork
|
import modules.hypernetworks.hypernetwork
|
||||||
from modules import devices, sd_hijack, shared
|
from modules import devices, sd_hijack, shared
|
||||||
|
|
||||||
not_available = ["hardswish", "multiheadattention"]
|
not_available = ["hardswish", "multiheadattention"]
|
||||||
keys = list(x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
|
keys = [x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict if x not in not_available]
|
||||||
|
|
||||||
|
|
||||||
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
|
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
|
||||||
filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure)
|
filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure)
|
||||||
|
|
||||||
return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {filename}", ""
|
return gr.Dropdown.update(choices=sorted(shared.hypernetworks)), f"Created: {filename}", ""
|
||||||
|
|
||||||
|
|
||||||
def train_hypernetwork(*args):
|
def train_hypernetwork(*args):
|
||||||
|
@ -13,17 +13,24 @@ import numpy as np
|
|||||||
import piexif
|
import piexif
|
||||||
import piexif.helper
|
import piexif.helper
|
||||||
from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
|
from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
|
||||||
from fonts.ttf import Roboto
|
|
||||||
import string
|
import string
|
||||||
import json
|
import json
|
||||||
import hashlib
|
import hashlib
|
||||||
|
|
||||||
from modules import sd_samplers, shared, script_callbacks, errors
|
from modules import sd_samplers, shared, script_callbacks, errors
|
||||||
from modules.shared import opts, cmd_opts
|
from modules.paths_internal import roboto_ttf_file
|
||||||
|
from modules.shared import opts
|
||||||
|
|
||||||
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
|
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
|
||||||
|
|
||||||
|
|
||||||
|
def get_font(fontsize: int):
|
||||||
|
try:
|
||||||
|
return ImageFont.truetype(opts.font or roboto_ttf_file, fontsize)
|
||||||
|
except Exception:
|
||||||
|
return ImageFont.truetype(roboto_ttf_file, fontsize)
|
||||||
|
|
||||||
|
|
||||||
def image_grid(imgs, batch_size=1, rows=None):
|
def image_grid(imgs, batch_size=1, rows=None):
|
||||||
if rows is None:
|
if rows is None:
|
||||||
if opts.n_rows > 0:
|
if opts.n_rows > 0:
|
||||||
@ -142,14 +149,8 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
|
|||||||
lines.append(word)
|
lines.append(word)
|
||||||
return lines
|
return lines
|
||||||
|
|
||||||
def get_font(fontsize):
|
|
||||||
try:
|
|
||||||
return ImageFont.truetype(opts.font or Roboto, fontsize)
|
|
||||||
except Exception:
|
|
||||||
return ImageFont.truetype(Roboto, fontsize)
|
|
||||||
|
|
||||||
def draw_texts(drawing, draw_x, draw_y, lines, initial_fnt, initial_fontsize):
|
def draw_texts(drawing, draw_x, draw_y, lines, initial_fnt, initial_fontsize):
|
||||||
for i, line in enumerate(lines):
|
for line in lines:
|
||||||
fnt = initial_fnt
|
fnt = initial_fnt
|
||||||
fontsize = initial_fontsize
|
fontsize = initial_fontsize
|
||||||
while drawing.multiline_textsize(line.text, font=fnt)[0] > line.allowed_width and fontsize > 0:
|
while drawing.multiline_textsize(line.text, font=fnt)[0] > line.allowed_width and fontsize > 0:
|
||||||
@ -357,6 +358,7 @@ class FilenameGenerator:
|
|||||||
'generation_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.n_iter == 1 and self.p.batch_size == 1 else self.p.iteration * self.p.batch_size + self.p.batch_index + 1,
|
'generation_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.n_iter == 1 and self.p.batch_size == 1 else self.p.iteration * self.p.batch_size + self.p.batch_index + 1,
|
||||||
'hasprompt': lambda self, *args: self.hasprompt(*args), # accepts formats:[hasprompt<prompt1|default><prompt2>..]
|
'hasprompt': lambda self, *args: self.hasprompt(*args), # accepts formats:[hasprompt<prompt1|default><prompt2>..]
|
||||||
'clip_skip': lambda self: opts.data["CLIP_stop_at_last_layers"],
|
'clip_skip': lambda self: opts.data["CLIP_stop_at_last_layers"],
|
||||||
|
'denoising': lambda self: self.p.denoising_strength if self.p and self.p.denoising_strength else NOTHING_AND_SKIP_PREVIOUS_TEXT,
|
||||||
}
|
}
|
||||||
default_time_format = '%Y%m%d%H%M%S'
|
default_time_format = '%Y%m%d%H%M%S'
|
||||||
|
|
||||||
@ -365,7 +367,7 @@ class FilenameGenerator:
|
|||||||
self.seed = seed
|
self.seed = seed
|
||||||
self.prompt = prompt
|
self.prompt = prompt
|
||||||
self.image = image
|
self.image = image
|
||||||
|
|
||||||
def hasprompt(self, *args):
|
def hasprompt(self, *args):
|
||||||
lower = self.prompt.lower()
|
lower = self.prompt.lower()
|
||||||
if self.p is None or self.prompt is None:
|
if self.p is None or self.prompt is None:
|
||||||
@ -408,13 +410,13 @@ class FilenameGenerator:
|
|||||||
time_format = args[0] if len(args) > 0 and args[0] != "" else self.default_time_format
|
time_format = args[0] if len(args) > 0 and args[0] != "" else self.default_time_format
|
||||||
try:
|
try:
|
||||||
time_zone = pytz.timezone(args[1]) if len(args) > 1 else None
|
time_zone = pytz.timezone(args[1]) if len(args) > 1 else None
|
||||||
except pytz.exceptions.UnknownTimeZoneError as _:
|
except pytz.exceptions.UnknownTimeZoneError:
|
||||||
time_zone = None
|
time_zone = None
|
||||||
|
|
||||||
time_zone_time = time_datetime.astimezone(time_zone)
|
time_zone_time = time_datetime.astimezone(time_zone)
|
||||||
try:
|
try:
|
||||||
formatted_time = time_zone_time.strftime(time_format)
|
formatted_time = time_zone_time.strftime(time_format)
|
||||||
except (ValueError, TypeError) as _:
|
except (ValueError, TypeError):
|
||||||
formatted_time = time_zone_time.strftime(self.default_time_format)
|
formatted_time = time_zone_time.strftime(self.default_time_format)
|
||||||
|
|
||||||
return sanitize_filename_part(formatted_time, replace_spaces=False)
|
return sanitize_filename_part(formatted_time, replace_spaces=False)
|
||||||
@ -466,14 +468,14 @@ def get_next_sequence_number(path, basename):
|
|||||||
"""
|
"""
|
||||||
result = -1
|
result = -1
|
||||||
if basename != '':
|
if basename != '':
|
||||||
basename = basename + "-"
|
basename = f"{basename}-"
|
||||||
|
|
||||||
prefix_length = len(basename)
|
prefix_length = len(basename)
|
||||||
for p in os.listdir(path):
|
for p in os.listdir(path):
|
||||||
if p.startswith(basename):
|
if p.startswith(basename):
|
||||||
l = os.path.splitext(p[prefix_length:])[0].split('-') # splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)
|
parts = os.path.splitext(p[prefix_length:])[0].split('-') # splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)
|
||||||
try:
|
try:
|
||||||
result = max(int(l[0]), result)
|
result = max(int(parts[0]), result)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -535,7 +537,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
|||||||
add_number = opts.save_images_add_number or file_decoration == ''
|
add_number = opts.save_images_add_number or file_decoration == ''
|
||||||
|
|
||||||
if file_decoration != "" and add_number:
|
if file_decoration != "" and add_number:
|
||||||
file_decoration = "-" + file_decoration
|
file_decoration = f"-{file_decoration}"
|
||||||
|
|
||||||
file_decoration = namegen.apply(file_decoration) + suffix
|
file_decoration = namegen.apply(file_decoration) + suffix
|
||||||
|
|
||||||
@ -565,7 +567,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
|||||||
|
|
||||||
def _atomically_save_image(image_to_save, filename_without_extension, extension):
|
def _atomically_save_image(image_to_save, filename_without_extension, extension):
|
||||||
# save image with .tmp extension to avoid race condition when another process detects new image in the directory
|
# save image with .tmp extension to avoid race condition when another process detects new image in the directory
|
||||||
temp_file_path = filename_without_extension + ".tmp"
|
temp_file_path = f"{filename_without_extension}.tmp"
|
||||||
image_format = Image.registered_extensions()[extension]
|
image_format = Image.registered_extensions()[extension]
|
||||||
|
|
||||||
if extension.lower() == '.png':
|
if extension.lower() == '.png':
|
||||||
@ -625,7 +627,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
|||||||
if opts.save_txt and info is not None:
|
if opts.save_txt and info is not None:
|
||||||
txt_fullfn = f"{fullfn_without_extension}.txt"
|
txt_fullfn = f"{fullfn_without_extension}.txt"
|
||||||
with open(txt_fullfn, "w", encoding="utf8") as file:
|
with open(txt_fullfn, "w", encoding="utf8") as file:
|
||||||
file.write(info + "\n")
|
file.write(f"{info}\n")
|
||||||
else:
|
else:
|
||||||
txt_fullfn = None
|
txt_fullfn = None
|
||||||
|
|
||||||
|
@ -1,19 +1,15 @@
|
|||||||
import math
|
|
||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError
|
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError
|
||||||
|
|
||||||
from modules import devices, sd_samplers
|
from modules import sd_samplers
|
||||||
from modules.generation_parameters_copypaste import create_override_settings_dict
|
from modules.generation_parameters_copypaste import create_override_settings_dict
|
||||||
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
|
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
|
||||||
from modules.shared import opts, state
|
from modules.shared import opts, state
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
import modules.processing as processing
|
import modules.processing as processing
|
||||||
from modules.ui import plaintext_to_html
|
from modules.ui import plaintext_to_html
|
||||||
import modules.images as images
|
|
||||||
import modules.scripts
|
import modules.scripts
|
||||||
|
|
||||||
|
|
||||||
@ -48,7 +44,8 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
img = Image.open(image)
|
img = Image.open(image)
|
||||||
except UnidentifiedImageError:
|
except UnidentifiedImageError as e:
|
||||||
|
print(e)
|
||||||
continue
|
continue
|
||||||
# Use the EXIF orientation of photos taken by smartphones.
|
# Use the EXIF orientation of photos taken by smartphones.
|
||||||
img = ImageOps.exif_transpose(img)
|
img = ImageOps.exif_transpose(img)
|
||||||
@ -58,7 +55,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
|
|||||||
# try to find corresponding mask for an image using simple filename matching
|
# try to find corresponding mask for an image using simple filename matching
|
||||||
mask_image_path = os.path.join(inpaint_mask_dir, os.path.basename(image))
|
mask_image_path = os.path.join(inpaint_mask_dir, os.path.basename(image))
|
||||||
# if not found use first one ("same mask for all images" use-case)
|
# if not found use first one ("same mask for all images" use-case)
|
||||||
if not mask_image_path in inpaint_masks:
|
if mask_image_path not in inpaint_masks:
|
||||||
mask_image_path = inpaint_masks[0]
|
mask_image_path = inpaint_masks[0]
|
||||||
mask_image = Image.open(mask_image_path)
|
mask_image = Image.open(mask_image_path)
|
||||||
p.image_mask = mask_image
|
p.image_mask = mask_image
|
||||||
|
@ -11,7 +11,6 @@ import torch.hub
|
|||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from torchvision.transforms.functional import InterpolationMode
|
from torchvision.transforms.functional import InterpolationMode
|
||||||
|
|
||||||
import modules.shared as shared
|
|
||||||
from modules import devices, paths, shared, lowvram, modelloader, errors
|
from modules import devices, paths, shared, lowvram, modelloader, errors
|
||||||
|
|
||||||
blip_image_eval_size = 384
|
blip_image_eval_size = 384
|
||||||
@ -28,7 +27,7 @@ def category_types():
|
|||||||
def download_default_clip_interrogate_categories(content_dir):
|
def download_default_clip_interrogate_categories(content_dir):
|
||||||
print("Downloading CLIP categories...")
|
print("Downloading CLIP categories...")
|
||||||
|
|
||||||
tmpdir = content_dir + "_tmp"
|
tmpdir = f"{content_dir}_tmp"
|
||||||
category_types = ["artists", "flavors", "mediums", "movements"]
|
category_types = ["artists", "flavors", "mediums", "movements"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -160,7 +159,7 @@ class InterrogateModels:
|
|||||||
text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]
|
text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]
|
||||||
|
|
||||||
top_count = min(top_count, len(text_array))
|
top_count = min(top_count, len(text_array))
|
||||||
text_tokens = clip.tokenize([text for text in text_array], truncate=True).to(devices.device_interrogate)
|
text_tokens = clip.tokenize(list(text_array), truncate=True).to(devices.device_interrogate)
|
||||||
text_features = self.clip_model.encode_text(text_tokens).type(self.dtype)
|
text_features = self.clip_model.encode_text(text_tokens).type(self.dtype)
|
||||||
text_features /= text_features.norm(dim=-1, keepdim=True)
|
text_features /= text_features.norm(dim=-1, keepdim=True)
|
||||||
|
|
||||||
@ -208,13 +207,13 @@ class InterrogateModels:
|
|||||||
|
|
||||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||||
|
|
||||||
for name, topn, items in self.categories():
|
for cat in self.categories():
|
||||||
matches = self.rank(image_features, items, top_count=topn)
|
matches = self.rank(image_features, cat.items, top_count=cat.topn)
|
||||||
for match, score in matches:
|
for match, score in matches:
|
||||||
if shared.opts.interrogate_return_ranks:
|
if shared.opts.interrogate_return_ranks:
|
||||||
res += f", ({match}:{score/100:.3f})"
|
res += f", ({match}:{score/100:.3f})"
|
||||||
else:
|
else:
|
||||||
res += ", " + match
|
res += f", {match}"
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
print("Error interrogating", file=sys.stderr)
|
print("Error interrogating", file=sys.stderr)
|
||||||
|
@ -23,7 +23,7 @@ def list_localizations(dirname):
|
|||||||
localizations[fn] = file.path
|
localizations[fn] = file.path
|
||||||
|
|
||||||
|
|
||||||
def localization_js(current_localization_name):
|
def localization_js(current_localization_name: str) -> str:
|
||||||
fn = localizations.get(current_localization_name, None)
|
fn = localizations.get(current_localization_name, None)
|
||||||
data = {}
|
data = {}
|
||||||
if fn is not None:
|
if fn is not None:
|
||||||
@ -34,4 +34,4 @@ def localization_js(current_localization_name):
|
|||||||
print(f"Error loading localization from {fn}:", file=sys.stderr)
|
print(f"Error loading localization from {fn}:", file=sys.stderr)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
|
||||||
return f"var localization = {json.dumps(data)}\n"
|
return f"window.localization = {json.dumps(data)}"
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
import platform
|
import platform
|
||||||
from modules import paths
|
|
||||||
from modules.sd_hijack_utils import CondFunc
|
from modules.sd_hijack_utils import CondFunc
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
@ -43,7 +42,7 @@ if has_mps:
|
|||||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
|
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
|
||||||
CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs),
|
CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs),
|
||||||
lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps'))
|
lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps'))
|
||||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/80800
|
# MPS workaround for https://github.com/pytorch/pytorch/issues/80800
|
||||||
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs),
|
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs),
|
||||||
lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps')
|
lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps')
|
||||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/90532
|
# MPS workaround for https://github.com/pytorch/pytorch/issues/90532
|
||||||
@ -54,6 +53,11 @@ if has_mps:
|
|||||||
CondFunc('torch.cumsum', cumsum_fix_func, None)
|
CondFunc('torch.cumsum', cumsum_fix_func, None)
|
||||||
CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)
|
CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)
|
||||||
CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None)
|
CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None)
|
||||||
if version.parse(torch.__version__) == version.parse("2.0"):
|
|
||||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/96113
|
# MPS workaround for https://github.com/pytorch/pytorch/issues/96113
|
||||||
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda *args, **kwargs: len(args) == 6)
|
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda _, input, *args, **kwargs: len(args) == 4 and input.device.type == 'mps')
|
||||||
|
|
||||||
|
# MPS workaround for https://github.com/pytorch/pytorch/issues/92311
|
||||||
|
if platform.processor() == 'i386':
|
||||||
|
for funcName in ['torch.argmax', 'torch.Tensor.argmax']:
|
||||||
|
CondFunc(funcName, lambda _, input, *args, **kwargs: torch.max(input.float() if input.dtype == torch.int64 else input, *args, **kwargs)[1], lambda _, input, *args, **kwargs: input.device.type == 'mps')
|
||||||
|
@ -4,7 +4,7 @@ from PIL import Image, ImageFilter, ImageOps
|
|||||||
def get_crop_region(mask, pad=0):
|
def get_crop_region(mask, pad=0):
|
||||||
"""finds a rectangular region that contains all masked ares in an image. Returns (x1, y1, x2, y2) coordinates of the rectangle.
|
"""finds a rectangular region that contains all masked ares in an image. Returns (x1, y1, x2, y2) coordinates of the rectangle.
|
||||||
For example, if a user has painted the top-right part of a 512x512 image", the result may be (256, 0, 512, 256)"""
|
For example, if a user has painted the top-right part of a 512x512 image", the result may be (256, 0, 512, 256)"""
|
||||||
|
|
||||||
h, w = mask.shape
|
h, w = mask.shape
|
||||||
|
|
||||||
crop_left = 0
|
crop_left = 0
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import glob
|
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import importlib
|
import importlib
|
||||||
@ -22,9 +21,6 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
|
|||||||
"""
|
"""
|
||||||
output = []
|
output = []
|
||||||
|
|
||||||
if ext_filter is None:
|
|
||||||
ext_filter = []
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
places = []
|
places = []
|
||||||
|
|
||||||
@ -39,22 +35,14 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
|
|||||||
places.append(model_path)
|
places.append(model_path)
|
||||||
|
|
||||||
for place in places:
|
for place in places:
|
||||||
if os.path.exists(place):
|
for full_path in shared.walk_files(place, allowed_extensions=ext_filter):
|
||||||
for file in glob.iglob(place + '**/**', recursive=True):
|
if os.path.islink(full_path) and not os.path.exists(full_path):
|
||||||
full_path = file
|
print(f"Skipping broken symlink: {full_path}")
|
||||||
if os.path.isdir(full_path):
|
continue
|
||||||
continue
|
if ext_blacklist is not None and any(full_path.endswith(x) for x in ext_blacklist):
|
||||||
if os.path.islink(full_path) and not os.path.exists(full_path):
|
continue
|
||||||
print(f"Skipping broken symlink: {full_path}")
|
if full_path not in output:
|
||||||
continue
|
output.append(full_path)
|
||||||
if ext_blacklist is not None and any([full_path.endswith(x) for x in ext_blacklist]):
|
|
||||||
continue
|
|
||||||
if len(ext_filter) != 0:
|
|
||||||
model_name, extension = os.path.splitext(file)
|
|
||||||
if extension not in ext_filter:
|
|
||||||
continue
|
|
||||||
if file not in output:
|
|
||||||
output.append(full_path)
|
|
||||||
|
|
||||||
if model_url is not None and len(output) == 0:
|
if model_url is not None and len(output) == 0:
|
||||||
if download_name is not None:
|
if download_name is not None:
|
||||||
@ -119,32 +107,15 @@ def move_files(src_path: str, dest_path: str, ext_filter: str = None):
|
|||||||
print(f"Moving {file} from {src_path} to {dest_path}.")
|
print(f"Moving {file} from {src_path} to {dest_path}.")
|
||||||
try:
|
try:
|
||||||
shutil.move(fullpath, dest_path)
|
shutil.move(fullpath, dest_path)
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
if len(os.listdir(src_path)) == 0:
|
if len(os.listdir(src_path)) == 0:
|
||||||
print(f"Removing empty folder: {src_path}")
|
print(f"Removing empty folder: {src_path}")
|
||||||
shutil.rmtree(src_path, True)
|
shutil.rmtree(src_path, True)
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
builtin_upscaler_classes = []
|
|
||||||
forbidden_upscaler_classes = set()
|
|
||||||
|
|
||||||
|
|
||||||
def list_builtin_upscalers():
|
|
||||||
load_upscalers()
|
|
||||||
|
|
||||||
builtin_upscaler_classes.clear()
|
|
||||||
builtin_upscaler_classes.extend(Upscaler.__subclasses__())
|
|
||||||
|
|
||||||
|
|
||||||
def forbid_loaded_nonbuiltin_upscalers():
|
|
||||||
for cls in Upscaler.__subclasses__():
|
|
||||||
if cls not in builtin_upscaler_classes:
|
|
||||||
forbidden_upscaler_classes.add(cls)
|
|
||||||
|
|
||||||
|
|
||||||
def load_upscalers():
|
def load_upscalers():
|
||||||
# We can only do this 'magic' method to dynamically load upscalers if they are referenced,
|
# We can only do this 'magic' method to dynamically load upscalers if they are referenced,
|
||||||
# so we'll try to import any _model.py files before looking in __subclasses__
|
# so we'll try to import any _model.py files before looking in __subclasses__
|
||||||
@ -155,15 +126,22 @@ def load_upscalers():
|
|||||||
full_model = f"modules.{model_name}_model"
|
full_model = f"modules.{model_name}_model"
|
||||||
try:
|
try:
|
||||||
importlib.import_module(full_model)
|
importlib.import_module(full_model)
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
datas = []
|
datas = []
|
||||||
commandline_options = vars(shared.cmd_opts)
|
commandline_options = vars(shared.cmd_opts)
|
||||||
for cls in Upscaler.__subclasses__():
|
|
||||||
if cls in forbidden_upscaler_classes:
|
|
||||||
continue
|
|
||||||
|
|
||||||
|
# some of upscaler classes will not go away after reloading their modules, and we'll end
|
||||||
|
# up with two copies of those classes. The newest copy will always be the last in the list,
|
||||||
|
# so we go from end to beginning and ignore duplicates
|
||||||
|
used_classes = {}
|
||||||
|
for cls in reversed(Upscaler.__subclasses__()):
|
||||||
|
classname = str(cls)
|
||||||
|
if classname not in used_classes:
|
||||||
|
used_classes[classname] = cls
|
||||||
|
|
||||||
|
for cls in reversed(used_classes.values()):
|
||||||
name = cls.__name__
|
name = cls.__name__
|
||||||
cmd_name = f"{name.lower().replace('upscaler', '')}_models_path"
|
cmd_name = f"{name.lower().replace('upscaler', '')}_models_path"
|
||||||
scaler = cls(commandline_options.get(cmd_name, None))
|
scaler = cls(commandline_options.get(cmd_name, None))
|
||||||
|
@ -52,7 +52,7 @@ class DDPM(pl.LightningModule):
|
|||||||
beta_schedule="linear",
|
beta_schedule="linear",
|
||||||
loss_type="l2",
|
loss_type="l2",
|
||||||
ckpt_path=None,
|
ckpt_path=None,
|
||||||
ignore_keys=[],
|
ignore_keys=None,
|
||||||
load_only_unet=False,
|
load_only_unet=False,
|
||||||
monitor="val/loss",
|
monitor="val/loss",
|
||||||
use_ema=True,
|
use_ema=True,
|
||||||
@ -107,7 +107,7 @@ class DDPM(pl.LightningModule):
|
|||||||
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
||||||
|
|
||||||
if ckpt_path is not None:
|
if ckpt_path is not None:
|
||||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
|
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [], only_model=load_only_unet)
|
||||||
|
|
||||||
# If initialing from EMA-only checkpoint, create EMA model after loading.
|
# If initialing from EMA-only checkpoint, create EMA model after loading.
|
||||||
if self.use_ema and not load_ema:
|
if self.use_ema and not load_ema:
|
||||||
@ -194,7 +194,9 @@ class DDPM(pl.LightningModule):
|
|||||||
if context is not None:
|
if context is not None:
|
||||||
print(f"{context}: Restored training weights")
|
print(f"{context}: Restored training weights")
|
||||||
|
|
||||||
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
|
def init_from_ckpt(self, path, ignore_keys=None, only_model=False):
|
||||||
|
ignore_keys = ignore_keys or []
|
||||||
|
|
||||||
sd = torch.load(path, map_location="cpu")
|
sd = torch.load(path, map_location="cpu")
|
||||||
if "state_dict" in list(sd.keys()):
|
if "state_dict" in list(sd.keys()):
|
||||||
sd = sd["state_dict"]
|
sd = sd["state_dict"]
|
||||||
@ -223,7 +225,7 @@ class DDPM(pl.LightningModule):
|
|||||||
for k in keys:
|
for k in keys:
|
||||||
for ik in ignore_keys:
|
for ik in ignore_keys:
|
||||||
if k.startswith(ik):
|
if k.startswith(ik):
|
||||||
print("Deleting key {} from state_dict.".format(k))
|
print(f"Deleting key {k} from state_dict.")
|
||||||
del sd[k]
|
del sd[k]
|
||||||
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
|
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
|
||||||
sd, strict=False)
|
sd, strict=False)
|
||||||
@ -386,7 +388,7 @@ class DDPM(pl.LightningModule):
|
|||||||
_, loss_dict_no_ema = self.shared_step(batch)
|
_, loss_dict_no_ema = self.shared_step(batch)
|
||||||
with self.ema_scope():
|
with self.ema_scope():
|
||||||
_, loss_dict_ema = self.shared_step(batch)
|
_, loss_dict_ema = self.shared_step(batch)
|
||||||
loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
|
loss_dict_ema = {f"{key}_ema": loss_dict_ema[key] for key in loss_dict_ema}
|
||||||
self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
|
self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
|
||||||
self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
|
self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
|
||||||
|
|
||||||
@ -403,7 +405,7 @@ class DDPM(pl.LightningModule):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
|
def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
|
||||||
log = dict()
|
log = {}
|
||||||
x = self.get_input(batch, self.first_stage_key)
|
x = self.get_input(batch, self.first_stage_key)
|
||||||
N = min(x.shape[0], N)
|
N = min(x.shape[0], N)
|
||||||
n_row = min(x.shape[0], n_row)
|
n_row = min(x.shape[0], n_row)
|
||||||
@ -411,7 +413,7 @@ class DDPM(pl.LightningModule):
|
|||||||
log["inputs"] = x
|
log["inputs"] = x
|
||||||
|
|
||||||
# get diffusion row
|
# get diffusion row
|
||||||
diffusion_row = list()
|
diffusion_row = []
|
||||||
x_start = x[:n_row]
|
x_start = x[:n_row]
|
||||||
|
|
||||||
for t in range(self.num_timesteps):
|
for t in range(self.num_timesteps):
|
||||||
@ -473,13 +475,13 @@ class LatentDiffusion(DDPM):
|
|||||||
conditioning_key = None
|
conditioning_key = None
|
||||||
ckpt_path = kwargs.pop("ckpt_path", None)
|
ckpt_path = kwargs.pop("ckpt_path", None)
|
||||||
ignore_keys = kwargs.pop("ignore_keys", [])
|
ignore_keys = kwargs.pop("ignore_keys", [])
|
||||||
super().__init__(conditioning_key=conditioning_key, *args, load_ema=load_ema, **kwargs)
|
super().__init__(*args, conditioning_key=conditioning_key, load_ema=load_ema, **kwargs)
|
||||||
self.concat_mode = concat_mode
|
self.concat_mode = concat_mode
|
||||||
self.cond_stage_trainable = cond_stage_trainable
|
self.cond_stage_trainable = cond_stage_trainable
|
||||||
self.cond_stage_key = cond_stage_key
|
self.cond_stage_key = cond_stage_key
|
||||||
try:
|
try:
|
||||||
self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
|
self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
|
||||||
except:
|
except Exception:
|
||||||
self.num_downs = 0
|
self.num_downs = 0
|
||||||
if not scale_by_std:
|
if not scale_by_std:
|
||||||
self.scale_factor = scale_factor
|
self.scale_factor = scale_factor
|
||||||
@ -891,16 +893,6 @@ class LatentDiffusion(DDPM):
|
|||||||
c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
|
c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
|
||||||
return self.p_losses(x, c, t, *args, **kwargs)
|
return self.p_losses(x, c, t, *args, **kwargs)
|
||||||
|
|
||||||
def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset
|
|
||||||
def rescale_bbox(bbox):
|
|
||||||
x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
|
|
||||||
y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
|
|
||||||
w = min(bbox[2] / crop_coordinates[2], 1 - x0)
|
|
||||||
h = min(bbox[3] / crop_coordinates[3], 1 - y0)
|
|
||||||
return x0, y0, w, h
|
|
||||||
|
|
||||||
return [rescale_bbox(b) for b in bboxes]
|
|
||||||
|
|
||||||
def apply_model(self, x_noisy, t, cond, return_ids=False):
|
def apply_model(self, x_noisy, t, cond, return_ids=False):
|
||||||
|
|
||||||
if isinstance(cond, dict):
|
if isinstance(cond, dict):
|
||||||
@ -1140,7 +1132,7 @@ class LatentDiffusion(DDPM):
|
|||||||
if cond is not None:
|
if cond is not None:
|
||||||
if isinstance(cond, dict):
|
if isinstance(cond, dict):
|
||||||
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
|
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
|
||||||
list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
|
[x[:batch_size] for x in cond[key]] for key in cond}
|
||||||
else:
|
else:
|
||||||
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
|
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
|
||||||
|
|
||||||
@ -1171,8 +1163,10 @@ class LatentDiffusion(DDPM):
|
|||||||
|
|
||||||
if i % log_every_t == 0 or i == timesteps - 1:
|
if i % log_every_t == 0 or i == timesteps - 1:
|
||||||
intermediates.append(x0_partial)
|
intermediates.append(x0_partial)
|
||||||
if callback: callback(i)
|
if callback:
|
||||||
if img_callback: img_callback(img, i)
|
callback(i)
|
||||||
|
if img_callback:
|
||||||
|
img_callback(img, i)
|
||||||
return img, intermediates
|
return img, intermediates
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@ -1219,8 +1213,10 @@ class LatentDiffusion(DDPM):
|
|||||||
|
|
||||||
if i % log_every_t == 0 or i == timesteps - 1:
|
if i % log_every_t == 0 or i == timesteps - 1:
|
||||||
intermediates.append(img)
|
intermediates.append(img)
|
||||||
if callback: callback(i)
|
if callback:
|
||||||
if img_callback: img_callback(img, i)
|
callback(i)
|
||||||
|
if img_callback:
|
||||||
|
img_callback(img, i)
|
||||||
|
|
||||||
if return_intermediates:
|
if return_intermediates:
|
||||||
return img, intermediates
|
return img, intermediates
|
||||||
@ -1235,7 +1231,7 @@ class LatentDiffusion(DDPM):
|
|||||||
if cond is not None:
|
if cond is not None:
|
||||||
if isinstance(cond, dict):
|
if isinstance(cond, dict):
|
||||||
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
|
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
|
||||||
list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
|
[x[:batch_size] for x in cond[key]] for key in cond}
|
||||||
else:
|
else:
|
||||||
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
|
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
|
||||||
return self.p_sample_loop(cond,
|
return self.p_sample_loop(cond,
|
||||||
@ -1267,7 +1263,7 @@ class LatentDiffusion(DDPM):
|
|||||||
|
|
||||||
use_ddim = False
|
use_ddim = False
|
||||||
|
|
||||||
log = dict()
|
log = {}
|
||||||
z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
|
z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
|
||||||
return_first_stage_outputs=True,
|
return_first_stage_outputs=True,
|
||||||
force_c_encode=True,
|
force_c_encode=True,
|
||||||
@ -1295,7 +1291,7 @@ class LatentDiffusion(DDPM):
|
|||||||
|
|
||||||
if plot_diffusion_rows:
|
if plot_diffusion_rows:
|
||||||
# get diffusion row
|
# get diffusion row
|
||||||
diffusion_row = list()
|
diffusion_row = []
|
||||||
z_start = z[:n_row]
|
z_start = z[:n_row]
|
||||||
for t in range(self.num_timesteps):
|
for t in range(self.num_timesteps):
|
||||||
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
|
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
|
||||||
@ -1337,7 +1333,7 @@ class LatentDiffusion(DDPM):
|
|||||||
|
|
||||||
if inpaint:
|
if inpaint:
|
||||||
# make a simple center square
|
# make a simple center square
|
||||||
b, h, w = z.shape[0], z.shape[2], z.shape[3]
|
h, w = z.shape[2], z.shape[3]
|
||||||
mask = torch.ones(N, h, w).to(self.device)
|
mask = torch.ones(N, h, w).to(self.device)
|
||||||
# zeros will be filled in
|
# zeros will be filled in
|
||||||
mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
|
mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
|
||||||
@ -1439,10 +1435,10 @@ class Layout2ImgDiffusion(LatentDiffusion):
|
|||||||
# TODO: move all layout-specific hacks to this class
|
# TODO: move all layout-specific hacks to this class
|
||||||
def __init__(self, cond_stage_key, *args, **kwargs):
|
def __init__(self, cond_stage_key, *args, **kwargs):
|
||||||
assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
|
assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
|
||||||
super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs)
|
super().__init__(*args, cond_stage_key=cond_stage_key, **kwargs)
|
||||||
|
|
||||||
def log_images(self, batch, N=8, *args, **kwargs):
|
def log_images(self, batch, N=8, *args, **kwargs):
|
||||||
logs = super().log_images(batch=batch, N=N, *args, **kwargs)
|
logs = super().log_images(*args, batch=batch, N=N, **kwargs)
|
||||||
|
|
||||||
key = 'train' if self.training else 'validation'
|
key = 'train' if self.training else 'validation'
|
||||||
dset = self.trainer.datamodule.datasets[key]
|
dset = self.trainer.datamodule.datasets[key]
|
||||||
|
@ -1 +1 @@
|
|||||||
from .sampler import UniPCSampler
|
from .sampler import UniPCSampler # noqa: F401
|
||||||
|
@ -54,7 +54,8 @@ class UniPCSampler(object):
|
|||||||
if conditioning is not None:
|
if conditioning is not None:
|
||||||
if isinstance(conditioning, dict):
|
if isinstance(conditioning, dict):
|
||||||
ctmp = conditioning[list(conditioning.keys())[0]]
|
ctmp = conditioning[list(conditioning.keys())[0]]
|
||||||
while isinstance(ctmp, list): ctmp = ctmp[0]
|
while isinstance(ctmp, list):
|
||||||
|
ctmp = ctmp[0]
|
||||||
cbs = ctmp.shape[0]
|
cbs = ctmp.shape[0]
|
||||||
if cbs != batch_size:
|
if cbs != batch_size:
|
||||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
|
||||||
import math
|
import math
|
||||||
from tqdm.auto import trange
|
import tqdm
|
||||||
|
|
||||||
|
|
||||||
class NoiseScheduleVP:
|
class NoiseScheduleVP:
|
||||||
@ -94,7 +93,7 @@ class NoiseScheduleVP:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
if schedule not in ['discrete', 'linear', 'cosine']:
|
if schedule not in ['discrete', 'linear', 'cosine']:
|
||||||
raise ValueError("Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(schedule))
|
raise ValueError(f"Unsupported noise schedule {schedule}. The schedule needs to be 'discrete' or 'linear' or 'cosine'")
|
||||||
|
|
||||||
self.schedule = schedule
|
self.schedule = schedule
|
||||||
if schedule == 'discrete':
|
if schedule == 'discrete':
|
||||||
@ -179,13 +178,13 @@ def model_wrapper(
|
|||||||
model,
|
model,
|
||||||
noise_schedule,
|
noise_schedule,
|
||||||
model_type="noise",
|
model_type="noise",
|
||||||
model_kwargs={},
|
model_kwargs=None,
|
||||||
guidance_type="uncond",
|
guidance_type="uncond",
|
||||||
#condition=None,
|
#condition=None,
|
||||||
#unconditional_condition=None,
|
#unconditional_condition=None,
|
||||||
guidance_scale=1.,
|
guidance_scale=1.,
|
||||||
classifier_fn=None,
|
classifier_fn=None,
|
||||||
classifier_kwargs={},
|
classifier_kwargs=None,
|
||||||
):
|
):
|
||||||
"""Create a wrapper function for the noise prediction model.
|
"""Create a wrapper function for the noise prediction model.
|
||||||
|
|
||||||
@ -276,6 +275,9 @@ def model_wrapper(
|
|||||||
A noise prediction model that accepts the noised data and the continuous time as the inputs.
|
A noise prediction model that accepts the noised data and the continuous time as the inputs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
model_kwargs = model_kwargs or {}
|
||||||
|
classifier_kwargs = classifier_kwargs or {}
|
||||||
|
|
||||||
def get_model_input_time(t_continuous):
|
def get_model_input_time(t_continuous):
|
||||||
"""
|
"""
|
||||||
Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
|
Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
|
||||||
@ -342,7 +344,7 @@ def model_wrapper(
|
|||||||
t_in = torch.cat([t_continuous] * 2)
|
t_in = torch.cat([t_continuous] * 2)
|
||||||
if isinstance(condition, dict):
|
if isinstance(condition, dict):
|
||||||
assert isinstance(unconditional_condition, dict)
|
assert isinstance(unconditional_condition, dict)
|
||||||
c_in = dict()
|
c_in = {}
|
||||||
for k in condition:
|
for k in condition:
|
||||||
if isinstance(condition[k], list):
|
if isinstance(condition[k], list):
|
||||||
c_in[k] = [torch.cat([
|
c_in[k] = [torch.cat([
|
||||||
@ -353,7 +355,7 @@ def model_wrapper(
|
|||||||
unconditional_condition[k],
|
unconditional_condition[k],
|
||||||
condition[k]])
|
condition[k]])
|
||||||
elif isinstance(condition, list):
|
elif isinstance(condition, list):
|
||||||
c_in = list()
|
c_in = []
|
||||||
assert isinstance(unconditional_condition, list)
|
assert isinstance(unconditional_condition, list)
|
||||||
for i in range(len(condition)):
|
for i in range(len(condition)):
|
||||||
c_in.append(torch.cat([unconditional_condition[i], condition[i]]))
|
c_in.append(torch.cat([unconditional_condition[i], condition[i]]))
|
||||||
@ -469,7 +471,7 @@ class UniPC:
|
|||||||
t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device)
|
t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device)
|
||||||
return t
|
return t
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
|
raise ValueError(f"Unsupported skip_type {skip_type}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'")
|
||||||
|
|
||||||
def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
|
def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
|
||||||
"""
|
"""
|
||||||
@ -757,40 +759,44 @@ class UniPC:
|
|||||||
vec_t = timesteps[0].expand((x.shape[0]))
|
vec_t = timesteps[0].expand((x.shape[0]))
|
||||||
model_prev_list = [self.model_fn(x, vec_t)]
|
model_prev_list = [self.model_fn(x, vec_t)]
|
||||||
t_prev_list = [vec_t]
|
t_prev_list = [vec_t]
|
||||||
# Init the first `order` values by lower order multistep DPM-Solver.
|
with tqdm.tqdm(total=steps) as pbar:
|
||||||
for init_order in range(1, order):
|
# Init the first `order` values by lower order multistep DPM-Solver.
|
||||||
vec_t = timesteps[init_order].expand(x.shape[0])
|
for init_order in range(1, order):
|
||||||
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, init_order, use_corrector=True)
|
vec_t = timesteps[init_order].expand(x.shape[0])
|
||||||
if model_x is None:
|
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, init_order, use_corrector=True)
|
||||||
model_x = self.model_fn(x, vec_t)
|
|
||||||
if self.after_update is not None:
|
|
||||||
self.after_update(x, model_x)
|
|
||||||
model_prev_list.append(model_x)
|
|
||||||
t_prev_list.append(vec_t)
|
|
||||||
for step in trange(order, steps + 1):
|
|
||||||
vec_t = timesteps[step].expand(x.shape[0])
|
|
||||||
if lower_order_final:
|
|
||||||
step_order = min(order, steps + 1 - step)
|
|
||||||
else:
|
|
||||||
step_order = order
|
|
||||||
#print('this step order:', step_order)
|
|
||||||
if step == steps:
|
|
||||||
#print('do not run corrector at the last step')
|
|
||||||
use_corrector = False
|
|
||||||
else:
|
|
||||||
use_corrector = True
|
|
||||||
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector)
|
|
||||||
if self.after_update is not None:
|
|
||||||
self.after_update(x, model_x)
|
|
||||||
for i in range(order - 1):
|
|
||||||
t_prev_list[i] = t_prev_list[i + 1]
|
|
||||||
model_prev_list[i] = model_prev_list[i + 1]
|
|
||||||
t_prev_list[-1] = vec_t
|
|
||||||
# We do not need to evaluate the final model value.
|
|
||||||
if step < steps:
|
|
||||||
if model_x is None:
|
if model_x is None:
|
||||||
model_x = self.model_fn(x, vec_t)
|
model_x = self.model_fn(x, vec_t)
|
||||||
model_prev_list[-1] = model_x
|
if self.after_update is not None:
|
||||||
|
self.after_update(x, model_x)
|
||||||
|
model_prev_list.append(model_x)
|
||||||
|
t_prev_list.append(vec_t)
|
||||||
|
pbar.update()
|
||||||
|
|
||||||
|
for step in range(order, steps + 1):
|
||||||
|
vec_t = timesteps[step].expand(x.shape[0])
|
||||||
|
if lower_order_final:
|
||||||
|
step_order = min(order, steps + 1 - step)
|
||||||
|
else:
|
||||||
|
step_order = order
|
||||||
|
#print('this step order:', step_order)
|
||||||
|
if step == steps:
|
||||||
|
#print('do not run corrector at the last step')
|
||||||
|
use_corrector = False
|
||||||
|
else:
|
||||||
|
use_corrector = True
|
||||||
|
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector)
|
||||||
|
if self.after_update is not None:
|
||||||
|
self.after_update(x, model_x)
|
||||||
|
for i in range(order - 1):
|
||||||
|
t_prev_list[i] = t_prev_list[i + 1]
|
||||||
|
model_prev_list[i] = model_prev_list[i + 1]
|
||||||
|
t_prev_list[-1] = vec_t
|
||||||
|
# We do not need to evaluate the final model value.
|
||||||
|
if step < steps:
|
||||||
|
if model_x is None:
|
||||||
|
model_x = self.model_fn(x, vec_t)
|
||||||
|
model_prev_list[-1] = model_x
|
||||||
|
pbar.update()
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
if denoise_to_zero:
|
if denoise_to_zero:
|
||||||
|
@ -7,13 +7,13 @@ def connect(token, port, region):
|
|||||||
else:
|
else:
|
||||||
if ':' in token:
|
if ':' in token:
|
||||||
# token = authtoken:username:password
|
# token = authtoken:username:password
|
||||||
account = token.split(':')[1] + ':' + token.split(':')[-1]
|
token, username, password = token.split(':', 2)
|
||||||
token = token.split(':')[0]
|
account = f"{username}:{password}"
|
||||||
|
|
||||||
config = conf.PyngrokConfig(
|
config = conf.PyngrokConfig(
|
||||||
auth_token=token, region=region
|
auth_token=token, region=region
|
||||||
)
|
)
|
||||||
|
|
||||||
# Guard for existing tunnels
|
# Guard for existing tunnels
|
||||||
existing = ngrok.get_tunnels(pyngrok_config=config)
|
existing = ngrok.get_tunnels(pyngrok_config=config)
|
||||||
if existing:
|
if existing:
|
||||||
@ -24,7 +24,7 @@ def connect(token, port, region):
|
|||||||
print(f'ngrok has already been connected to localhost:{port}! URL: {public_url}\n'
|
print(f'ngrok has already been connected to localhost:{port}! URL: {public_url}\n'
|
||||||
'You can use this link after the launch is complete.')
|
'You can use this link after the launch is complete.')
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if account is None:
|
if account is None:
|
||||||
public_url = ngrok.connect(port, pyngrok_config=config, bind_tls=True).public_url
|
public_url = ngrok.connect(port, pyngrok_config=config, bind_tls=True).public_url
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir
|
from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir # noqa: F401
|
||||||
|
|
||||||
import modules.safe
|
import modules.safe # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
# data_path = cmd_opts_pre.data
|
# data_path = cmd_opts_pre.data
|
||||||
@ -16,7 +16,7 @@ for possible_sd_path in possible_sd_paths:
|
|||||||
sd_path = os.path.abspath(possible_sd_path)
|
sd_path = os.path.abspath(possible_sd_path)
|
||||||
break
|
break
|
||||||
|
|
||||||
assert sd_path is not None, "Couldn't find Stable Diffusion in any of: " + str(possible_sd_paths)
|
assert sd_path is not None, f"Couldn't find Stable Diffusion in any of: {possible_sd_paths}"
|
||||||
|
|
||||||
path_dirs = [
|
path_dirs = [
|
||||||
(sd_path, 'ldm', 'Stable Diffusion', []),
|
(sd_path, 'ldm', 'Stable Diffusion', []),
|
||||||
|
@ -2,8 +2,14 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
|
import shlex
|
||||||
|
|
||||||
script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
|
||||||
|
sys.argv += shlex.split(commandline_args)
|
||||||
|
|
||||||
|
modules_path = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
script_path = os.path.dirname(modules_path)
|
||||||
|
|
||||||
sd_configs_path = os.path.join(script_path, "configs")
|
sd_configs_path = os.path.join(script_path, "configs")
|
||||||
sd_default_config = os.path.join(sd_configs_path, "v1-inference.yaml")
|
sd_default_config = os.path.join(sd_configs_path, "v1-inference.yaml")
|
||||||
@ -12,7 +18,7 @@ default_sd_model_file = sd_model_file
|
|||||||
|
|
||||||
# Parse the --data-dir flag first so we can use it as a base for our other argument default values
|
# Parse the --data-dir flag first so we can use it as a base for our other argument default values
|
||||||
parser_pre = argparse.ArgumentParser(add_help=False)
|
parser_pre = argparse.ArgumentParser(add_help=False)
|
||||||
parser_pre.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored",)
|
parser_pre.add_argument("--data-dir", type=str, default=os.path.dirname(modules_path), help="base path where all user data is stored", )
|
||||||
cmd_opts_pre = parser_pre.parse_known_args()[0]
|
cmd_opts_pre = parser_pre.parse_known_args()[0]
|
||||||
|
|
||||||
data_path = cmd_opts_pre.data_dir
|
data_path = cmd_opts_pre.data_dir
|
||||||
@ -21,3 +27,5 @@ models_path = os.path.join(data_path, "models")
|
|||||||
extensions_dir = os.path.join(data_path, "extensions")
|
extensions_dir = os.path.join(data_path, "extensions")
|
||||||
extensions_builtin_dir = os.path.join(script_path, "extensions-builtin")
|
extensions_builtin_dir = os.path.join(script_path, "extensions-builtin")
|
||||||
config_states_dir = os.path.join(script_path, "config_states")
|
config_states_dir = os.path.join(script_path, "config_states")
|
||||||
|
|
||||||
|
roboto_ttf_file = os.path.join(modules_path, 'Roboto-Regular.ttf')
|
||||||
|
@ -2,7 +2,6 @@ import json
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import warnings
|
|
||||||
import hashlib
|
import hashlib
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -11,10 +10,10 @@ from PIL import Image, ImageFilter, ImageOps
|
|||||||
import random
|
import random
|
||||||
import cv2
|
import cv2
|
||||||
from skimage import exposure
|
from skimage import exposure
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
import modules.sd_hijack
|
import modules.sd_hijack
|
||||||
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks, sd_vae_approx, scripts
|
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts
|
||||||
from modules.sd_hijack import model_hijack
|
from modules.sd_hijack import model_hijack
|
||||||
from modules.shared import opts, cmd_opts, state
|
from modules.shared import opts, cmd_opts, state
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
@ -30,6 +29,13 @@ from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
|
|||||||
|
|
||||||
from einops import repeat, rearrange
|
from einops import repeat, rearrange
|
||||||
from blendmodes.blend import blendLayers, BlendType
|
from blendmodes.blend import blendLayers, BlendType
|
||||||
|
import tomesd
|
||||||
|
|
||||||
|
# add a logger for the processing module
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
# manually set output level here since there is no option to do so yet through launch options
|
||||||
|
# logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s %(name)s %(message)s')
|
||||||
|
|
||||||
|
|
||||||
# some of those options should not be changed at all because they would break the model, so I removed them from options.
|
# some of those options should not be changed at all because they would break the model, so I removed them from options.
|
||||||
opt_C = 4
|
opt_C = 4
|
||||||
@ -165,7 +171,7 @@ class StableDiffusionProcessing:
|
|||||||
self.all_subseeds = None
|
self.all_subseeds = None
|
||||||
self.iteration = 0
|
self.iteration = 0
|
||||||
self.is_hr_pass = False
|
self.is_hr_pass = False
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def sd_model(self):
|
def sd_model(self):
|
||||||
@ -458,10 +464,21 @@ def fix_seed(p):
|
|||||||
p.subseed = get_fixed_seed(p.subseed)
|
p.subseed = get_fixed_seed(p.subseed)
|
||||||
|
|
||||||
|
|
||||||
|
def program_version():
|
||||||
|
import launch
|
||||||
|
|
||||||
|
res = launch.git_tag()
|
||||||
|
if res == "<none>":
|
||||||
|
res = None
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0):
|
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0):
|
||||||
index = position_in_batch + iteration * p.batch_size
|
index = position_in_batch + iteration * p.batch_size
|
||||||
|
|
||||||
clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
|
clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
|
||||||
|
enable_hr = getattr(p, 'enable_hr', False)
|
||||||
|
|
||||||
generation_params = {
|
generation_params = {
|
||||||
"Steps": p.steps,
|
"Steps": p.steps,
|
||||||
@ -480,16 +497,19 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
|
|||||||
"Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
|
"Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
|
||||||
"Clip skip": None if clip_skip <= 1 else clip_skip,
|
"Clip skip": None if clip_skip <= 1 else clip_skip,
|
||||||
"ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
|
"ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
|
||||||
|
"Token merging ratio": None if opts.token_merging_ratio == 0 else opts.token_merging_ratio,
|
||||||
|
"Token merging ratio hr": None if not enable_hr or opts.token_merging_ratio_hr == 0 else opts.token_merging_ratio_hr,
|
||||||
"Init image hash": getattr(p, 'init_img_hash', None),
|
"Init image hash": getattr(p, 'init_img_hash', None),
|
||||||
"RNG": opts.randn_source if opts.randn_source != "GPU" else None,
|
"RNG": opts.randn_source if opts.randn_source != "GPU" else None,
|
||||||
"NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
|
"NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
|
||||||
|
"Version": program_version() if opts.add_version_to_infotext else None,
|
||||||
}
|
}
|
||||||
|
|
||||||
generation_params.update(p.extra_generation_params)
|
generation_params.update(p.extra_generation_params)
|
||||||
|
|
||||||
generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
|
generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
|
||||||
|
|
||||||
negative_prompt_text = "\nNegative prompt: " + p.all_negative_prompts[index] if p.all_negative_prompts[index] else ""
|
negative_prompt_text = f"\nNegative prompt: {p.all_negative_prompts[index]}" if p.all_negative_prompts[index] else ""
|
||||||
|
|
||||||
return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
|
return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
|
||||||
|
|
||||||
@ -512,9 +532,18 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||||||
if k == 'sd_vae':
|
if k == 'sd_vae':
|
||||||
sd_vae.reload_vae_weights()
|
sd_vae.reload_vae_weights()
|
||||||
|
|
||||||
|
if opts.token_merging_ratio > 0:
|
||||||
|
sd_models.apply_token_merging(sd_model=p.sd_model, hr=False)
|
||||||
|
logger.debug(f"Token merging applied to first pass. Ratio: '{opts.token_merging_ratio}'")
|
||||||
|
|
||||||
res = process_images_inner(p)
|
res = process_images_inner(p)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
|
# undo model optimizations made by tomesd
|
||||||
|
if opts.token_merging_ratio > 0:
|
||||||
|
tomesd.remove_patch(p.sd_model)
|
||||||
|
logger.debug('Token merging model optimizations removed')
|
||||||
|
|
||||||
# restore opts to original state
|
# restore opts to original state
|
||||||
if p.override_settings_restore_afterwards:
|
if p.override_settings_restore_afterwards:
|
||||||
for k, v in stored_opts.items():
|
for k, v in stored_opts.items():
|
||||||
@ -653,7 +682,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
if not shared.opts.dont_fix_second_order_samplers_schedule:
|
if not shared.opts.dont_fix_second_order_samplers_schedule:
|
||||||
try:
|
try:
|
||||||
step_multiplier = 2 if sd_samplers.all_samplers_map.get(p.sampler_name).aliases[0] in ['k_dpmpp_2s_a', 'k_dpmpp_2s_a_ka', 'k_dpmpp_sde', 'k_dpmpp_sde_ka', 'k_dpm_2', 'k_dpm_2_a', 'k_heun'] else 1
|
step_multiplier = 2 if sd_samplers.all_samplers_map.get(p.sampler_name).aliases[0] in ['k_dpmpp_2s_a', 'k_dpmpp_2s_a_ka', 'k_dpmpp_sde', 'k_dpmpp_sde_ka', 'k_dpm_2', 'k_dpm_2_a', 'k_heun'] else 1
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
uc = get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, p.steps * step_multiplier, cached_uc)
|
uc = get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, p.steps * step_multiplier, cached_uc)
|
||||||
c = get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, p.steps * step_multiplier, cached_c)
|
c = get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, p.steps * step_multiplier, cached_c)
|
||||||
@ -769,7 +798,16 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
|
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
res = Processed(p, output_images, p.all_seeds[0], infotext(), comments="".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts)
|
res = Processed(
|
||||||
|
p,
|
||||||
|
images_list=output_images,
|
||||||
|
seed=p.all_seeds[0],
|
||||||
|
info=infotext(),
|
||||||
|
comments="".join(f"\n\n{comment}" for comment in comments),
|
||||||
|
subseed=p.all_subseeds[0],
|
||||||
|
index_of_first_image=index_of_first_image,
|
||||||
|
infotexts=infotexts,
|
||||||
|
)
|
||||||
|
|
||||||
if p.scripts is not None:
|
if p.scripts is not None:
|
||||||
p.scripts.postprocess(p, res)
|
p.scripts.postprocess(p, res)
|
||||||
@ -958,8 +996,22 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
x = None
|
x = None
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
|
# apply token merging optimizations from tomesd for high-res pass
|
||||||
|
if opts.token_merging_ratio_hr > 0:
|
||||||
|
# in case the user has used separate merge ratios
|
||||||
|
if opts.token_merging_ratio > 0:
|
||||||
|
tomesd.remove_patch(self.sd_model)
|
||||||
|
logger.debug('Adjusting token merging ratio for high-res pass')
|
||||||
|
|
||||||
|
sd_models.apply_token_merging(sd_model=self.sd_model, hr=True)
|
||||||
|
logger.debug(f"Applied token merging for high-res pass. Ratio: '{opts.token_merging_ratio_hr}'")
|
||||||
|
|
||||||
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
|
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
|
||||||
|
|
||||||
|
if opts.token_merging_ratio_hr > 0 or opts.token_merging_ratio > 0:
|
||||||
|
tomesd.remove_patch(self.sd_model)
|
||||||
|
logger.debug('Removed token merging optimizations from model')
|
||||||
|
|
||||||
self.is_hr_pass = False
|
self.is_hr_pass = False
|
||||||
|
|
||||||
return samples
|
return samples
|
||||||
|
@ -95,8 +95,16 @@ def progressapi(req: ProgressRequest):
|
|||||||
image = shared.state.current_image
|
image = shared.state.current_image
|
||||||
if image is not None:
|
if image is not None:
|
||||||
buffered = io.BytesIO()
|
buffered = io.BytesIO()
|
||||||
image.save(buffered, format="png")
|
|
||||||
live_preview = 'data:image/png;base64,' + base64.b64encode(buffered.getvalue()).decode("ascii")
|
if opts.live_previews_image_format == "png":
|
||||||
|
# using optimize for large images takes an enormous amount of time
|
||||||
|
save_kwargs = {"optimize": max(*image.size) > 256}
|
||||||
|
else:
|
||||||
|
save_kwargs = {}
|
||||||
|
|
||||||
|
image.save(buffered, format=opts.live_previews_image_format, **save_kwargs)
|
||||||
|
base64_image = base64.b64encode(buffered.getvalue()).decode('ascii')
|
||||||
|
live_preview = f"data:image/{opts.live_previews_image_format};base64,{base64_image}"
|
||||||
id_live_preview = shared.state.id_live_preview
|
id_live_preview = shared.state.id_live_preview
|
||||||
else:
|
else:
|
||||||
live_preview = None
|
live_preview = None
|
||||||
|
@ -54,18 +54,21 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def collect_steps(steps, tree):
|
def collect_steps(steps, tree):
|
||||||
l = [steps]
|
res = [steps]
|
||||||
|
|
||||||
class CollectSteps(lark.Visitor):
|
class CollectSteps(lark.Visitor):
|
||||||
def scheduled(self, tree):
|
def scheduled(self, tree):
|
||||||
tree.children[-1] = float(tree.children[-1])
|
tree.children[-1] = float(tree.children[-1])
|
||||||
if tree.children[-1] < 1:
|
if tree.children[-1] < 1:
|
||||||
tree.children[-1] *= steps
|
tree.children[-1] *= steps
|
||||||
tree.children[-1] = min(steps, int(tree.children[-1]))
|
tree.children[-1] = min(steps, int(tree.children[-1]))
|
||||||
l.append(tree.children[-1])
|
res.append(tree.children[-1])
|
||||||
|
|
||||||
def alternate(self, tree):
|
def alternate(self, tree):
|
||||||
l.extend(range(1, steps+1))
|
res.extend(range(1, steps+1))
|
||||||
|
|
||||||
CollectSteps().visit(tree)
|
CollectSteps().visit(tree)
|
||||||
return sorted(set(l))
|
return sorted(set(res))
|
||||||
|
|
||||||
def at_step(step, tree):
|
def at_step(step, tree):
|
||||||
class AtStep(lark.Transformer):
|
class AtStep(lark.Transformer):
|
||||||
@ -92,7 +95,7 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
|
|||||||
def get_schedule(prompt):
|
def get_schedule(prompt):
|
||||||
try:
|
try:
|
||||||
tree = schedule_parser.parse(prompt)
|
tree = schedule_parser.parse(prompt)
|
||||||
except lark.exceptions.LarkError as e:
|
except lark.exceptions.LarkError:
|
||||||
if 0:
|
if 0:
|
||||||
import traceback
|
import traceback
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
@ -140,7 +143,7 @@ def get_learned_conditioning(model, prompts, steps):
|
|||||||
conds = model.get_learned_conditioning(texts)
|
conds = model.get_learned_conditioning(texts)
|
||||||
|
|
||||||
cond_schedule = []
|
cond_schedule = []
|
||||||
for i, (end_at_step, text) in enumerate(prompt_schedule):
|
for i, (end_at_step, _) in enumerate(prompt_schedule):
|
||||||
cond_schedule.append(ScheduledPromptConditioning(end_at_step, conds[i]))
|
cond_schedule.append(ScheduledPromptConditioning(end_at_step, conds[i]))
|
||||||
|
|
||||||
cache[prompt] = cond_schedule
|
cache[prompt] = cond_schedule
|
||||||
@ -216,8 +219,8 @@ def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_s
|
|||||||
res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
|
res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
|
||||||
for i, cond_schedule in enumerate(c):
|
for i, cond_schedule in enumerate(c):
|
||||||
target_index = 0
|
target_index = 0
|
||||||
for current, (end_at, cond) in enumerate(cond_schedule):
|
for current, entry in enumerate(cond_schedule):
|
||||||
if current_step <= end_at:
|
if current_step <= entry.end_at_step:
|
||||||
target_index = current
|
target_index = current
|
||||||
break
|
break
|
||||||
res[i] = cond_schedule[target_index].cond
|
res[i] = cond_schedule[target_index].cond
|
||||||
@ -231,13 +234,13 @@ def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
|
|||||||
tensors = []
|
tensors = []
|
||||||
conds_list = []
|
conds_list = []
|
||||||
|
|
||||||
for batch_no, composable_prompts in enumerate(c.batch):
|
for composable_prompts in c.batch:
|
||||||
conds_for_batch = []
|
conds_for_batch = []
|
||||||
|
|
||||||
for cond_index, composable_prompt in enumerate(composable_prompts):
|
for composable_prompt in composable_prompts:
|
||||||
target_index = 0
|
target_index = 0
|
||||||
for current, (end_at, cond) in enumerate(composable_prompt.schedules):
|
for current, entry in enumerate(composable_prompt.schedules):
|
||||||
if current_step <= end_at:
|
if current_step <= entry.end_at_step:
|
||||||
target_index = current
|
target_index = current
|
||||||
break
|
break
|
||||||
|
|
||||||
|
@ -17,9 +17,9 @@ class UpscalerRealESRGAN(Upscaler):
|
|||||||
self.user_path = path
|
self.user_path = path
|
||||||
super().__init__()
|
super().__init__()
|
||||||
try:
|
try:
|
||||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
from basicsr.archs.rrdbnet_arch import RRDBNet # noqa: F401
|
||||||
from realesrgan import RealESRGANer
|
from realesrgan import RealESRGANer # noqa: F401
|
||||||
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
from realesrgan.archs.srvgg_arch import SRVGGNetCompact # noqa: F401
|
||||||
self.enable = True
|
self.enable = True
|
||||||
self.scalers = []
|
self.scalers = []
|
||||||
scalers = self.load_models(path)
|
scalers = self.load_models(path)
|
||||||
@ -28,9 +28,9 @@ class UpscalerRealESRGAN(Upscaler):
|
|||||||
for scaler in scalers:
|
for scaler in scalers:
|
||||||
if scaler.local_data_path.startswith("http"):
|
if scaler.local_data_path.startswith("http"):
|
||||||
filename = modelloader.friendly_name(scaler.local_data_path)
|
filename = modelloader.friendly_name(scaler.local_data_path)
|
||||||
local = next(iter([local_model for local_model in local_model_paths if local_model.endswith(filename + '.pth')]), None)
|
local_model_candidates = [local_model for local_model in local_model_paths if local_model.endswith(f"{filename}.pth")]
|
||||||
if local:
|
if local_model_candidates:
|
||||||
scaler.local_data_path = local
|
scaler.local_data_path = local_model_candidates[0]
|
||||||
|
|
||||||
if scaler.name in opts.realesrgan_enabled_models:
|
if scaler.name in opts.realesrgan_enabled_models:
|
||||||
self.scalers.append(scaler)
|
self.scalers.append(scaler)
|
||||||
@ -47,7 +47,7 @@ class UpscalerRealESRGAN(Upscaler):
|
|||||||
|
|
||||||
info = self.load_model(path)
|
info = self.load_model(path)
|
||||||
if not os.path.exists(info.local_data_path):
|
if not os.path.exists(info.local_data_path):
|
||||||
print("Unable to load RealESRGAN model: %s" % info.name)
|
print(f"Unable to load RealESRGAN model: {info.name}")
|
||||||
return img
|
return img
|
||||||
|
|
||||||
upsampler = RealESRGANer(
|
upsampler = RealESRGANer(
|
||||||
@ -134,6 +134,6 @@ def get_realesrgan_models(scaler):
|
|||||||
),
|
),
|
||||||
]
|
]
|
||||||
return models
|
return models
|
||||||
except Exception as e:
|
except Exception:
|
||||||
print("Error making Real-ESRGAN models list:", file=sys.stderr)
|
print("Error making Real-ESRGAN models list:", file=sys.stderr)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
@ -40,7 +40,7 @@ class RestrictedUnpickler(pickle.Unpickler):
|
|||||||
return getattr(collections, name)
|
return getattr(collections, name)
|
||||||
if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter', '_rebuild_device_tensor_from_numpy']:
|
if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter', '_rebuild_device_tensor_from_numpy']:
|
||||||
return getattr(torch._utils, name)
|
return getattr(torch._utils, name)
|
||||||
if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage', 'float32']:
|
if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage', 'float32', 'BFloat16Storage']:
|
||||||
return getattr(torch, name)
|
return getattr(torch, name)
|
||||||
if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
|
if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
|
||||||
return getattr(torch.nn.modules.container, name)
|
return getattr(torch.nn.modules.container, name)
|
||||||
@ -95,16 +95,16 @@ def check_pt(filename, extra_handler):
|
|||||||
|
|
||||||
except zipfile.BadZipfile:
|
except zipfile.BadZipfile:
|
||||||
|
|
||||||
# if it's not a zip file, it's an olf pytorch format, with five objects written to pickle
|
# if it's not a zip file, it's an old pytorch format, with five objects written to pickle
|
||||||
with open(filename, "rb") as file:
|
with open(filename, "rb") as file:
|
||||||
unpickler = RestrictedUnpickler(file)
|
unpickler = RestrictedUnpickler(file)
|
||||||
unpickler.extra_handler = extra_handler
|
unpickler.extra_handler = extra_handler
|
||||||
for i in range(5):
|
for _ in range(5):
|
||||||
unpickler.load()
|
unpickler.load()
|
||||||
|
|
||||||
|
|
||||||
def load(filename, *args, **kwargs):
|
def load(filename, *args, **kwargs):
|
||||||
return load_with_extra(filename, extra_handler=global_extra_handler, *args, **kwargs)
|
return load_with_extra(filename, *args, extra_handler=global_extra_handler, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def load_with_extra(filename, extra_handler=None, *args, **kwargs):
|
def load_with_extra(filename, extra_handler=None, *args, **kwargs):
|
||||||
|
@ -32,27 +32,42 @@ class CFGDenoiserParams:
|
|||||||
def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond):
|
def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond):
|
||||||
self.x = x
|
self.x = x
|
||||||
"""Latent image representation in the process of being denoised"""
|
"""Latent image representation in the process of being denoised"""
|
||||||
|
|
||||||
self.image_cond = image_cond
|
self.image_cond = image_cond
|
||||||
"""Conditioning image"""
|
"""Conditioning image"""
|
||||||
|
|
||||||
self.sigma = sigma
|
self.sigma = sigma
|
||||||
"""Current sigma noise step value"""
|
"""Current sigma noise step value"""
|
||||||
|
|
||||||
self.sampling_step = sampling_step
|
self.sampling_step = sampling_step
|
||||||
"""Current Sampling step number"""
|
"""Current Sampling step number"""
|
||||||
|
|
||||||
self.total_sampling_steps = total_sampling_steps
|
self.total_sampling_steps = total_sampling_steps
|
||||||
"""Total number of sampling steps planned"""
|
"""Total number of sampling steps planned"""
|
||||||
|
|
||||||
self.text_cond = text_cond
|
self.text_cond = text_cond
|
||||||
""" Encoder hidden states of text conditioning from prompt"""
|
""" Encoder hidden states of text conditioning from prompt"""
|
||||||
|
|
||||||
self.text_uncond = text_uncond
|
self.text_uncond = text_uncond
|
||||||
""" Encoder hidden states of text conditioning from negative prompt"""
|
""" Encoder hidden states of text conditioning from negative prompt"""
|
||||||
|
|
||||||
|
|
||||||
class CFGDenoisedParams:
|
class CFGDenoisedParams:
|
||||||
|
def __init__(self, x, sampling_step, total_sampling_steps, inner_model):
|
||||||
|
self.x = x
|
||||||
|
"""Latent image representation in the process of being denoised"""
|
||||||
|
|
||||||
|
self.sampling_step = sampling_step
|
||||||
|
"""Current Sampling step number"""
|
||||||
|
|
||||||
|
self.total_sampling_steps = total_sampling_steps
|
||||||
|
"""Total number of sampling steps planned"""
|
||||||
|
|
||||||
|
self.inner_model = inner_model
|
||||||
|
"""Inner model reference used for denoising"""
|
||||||
|
|
||||||
|
|
||||||
|
class AfterCFGCallbackParams:
|
||||||
def __init__(self, x, sampling_step, total_sampling_steps):
|
def __init__(self, x, sampling_step, total_sampling_steps):
|
||||||
self.x = x
|
self.x = x
|
||||||
"""Latent image representation in the process of being denoised"""
|
"""Latent image representation in the process of being denoised"""
|
||||||
@ -87,6 +102,7 @@ callback_map = dict(
|
|||||||
callbacks_image_saved=[],
|
callbacks_image_saved=[],
|
||||||
callbacks_cfg_denoiser=[],
|
callbacks_cfg_denoiser=[],
|
||||||
callbacks_cfg_denoised=[],
|
callbacks_cfg_denoised=[],
|
||||||
|
callbacks_cfg_after_cfg=[],
|
||||||
callbacks_before_component=[],
|
callbacks_before_component=[],
|
||||||
callbacks_after_component=[],
|
callbacks_after_component=[],
|
||||||
callbacks_image_grid=[],
|
callbacks_image_grid=[],
|
||||||
@ -186,6 +202,14 @@ def cfg_denoised_callback(params: CFGDenoisedParams):
|
|||||||
report_exception(c, 'cfg_denoised_callback')
|
report_exception(c, 'cfg_denoised_callback')
|
||||||
|
|
||||||
|
|
||||||
|
def cfg_after_cfg_callback(params: AfterCFGCallbackParams):
|
||||||
|
for c in callback_map['callbacks_cfg_after_cfg']:
|
||||||
|
try:
|
||||||
|
c.callback(params)
|
||||||
|
except Exception:
|
||||||
|
report_exception(c, 'cfg_after_cfg_callback')
|
||||||
|
|
||||||
|
|
||||||
def before_component_callback(component, **kwargs):
|
def before_component_callback(component, **kwargs):
|
||||||
for c in callback_map['callbacks_before_component']:
|
for c in callback_map['callbacks_before_component']:
|
||||||
try:
|
try:
|
||||||
@ -240,7 +264,7 @@ def add_callback(callbacks, fun):
|
|||||||
|
|
||||||
callbacks.append(ScriptCallback(filename, fun))
|
callbacks.append(ScriptCallback(filename, fun))
|
||||||
|
|
||||||
|
|
||||||
def remove_current_script_callbacks():
|
def remove_current_script_callbacks():
|
||||||
stack = [x for x in inspect.stack() if x.filename != __file__]
|
stack = [x for x in inspect.stack() if x.filename != __file__]
|
||||||
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
|
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
|
||||||
@ -332,6 +356,14 @@ def on_cfg_denoised(callback):
|
|||||||
add_callback(callback_map['callbacks_cfg_denoised'], callback)
|
add_callback(callback_map['callbacks_cfg_denoised'], callback)
|
||||||
|
|
||||||
|
|
||||||
|
def on_cfg_after_cfg(callback):
|
||||||
|
"""register a function to be called in the kdiffussion cfg_denoiser method after cfg calculations are completed.
|
||||||
|
The callback is called with one argument:
|
||||||
|
- params: AfterCFGCallbackParams - parameters to be passed to the script for post-processing after cfg calculation.
|
||||||
|
"""
|
||||||
|
add_callback(callback_map['callbacks_cfg_after_cfg'], callback)
|
||||||
|
|
||||||
|
|
||||||
def on_before_component(callback):
|
def on_before_component(callback):
|
||||||
"""register a function to be called before a component is created.
|
"""register a function to be called before a component is created.
|
||||||
The callback is called with arguments:
|
The callback is called with arguments:
|
||||||
|
@ -2,7 +2,6 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
import importlib.util
|
import importlib.util
|
||||||
from types import ModuleType
|
|
||||||
|
|
||||||
|
|
||||||
def load_module(path):
|
def load_module(path):
|
||||||
|
@ -163,7 +163,8 @@ class Script:
|
|||||||
"""helper function to generate id for a HTML element, constructs final id out of script name, tab and user-supplied item_id"""
|
"""helper function to generate id for a HTML element, constructs final id out of script name, tab and user-supplied item_id"""
|
||||||
|
|
||||||
need_tabname = self.show(True) == self.show(False)
|
need_tabname = self.show(True) == self.show(False)
|
||||||
tabname = ('img2img' if self.is_img2img else 'txt2txt') + "_" if need_tabname else ""
|
tabkind = 'img2img' if self.is_img2img else 'txt2txt'
|
||||||
|
tabname = f"{tabkind}_" if need_tabname else ""
|
||||||
title = re.sub(r'[^a-z_0-9]', '', re.sub(r'\s', '_', self.title().lower()))
|
title = re.sub(r'[^a-z_0-9]', '', re.sub(r'\s', '_', self.title().lower()))
|
||||||
|
|
||||||
return f'script_{tabname}{title}_{item_id}'
|
return f'script_{tabname}{title}_{item_id}'
|
||||||
@ -230,7 +231,7 @@ def load_scripts():
|
|||||||
syspath = sys.path
|
syspath = sys.path
|
||||||
|
|
||||||
def register_scripts_from_module(module):
|
def register_scripts_from_module(module):
|
||||||
for key, script_class in module.__dict__.items():
|
for script_class in module.__dict__.values():
|
||||||
if type(script_class) != type:
|
if type(script_class) != type:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -294,9 +295,9 @@ class ScriptRunner:
|
|||||||
|
|
||||||
auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data()
|
auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data()
|
||||||
|
|
||||||
for script_class, path, basedir, script_module in auto_processing_scripts + scripts_data:
|
for script_data in auto_processing_scripts + scripts_data:
|
||||||
script = script_class()
|
script = script_data.script_class()
|
||||||
script.filename = path
|
script.filename = script_data.path
|
||||||
script.is_txt2img = not is_img2img
|
script.is_txt2img = not is_img2img
|
||||||
script.is_img2img = is_img2img
|
script.is_img2img = is_img2img
|
||||||
|
|
||||||
@ -491,7 +492,7 @@ class ScriptRunner:
|
|||||||
module = script_loading.load_module(script.filename)
|
module = script_loading.load_module(script.filename)
|
||||||
cache[filename] = module
|
cache[filename] = module
|
||||||
|
|
||||||
for key, script_class in module.__dict__.items():
|
for script_class in module.__dict__.values():
|
||||||
if type(script_class) == type and issubclass(script_class, Script):
|
if type(script_class) == type and issubclass(script_class, Script):
|
||||||
self.scripts[si] = script_class()
|
self.scripts[si] = script_class()
|
||||||
self.scripts[si].filename = filename
|
self.scripts[si].filename = filename
|
||||||
@ -526,7 +527,7 @@ def add_classes_to_gradio_component(comp):
|
|||||||
this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others
|
this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others
|
||||||
"""
|
"""
|
||||||
|
|
||||||
comp.elem_classes = ["gradio-" + comp.get_block_name(), *(comp.elem_classes or [])]
|
comp.elem_classes = [f"gradio-{comp.get_block_name()}", *(comp.elem_classes or [])]
|
||||||
|
|
||||||
if getattr(comp, 'multiselect', False):
|
if getattr(comp, 'multiselect', False):
|
||||||
comp.elem_classes.append('multiselect')
|
comp.elem_classes.append('multiselect')
|
||||||
|
@ -17,7 +17,7 @@ class ScriptPostprocessingForMainUI(scripts.Script):
|
|||||||
return self.postprocessing_controls.values()
|
return self.postprocessing_controls.values()
|
||||||
|
|
||||||
def postprocess_image(self, p, script_pp, *args):
|
def postprocess_image(self, p, script_pp, *args):
|
||||||
args_dict = {k: v for k, v in zip(self.postprocessing_controls, args)}
|
args_dict = dict(zip(self.postprocessing_controls, args))
|
||||||
|
|
||||||
pp = scripts_postprocessing.PostprocessedImage(script_pp.image)
|
pp = scripts_postprocessing.PostprocessedImage(script_pp.image)
|
||||||
pp.info = {}
|
pp.info = {}
|
||||||
|
@ -66,9 +66,9 @@ class ScriptPostprocessingRunner:
|
|||||||
def initialize_scripts(self, scripts_data):
|
def initialize_scripts(self, scripts_data):
|
||||||
self.scripts = []
|
self.scripts = []
|
||||||
|
|
||||||
for script_class, path, basedir, script_module in scripts_data:
|
for script_data in scripts_data:
|
||||||
script: ScriptPostprocessing = script_class()
|
script: ScriptPostprocessing = script_data.script_class()
|
||||||
script.filename = path
|
script.filename = script_data.path
|
||||||
|
|
||||||
if script.name == "Simple Upscale":
|
if script.name == "Simple Upscale":
|
||||||
continue
|
continue
|
||||||
@ -124,7 +124,7 @@ class ScriptPostprocessingRunner:
|
|||||||
script_args = args[script.args_from:script.args_to]
|
script_args = args[script.args_from:script.args_to]
|
||||||
|
|
||||||
process_args = {}
|
process_args = {}
|
||||||
for (name, component), value in zip(script.controls.items(), script_args):
|
for (name, _component), value in zip(script.controls.items(), script_args):
|
||||||
process_args[name] = value
|
process_args[name] = value
|
||||||
|
|
||||||
script.process(pp, **process_args)
|
script.process(pp, **process_args)
|
||||||
|
@ -61,7 +61,7 @@ class DisableInitialization:
|
|||||||
if res is None:
|
if res is None:
|
||||||
res = original(url, *args, local_files_only=False, **kwargs)
|
res = original(url, *args, local_files_only=False, **kwargs)
|
||||||
return res
|
return res
|
||||||
except Exception as e:
|
except Exception:
|
||||||
return original(url, *args, local_files_only=False, **kwargs)
|
return original(url, *args, local_files_only=False, **kwargs)
|
||||||
|
|
||||||
def transformers_utils_hub_get_from_cache(url, *args, local_files_only=False, **kwargs):
|
def transformers_utils_hub_get_from_cache(url, *args, local_files_only=False, **kwargs):
|
||||||
|
@ -3,7 +3,7 @@ from torch.nn.functional import silu
|
|||||||
from types import MethodType
|
from types import MethodType
|
||||||
|
|
||||||
import modules.textual_inversion.textual_inversion
|
import modules.textual_inversion.textual_inversion
|
||||||
from modules import devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint
|
from modules import devices, sd_hijack_optimizations, shared
|
||||||
from modules.hypernetworks import hypernetwork
|
from modules.hypernetworks import hypernetwork
|
||||||
from modules.shared import cmd_opts
|
from modules.shared import cmd_opts
|
||||||
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
|
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
|
||||||
@ -34,10 +34,10 @@ def apply_optimizations():
|
|||||||
|
|
||||||
ldm.modules.diffusionmodules.model.nonlinearity = silu
|
ldm.modules.diffusionmodules.model.nonlinearity = silu
|
||||||
ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
|
ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
|
||||||
|
|
||||||
optimization_method = None
|
optimization_method = None
|
||||||
|
|
||||||
can_use_sdp = hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(getattr(torch.nn.functional, "scaled_dot_product_attention")) # not everyone has torch 2.x to use sdp
|
can_use_sdp = hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(torch.nn.functional.scaled_dot_product_attention) # not everyone has torch 2.x to use sdp
|
||||||
|
|
||||||
if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
|
if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
|
||||||
print("Applying xformers cross attention optimization.")
|
print("Applying xformers cross attention optimization.")
|
||||||
@ -92,12 +92,12 @@ def fix_checkpoint():
|
|||||||
def weighted_loss(sd_model, pred, target, mean=True):
|
def weighted_loss(sd_model, pred, target, mean=True):
|
||||||
#Calculate the weight normally, but ignore the mean
|
#Calculate the weight normally, but ignore the mean
|
||||||
loss = sd_model._old_get_loss(pred, target, mean=False)
|
loss = sd_model._old_get_loss(pred, target, mean=False)
|
||||||
|
|
||||||
#Check if we have weights available
|
#Check if we have weights available
|
||||||
weight = getattr(sd_model, '_custom_loss_weight', None)
|
weight = getattr(sd_model, '_custom_loss_weight', None)
|
||||||
if weight is not None:
|
if weight is not None:
|
||||||
loss *= weight
|
loss *= weight
|
||||||
|
|
||||||
#Return the loss, as mean if specified
|
#Return the loss, as mean if specified
|
||||||
return loss.mean() if mean else loss
|
return loss.mean() if mean else loss
|
||||||
|
|
||||||
@ -105,7 +105,7 @@ def weighted_forward(sd_model, x, c, w, *args, **kwargs):
|
|||||||
try:
|
try:
|
||||||
#Temporarily append weights to a place accessible during loss calc
|
#Temporarily append weights to a place accessible during loss calc
|
||||||
sd_model._custom_loss_weight = w
|
sd_model._custom_loss_weight = w
|
||||||
|
|
||||||
#Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely
|
#Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely
|
||||||
#Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set
|
#Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set
|
||||||
if not hasattr(sd_model, '_old_get_loss'):
|
if not hasattr(sd_model, '_old_get_loss'):
|
||||||
@ -118,9 +118,9 @@ def weighted_forward(sd_model, x, c, w, *args, **kwargs):
|
|||||||
try:
|
try:
|
||||||
#Delete temporary weights if appended
|
#Delete temporary weights if appended
|
||||||
del sd_model._custom_loss_weight
|
del sd_model._custom_loss_weight
|
||||||
except AttributeError as e:
|
except AttributeError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
#If we have an old loss function, reset the loss function to the original one
|
#If we have an old loss function, reset the loss function to the original one
|
||||||
if hasattr(sd_model, '_old_get_loss'):
|
if hasattr(sd_model, '_old_get_loss'):
|
||||||
sd_model.get_loss = sd_model._old_get_loss
|
sd_model.get_loss = sd_model._old_get_loss
|
||||||
@ -133,7 +133,7 @@ def apply_weighted_forward(sd_model):
|
|||||||
def undo_weighted_forward(sd_model):
|
def undo_weighted_forward(sd_model):
|
||||||
try:
|
try:
|
||||||
del sd_model.weighted_forward
|
del sd_model.weighted_forward
|
||||||
except AttributeError as e:
|
except AttributeError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@ -184,7 +184,7 @@ class StableDiffusionModelHijack:
|
|||||||
|
|
||||||
def undo_hijack(self, m):
|
def undo_hijack(self, m):
|
||||||
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
|
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
|
||||||
m.cond_stage_model = m.cond_stage_model.wrapped
|
m.cond_stage_model = m.cond_stage_model.wrapped
|
||||||
|
|
||||||
elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
|
elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
|
||||||
m.cond_stage_model = m.cond_stage_model.wrapped
|
m.cond_stage_model = m.cond_stage_model.wrapped
|
||||||
|
@ -223,7 +223,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
|||||||
self.hijack.fixes = [x.fixes for x in batch_chunk]
|
self.hijack.fixes = [x.fixes for x in batch_chunk]
|
||||||
|
|
||||||
for fixes in self.hijack.fixes:
|
for fixes in self.hijack.fixes:
|
||||||
for position, embedding in fixes:
|
for _position, embedding in fixes:
|
||||||
used_embeddings[embedding.name] = embedding
|
used_embeddings[embedding.name] = embedding
|
||||||
|
|
||||||
z = self.process_tokens(tokens, multipliers)
|
z = self.process_tokens(tokens, multipliers)
|
||||||
|
@ -75,7 +75,8 @@ def forward_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, text
|
|||||||
self.hijack.comments += hijack_comments
|
self.hijack.comments += hijack_comments
|
||||||
|
|
||||||
if len(used_custom_terms) > 0:
|
if len(used_custom_terms) > 0:
|
||||||
self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
|
embedding_names = ", ".join(f"{word} [{checksum}]" for word, checksum in used_custom_terms)
|
||||||
|
self.hijack.comments.append(f"Used embeddings: {embedding_names}")
|
||||||
|
|
||||||
self.hijack.fixes = hijack_fixes
|
self.hijack.fixes = hijack_fixes
|
||||||
return self.process_tokens(remade_batch_tokens, batch_multipliers)
|
return self.process_tokens(remade_batch_tokens, batch_multipliers)
|
||||||
|
@ -1,16 +1,10 @@
|
|||||||
import os
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from einops import repeat
|
|
||||||
from omegaconf import ListConfig
|
|
||||||
|
|
||||||
import ldm.models.diffusion.ddpm
|
import ldm.models.diffusion.ddpm
|
||||||
import ldm.models.diffusion.ddim
|
import ldm.models.diffusion.ddim
|
||||||
import ldm.models.diffusion.plms
|
import ldm.models.diffusion.plms
|
||||||
|
|
||||||
from ldm.models.diffusion.ddpm import LatentDiffusion
|
from ldm.models.diffusion.ddim import noise_like
|
||||||
from ldm.models.diffusion.plms import PLMSSampler
|
|
||||||
from ldm.models.diffusion.ddim import DDIMSampler, noise_like
|
|
||||||
from ldm.models.diffusion.sampling_util import norm_thresholding
|
from ldm.models.diffusion.sampling_util import norm_thresholding
|
||||||
|
|
||||||
|
|
||||||
@ -29,7 +23,7 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F
|
|||||||
|
|
||||||
if isinstance(c, dict):
|
if isinstance(c, dict):
|
||||||
assert isinstance(unconditional_conditioning, dict)
|
assert isinstance(unconditional_conditioning, dict)
|
||||||
c_in = dict()
|
c_in = {}
|
||||||
for k in c:
|
for k in c:
|
||||||
if isinstance(c[k], list):
|
if isinstance(c[k], list):
|
||||||
c_in[k] = [
|
c_in[k] = [
|
||||||
|
@ -1,8 +1,5 @@
|
|||||||
import collections
|
|
||||||
import os.path
|
import os.path
|
||||||
import sys
|
|
||||||
import gc
|
|
||||||
import time
|
|
||||||
|
|
||||||
def should_hijack_ip2p(checkpoint_info):
|
def should_hijack_ip2p(checkpoint_info):
|
||||||
from modules import sd_models_config
|
from modules import sd_models_config
|
||||||
@ -10,4 +7,4 @@ def should_hijack_ip2p(checkpoint_info):
|
|||||||
ckpt_basename = os.path.basename(checkpoint_info.filename).lower()
|
ckpt_basename = os.path.basename(checkpoint_info.filename).lower()
|
||||||
cfg_basename = os.path.basename(sd_models_config.find_checkpoint_config_near_filename(checkpoint_info)).lower()
|
cfg_basename = os.path.basename(sd_models_config.find_checkpoint_config_near_filename(checkpoint_info)).lower()
|
||||||
|
|
||||||
return "pix2pix" in ckpt_basename and not "pix2pix" in cfg_basename
|
return "pix2pix" in ckpt_basename and "pix2pix" not in cfg_basename
|
||||||
|
@ -49,7 +49,7 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
|
|||||||
v_in = self.to_v(context_v)
|
v_in = self.to_v(context_v)
|
||||||
del context, context_k, context_v, x
|
del context, context_k, context_v, x
|
||||||
|
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in))
|
||||||
del q_in, k_in, v_in
|
del q_in, k_in, v_in
|
||||||
|
|
||||||
dtype = q.dtype
|
dtype = q.dtype
|
||||||
@ -62,10 +62,10 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
|
|||||||
end = i + 2
|
end = i + 2
|
||||||
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
|
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
|
||||||
s1 *= self.scale
|
s1 *= self.scale
|
||||||
|
|
||||||
s2 = s1.softmax(dim=-1)
|
s2 = s1.softmax(dim=-1)
|
||||||
del s1
|
del s1
|
||||||
|
|
||||||
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
|
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
|
||||||
del s2
|
del s2
|
||||||
del q, k, v
|
del q, k, v
|
||||||
@ -95,43 +95,43 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
|
|||||||
|
|
||||||
with devices.without_autocast(disable=not shared.opts.upcast_attn):
|
with devices.without_autocast(disable=not shared.opts.upcast_attn):
|
||||||
k_in = k_in * self.scale
|
k_in = k_in * self.scale
|
||||||
|
|
||||||
del context, x
|
del context, x
|
||||||
|
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in))
|
||||||
del q_in, k_in, v_in
|
del q_in, k_in, v_in
|
||||||
|
|
||||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||||
|
|
||||||
mem_free_total = get_available_vram()
|
mem_free_total = get_available_vram()
|
||||||
|
|
||||||
gb = 1024 ** 3
|
gb = 1024 ** 3
|
||||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
|
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
|
||||||
modifier = 3 if q.element_size() == 2 else 2.5
|
modifier = 3 if q.element_size() == 2 else 2.5
|
||||||
mem_required = tensor_size * modifier
|
mem_required = tensor_size * modifier
|
||||||
steps = 1
|
steps = 1
|
||||||
|
|
||||||
if mem_required > mem_free_total:
|
if mem_required > mem_free_total:
|
||||||
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
|
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||||
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
|
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
|
||||||
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
|
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
|
||||||
|
|
||||||
if steps > 64:
|
if steps > 64:
|
||||||
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
|
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
|
||||||
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
|
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
|
||||||
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
|
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
|
||||||
|
|
||||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||||
for i in range(0, q.shape[1], slice_size):
|
for i in range(0, q.shape[1], slice_size):
|
||||||
end = i + slice_size
|
end = i + slice_size
|
||||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
|
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
|
||||||
|
|
||||||
s2 = s1.softmax(dim=-1, dtype=q.dtype)
|
s2 = s1.softmax(dim=-1, dtype=q.dtype)
|
||||||
del s1
|
del s1
|
||||||
|
|
||||||
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
||||||
del s2
|
del s2
|
||||||
|
|
||||||
del q, k, v
|
del q, k, v
|
||||||
|
|
||||||
r1 = r1.to(dtype)
|
r1 = r1.to(dtype)
|
||||||
@ -228,8 +228,8 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
|
|||||||
|
|
||||||
with devices.without_autocast(disable=not shared.opts.upcast_attn):
|
with devices.without_autocast(disable=not shared.opts.upcast_attn):
|
||||||
k = k * self.scale
|
k = k * self.scale
|
||||||
|
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q, k, v))
|
||||||
r = einsum_op(q, k, v)
|
r = einsum_op(q, k, v)
|
||||||
r = r.to(dtype)
|
r = r.to(dtype)
|
||||||
return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
|
return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
|
||||||
@ -256,6 +256,9 @@ def sub_quad_attention_forward(self, x, context=None, mask=None):
|
|||||||
k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
|
k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
|
||||||
v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
|
v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
|
||||||
|
|
||||||
|
if q.device.type == 'mps':
|
||||||
|
q, k, v = q.contiguous(), k.contiguous(), v.contiguous()
|
||||||
|
|
||||||
dtype = q.dtype
|
dtype = q.dtype
|
||||||
if shared.opts.upcast_attn:
|
if shared.opts.upcast_attn:
|
||||||
q, k = q.float(), k.float()
|
q, k = q.float(), k.float()
|
||||||
@ -293,7 +296,6 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_
|
|||||||
if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes:
|
if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes:
|
||||||
# the big matmul fits into our memory limit; do everything in 1 chunk,
|
# the big matmul fits into our memory limit; do everything in 1 chunk,
|
||||||
# i.e. send it down the unchunked fast-path
|
# i.e. send it down the unchunked fast-path
|
||||||
query_chunk_size = q_tokens
|
|
||||||
kv_chunk_size = k_tokens
|
kv_chunk_size = k_tokens
|
||||||
|
|
||||||
with devices.without_autocast(disable=q.dtype == v.dtype):
|
with devices.without_autocast(disable=q.dtype == v.dtype):
|
||||||
@ -332,7 +334,7 @@ def xformers_attention_forward(self, x, context=None, mask=None):
|
|||||||
k_in = self.to_k(context_k)
|
k_in = self.to_k(context_k)
|
||||||
v_in = self.to_v(context_v)
|
v_in = self.to_v(context_v)
|
||||||
|
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
|
q, k, v = (rearrange(t, 'b n (h d) -> b n h d', h=h) for t in (q_in, k_in, v_in))
|
||||||
del q_in, k_in, v_in
|
del q_in, k_in, v_in
|
||||||
|
|
||||||
dtype = q.dtype
|
dtype = q.dtype
|
||||||
@ -367,7 +369,7 @@ def scaled_dot_product_attention_forward(self, x, context=None, mask=None):
|
|||||||
q = q_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
|
q = q_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
|
||||||
k = k_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
|
k = k_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
|
||||||
v = v_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
|
v = v_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
|
||||||
|
|
||||||
del q_in, k_in, v_in
|
del q_in, k_in, v_in
|
||||||
|
|
||||||
dtype = q.dtype
|
dtype = q.dtype
|
||||||
@ -449,7 +451,7 @@ def cross_attention_attnblock_forward(self, x):
|
|||||||
h3 += x
|
h3 += x
|
||||||
|
|
||||||
return h3
|
return h3
|
||||||
|
|
||||||
def xformers_attnblock_forward(self, x):
|
def xformers_attnblock_forward(self, x):
|
||||||
try:
|
try:
|
||||||
h_ = x
|
h_ = x
|
||||||
@ -458,7 +460,7 @@ def xformers_attnblock_forward(self, x):
|
|||||||
k = self.k(h_)
|
k = self.k(h_)
|
||||||
v = self.v(h_)
|
v = self.v(h_)
|
||||||
b, c, h, w = q.shape
|
b, c, h, w = q.shape
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
|
q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
|
||||||
dtype = q.dtype
|
dtype = q.dtype
|
||||||
if shared.opts.upcast_attn:
|
if shared.opts.upcast_attn:
|
||||||
q, k = q.float(), k.float()
|
q, k = q.float(), k.float()
|
||||||
@ -480,7 +482,7 @@ def sdp_attnblock_forward(self, x):
|
|||||||
k = self.k(h_)
|
k = self.k(h_)
|
||||||
v = self.v(h_)
|
v = self.v(h_)
|
||||||
b, c, h, w = q.shape
|
b, c, h, w = q.shape
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
|
q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
|
||||||
dtype = q.dtype
|
dtype = q.dtype
|
||||||
if shared.opts.upcast_attn:
|
if shared.opts.upcast_attn:
|
||||||
q, k = q.float(), k.float()
|
q, k = q.float(), k.float()
|
||||||
@ -504,7 +506,7 @@ def sub_quad_attnblock_forward(self, x):
|
|||||||
k = self.k(h_)
|
k = self.k(h_)
|
||||||
v = self.v(h_)
|
v = self.v(h_)
|
||||||
b, c, h, w = q.shape
|
b, c, h, w = q.shape
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
|
q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
|
||||||
q = q.contiguous()
|
q = q.contiguous()
|
||||||
k = k.contiguous()
|
k = k.contiguous()
|
||||||
v = v.contiguous()
|
v = v.contiguous()
|
||||||
|
@ -18,7 +18,7 @@ class TorchHijackForUnet:
|
|||||||
if hasattr(torch, item):
|
if hasattr(torch, item):
|
||||||
return getattr(torch, item)
|
return getattr(torch, item)
|
||||||
|
|
||||||
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))
|
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
|
||||||
|
|
||||||
def cat(self, tensors, *args, **kwargs):
|
def cat(self, tensors, *args, **kwargs):
|
||||||
if len(tensors) == 2:
|
if len(tensors) == 2:
|
||||||
|
@ -1,8 +1,6 @@
|
|||||||
import open_clip.tokenizer
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from modules import sd_hijack_clip, devices
|
from modules import sd_hijack_clip, devices
|
||||||
from modules.shared import opts
|
|
||||||
|
|
||||||
|
|
||||||
class FrozenXLMREmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords):
|
class FrozenXLMREmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords):
|
||||||
|
@ -2,6 +2,8 @@ import collections
|
|||||||
import os.path
|
import os.path
|
||||||
import sys
|
import sys
|
||||||
import gc
|
import gc
|
||||||
|
import threading
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import re
|
import re
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
@ -13,9 +15,9 @@ import ldm.modules.midas as midas
|
|||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
|
|
||||||
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config
|
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config
|
||||||
from modules.paths import models_path
|
|
||||||
from modules.sd_hijack_inpainting import do_inpainting_hijack
|
from modules.sd_hijack_inpainting import do_inpainting_hijack
|
||||||
from modules.timer import Timer
|
from modules.timer import Timer
|
||||||
|
import tomesd
|
||||||
|
|
||||||
model_dir = "Stable-diffusion"
|
model_dir = "Stable-diffusion"
|
||||||
model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
|
model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
|
||||||
@ -45,7 +47,7 @@ class CheckpointInfo:
|
|||||||
self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
|
self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
|
||||||
self.hash = model_hash(filename)
|
self.hash = model_hash(filename)
|
||||||
|
|
||||||
self.sha256 = hashes.sha256_from_cache(self.filename, "checkpoint/" + name)
|
self.sha256 = hashes.sha256_from_cache(self.filename, f"checkpoint/{name}")
|
||||||
self.shorthash = self.sha256[0:10] if self.sha256 else None
|
self.shorthash = self.sha256[0:10] if self.sha256 else None
|
||||||
|
|
||||||
self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'
|
self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'
|
||||||
@ -67,7 +69,7 @@ class CheckpointInfo:
|
|||||||
checkpoint_alisases[id] = self
|
checkpoint_alisases[id] = self
|
||||||
|
|
||||||
def calculate_shorthash(self):
|
def calculate_shorthash(self):
|
||||||
self.sha256 = hashes.sha256(self.filename, "checkpoint/" + self.name)
|
self.sha256 = hashes.sha256(self.filename, f"checkpoint/{self.name}")
|
||||||
if self.sha256 is None:
|
if self.sha256 is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -85,8 +87,7 @@ class CheckpointInfo:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
|
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
|
||||||
|
from transformers import logging, CLIPModel # noqa: F401
|
||||||
from transformers import logging, CLIPModel
|
|
||||||
|
|
||||||
logging.set_verbosity_error()
|
logging.set_verbosity_error()
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -165,7 +166,7 @@ def model_hash(filename):
|
|||||||
|
|
||||||
def select_checkpoint():
|
def select_checkpoint():
|
||||||
model_checkpoint = shared.opts.sd_model_checkpoint
|
model_checkpoint = shared.opts.sd_model_checkpoint
|
||||||
|
|
||||||
checkpoint_info = checkpoint_alisases.get(model_checkpoint, None)
|
checkpoint_info = checkpoint_alisases.get(model_checkpoint, None)
|
||||||
if checkpoint_info is not None:
|
if checkpoint_info is not None:
|
||||||
return checkpoint_info
|
return checkpoint_info
|
||||||
@ -237,7 +238,7 @@ def read_metadata_from_safetensors(filename):
|
|||||||
if isinstance(v, str) and v[0:1] == '{':
|
if isinstance(v, str) and v[0:1] == '{':
|
||||||
try:
|
try:
|
||||||
res[k] = json.loads(v)
|
res[k] = json.loads(v)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return res
|
return res
|
||||||
@ -372,7 +373,7 @@ def enable_midas_autodownload():
|
|||||||
if not os.path.exists(path):
|
if not os.path.exists(path):
|
||||||
if not os.path.exists(midas_path):
|
if not os.path.exists(midas_path):
|
||||||
mkdir(midas_path)
|
mkdir(midas_path)
|
||||||
|
|
||||||
print(f"Downloading midas model weights for {model_type} to {path}")
|
print(f"Downloading midas model weights for {model_type} to {path}")
|
||||||
request.urlretrieve(midas_urls[model_type], path)
|
request.urlretrieve(midas_urls[model_type], path)
|
||||||
print(f"{model_type} downloaded")
|
print(f"{model_type} downloaded")
|
||||||
@ -404,13 +405,39 @@ def repair_config(sd_config):
|
|||||||
sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
|
sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
|
||||||
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
|
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
|
||||||
|
|
||||||
def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_to_load_state_dict=None):
|
|
||||||
|
class SdModelData:
|
||||||
|
def __init__(self):
|
||||||
|
self.sd_model = None
|
||||||
|
self.lock = threading.Lock()
|
||||||
|
|
||||||
|
def get_sd_model(self):
|
||||||
|
if self.sd_model is None:
|
||||||
|
with self.lock:
|
||||||
|
try:
|
||||||
|
load_model()
|
||||||
|
except Exception as e:
|
||||||
|
errors.display(e, "loading stable diffusion model")
|
||||||
|
print("", file=sys.stderr)
|
||||||
|
print("Stable diffusion model failed to load", file=sys.stderr)
|
||||||
|
self.sd_model = None
|
||||||
|
|
||||||
|
return self.sd_model
|
||||||
|
|
||||||
|
def set_sd_model(self, v):
|
||||||
|
self.sd_model = v
|
||||||
|
|
||||||
|
|
||||||
|
model_data = SdModelData()
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
||||||
from modules import lowvram, sd_hijack
|
from modules import lowvram, sd_hijack
|
||||||
checkpoint_info = checkpoint_info or select_checkpoint()
|
checkpoint_info = checkpoint_info or select_checkpoint()
|
||||||
|
|
||||||
if shared.sd_model:
|
if model_data.sd_model:
|
||||||
sd_hijack.model_hijack.undo_hijack(shared.sd_model)
|
sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
|
||||||
shared.sd_model = None
|
model_data.sd_model = None
|
||||||
gc.collect()
|
gc.collect()
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
@ -439,7 +466,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_
|
|||||||
try:
|
try:
|
||||||
with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd):
|
with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd):
|
||||||
sd_model = instantiate_from_config(sd_config.model)
|
sd_model = instantiate_from_config(sd_config.model)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if sd_model is None:
|
if sd_model is None:
|
||||||
@ -464,7 +491,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_
|
|||||||
timer.record("hijack")
|
timer.record("hijack")
|
||||||
|
|
||||||
sd_model.eval()
|
sd_model.eval()
|
||||||
shared.sd_model = sd_model
|
model_data.sd_model = sd_model
|
||||||
|
|
||||||
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
|
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
|
||||||
|
|
||||||
@ -484,7 +511,7 @@ def reload_model_weights(sd_model=None, info=None):
|
|||||||
checkpoint_info = info or select_checkpoint()
|
checkpoint_info = info or select_checkpoint()
|
||||||
|
|
||||||
if not sd_model:
|
if not sd_model:
|
||||||
sd_model = shared.sd_model
|
sd_model = model_data.sd_model
|
||||||
|
|
||||||
if sd_model is None: # previous model load failed
|
if sd_model is None: # previous model load failed
|
||||||
current_checkpoint_info = None
|
current_checkpoint_info = None
|
||||||
@ -512,11 +539,11 @@ def reload_model_weights(sd_model=None, info=None):
|
|||||||
del sd_model
|
del sd_model
|
||||||
checkpoints_loaded.clear()
|
checkpoints_loaded.clear()
|
||||||
load_model(checkpoint_info, already_loaded_state_dict=state_dict)
|
load_model(checkpoint_info, already_loaded_state_dict=state_dict)
|
||||||
return shared.sd_model
|
return model_data.sd_model
|
||||||
|
|
||||||
try:
|
try:
|
||||||
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
|
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
print("Failed to load checkpoint, restoring previous")
|
print("Failed to load checkpoint, restoring previous")
|
||||||
load_model_weights(sd_model, current_checkpoint_info, None, timer)
|
load_model_weights(sd_model, current_checkpoint_info, None, timer)
|
||||||
raise
|
raise
|
||||||
@ -535,17 +562,15 @@ def reload_model_weights(sd_model=None, info=None):
|
|||||||
|
|
||||||
return sd_model
|
return sd_model
|
||||||
|
|
||||||
|
|
||||||
def unload_model_weights(sd_model=None, info=None):
|
def unload_model_weights(sd_model=None, info=None):
|
||||||
from modules import lowvram, devices, sd_hijack
|
from modules import devices, sd_hijack
|
||||||
timer = Timer()
|
timer = Timer()
|
||||||
|
|
||||||
if shared.sd_model:
|
if model_data.sd_model:
|
||||||
|
model_data.sd_model.to(devices.cpu)
|
||||||
# shared.sd_model.cond_stage_model.to(devices.cpu)
|
sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
|
||||||
# shared.sd_model.first_stage_model.to(devices.cpu)
|
model_data.sd_model = None
|
||||||
shared.sd_model.to(devices.cpu)
|
|
||||||
sd_hijack.model_hijack.undo_hijack(shared.sd_model)
|
|
||||||
shared.sd_model = None
|
|
||||||
sd_model = None
|
sd_model = None
|
||||||
gc.collect()
|
gc.collect()
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
@ -554,3 +579,25 @@ def unload_model_weights(sd_model=None, info=None):
|
|||||||
print(f"Unloaded weights {timer.summary()}.")
|
print(f"Unloaded weights {timer.summary()}.")
|
||||||
|
|
||||||
return sd_model
|
return sd_model
|
||||||
|
|
||||||
|
|
||||||
|
def apply_token_merging(sd_model, hr: bool):
|
||||||
|
"""
|
||||||
|
Applies speed and memory optimizations from tomesd.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hr (bool): True if called in the context of a high-res pass
|
||||||
|
"""
|
||||||
|
|
||||||
|
ratio = shared.opts.token_merging_ratio
|
||||||
|
if hr:
|
||||||
|
ratio = shared.opts.token_merging_ratio_hr
|
||||||
|
|
||||||
|
tomesd.apply_patch(
|
||||||
|
sd_model,
|
||||||
|
ratio=ratio,
|
||||||
|
use_rand=False, # can cause issues with some samplers
|
||||||
|
merge_attn=True,
|
||||||
|
merge_crossattn=False,
|
||||||
|
merge_mlp=False
|
||||||
|
)
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import re
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -111,7 +110,7 @@ def find_checkpoint_config_near_filename(info):
|
|||||||
if info is None:
|
if info is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
config = os.path.splitext(info.filename)[0] + ".yaml"
|
config = f"{os.path.splitext(info.filename)[0]}.yaml"
|
||||||
if os.path.exists(config):
|
if os.path.exists(config):
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from modules import sd_samplers_compvis, sd_samplers_kdiffusion, shared
|
from modules import sd_samplers_compvis, sd_samplers_kdiffusion, shared
|
||||||
|
|
||||||
# imports for functions that previously were here and are used by other modules
|
# imports for functions that previously were here and are used by other modules
|
||||||
from modules.sd_samplers_common import samples_to_image_grid, sample_to_image
|
from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # noqa: F401
|
||||||
|
|
||||||
all_samplers = [
|
all_samplers = [
|
||||||
*sd_samplers_kdiffusion.samplers_data_k_diffusion,
|
*sd_samplers_kdiffusion.samplers_data_k_diffusion,
|
||||||
|
@ -55,7 +55,7 @@ class VanillaStableDiffusionSampler:
|
|||||||
def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
|
def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
|
||||||
x_dec, ts, cond, unconditional_conditioning = self.before_sample(x_dec, ts, cond, unconditional_conditioning)
|
x_dec, ts, cond, unconditional_conditioning = self.before_sample(x_dec, ts, cond, unconditional_conditioning)
|
||||||
|
|
||||||
res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
|
res = self.orig_p_sample_ddim(x_dec, cond, ts, *args, unconditional_conditioning=unconditional_conditioning, **kwargs)
|
||||||
|
|
||||||
x_dec, ts, cond, unconditional_conditioning, res = self.after_sample(x_dec, ts, cond, unconditional_conditioning, res)
|
x_dec, ts, cond, unconditional_conditioning, res = self.after_sample(x_dec, ts, cond, unconditional_conditioning, res)
|
||||||
|
|
||||||
@ -83,7 +83,7 @@ class VanillaStableDiffusionSampler:
|
|||||||
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
||||||
unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
|
unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
|
||||||
|
|
||||||
assert all([len(conds) == 1 for conds in conds_list]), 'composition via AND is not supported for DDIM/PLMS samplers'
|
assert all(len(conds) == 1 for conds in conds_list), 'composition via AND is not supported for DDIM/PLMS samplers'
|
||||||
cond = tensor
|
cond = tensor
|
||||||
|
|
||||||
# for DDIM, shapes must match, we can't just process cond and uncond independently;
|
# for DDIM, shapes must match, we can't just process cond and uncond independently;
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
from collections import deque
|
from collections import deque
|
||||||
import torch
|
import torch
|
||||||
import inspect
|
import inspect
|
||||||
import einops
|
|
||||||
import k_diffusion.sampling
|
import k_diffusion.sampling
|
||||||
from modules import prompt_parser, devices, sd_samplers_common
|
from modules import prompt_parser, devices, sd_samplers_common
|
||||||
|
|
||||||
@ -9,6 +8,7 @@ from modules.shared import opts, state
|
|||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
|
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
|
||||||
from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback
|
from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback
|
||||||
|
from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback
|
||||||
|
|
||||||
samplers_k_diffusion = [
|
samplers_k_diffusion = [
|
||||||
('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}),
|
('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}),
|
||||||
@ -87,17 +87,17 @@ class CFGDenoiser(torch.nn.Module):
|
|||||||
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
||||||
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
|
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
|
||||||
|
|
||||||
assert not is_edit_model or all([len(conds) == 1 for conds in conds_list]), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
|
assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
|
||||||
|
|
||||||
batch_size = len(conds_list)
|
batch_size = len(conds_list)
|
||||||
repeats = [len(conds_list[i]) for i in range(batch_size)]
|
repeats = [len(conds_list[i]) for i in range(batch_size)]
|
||||||
|
|
||||||
if shared.sd_model.model.conditioning_key == "crossattn-adm":
|
if shared.sd_model.model.conditioning_key == "crossattn-adm":
|
||||||
image_uncond = torch.zeros_like(image_cond)
|
image_uncond = torch.zeros_like(image_cond)
|
||||||
make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": c_crossattn, "c_adm": c_adm}
|
make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": c_crossattn, "c_adm": c_adm}
|
||||||
else:
|
else:
|
||||||
image_uncond = image_cond
|
image_uncond = image_cond
|
||||||
make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": c_crossattn, "c_concat": [c_concat]}
|
make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": c_crossattn, "c_concat": [c_concat]}
|
||||||
|
|
||||||
if not is_edit_model:
|
if not is_edit_model:
|
||||||
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
|
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
|
||||||
@ -161,7 +161,7 @@ class CFGDenoiser(torch.nn.Module):
|
|||||||
fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes])
|
fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes])
|
||||||
x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be
|
x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be
|
||||||
|
|
||||||
denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps)
|
denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model)
|
||||||
cfg_denoised_callback(denoised_params)
|
cfg_denoised_callback(denoised_params)
|
||||||
|
|
||||||
devices.test_for_nans(x_out, "unet")
|
devices.test_for_nans(x_out, "unet")
|
||||||
@ -181,6 +181,10 @@ class CFGDenoiser(torch.nn.Module):
|
|||||||
if self.mask is not None:
|
if self.mask is not None:
|
||||||
denoised = self.init_latent * self.mask + self.nmask * denoised
|
denoised = self.init_latent * self.mask + self.nmask * denoised
|
||||||
|
|
||||||
|
after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps)
|
||||||
|
cfg_after_cfg_callback(after_cfg_callback_params)
|
||||||
|
denoised = after_cfg_callback_params.x
|
||||||
|
|
||||||
self.step += 1
|
self.step += 1
|
||||||
return denoised
|
return denoised
|
||||||
|
|
||||||
@ -198,7 +202,7 @@ class TorchHijack:
|
|||||||
if hasattr(torch, item):
|
if hasattr(torch, item):
|
||||||
return getattr(torch, item)
|
return getattr(torch, item)
|
||||||
|
|
||||||
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))
|
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
|
||||||
|
|
||||||
def randn_like(self, x):
|
def randn_like(self, x):
|
||||||
if self.sampler_noises:
|
if self.sampler_noises:
|
||||||
@ -317,7 +321,7 @@ class KDiffusionSampler:
|
|||||||
|
|
||||||
sigma_sched = sigmas[steps - t_enc - 1:]
|
sigma_sched = sigmas[steps - t_enc - 1:]
|
||||||
xi = x + noise * sigma_sched[0]
|
xi = x + noise * sigma_sched[0]
|
||||||
|
|
||||||
extra_params_kwargs = self.initialize(p)
|
extra_params_kwargs = self.initialize(p)
|
||||||
parameters = inspect.signature(self.func).parameters
|
parameters = inspect.signature(self.func).parameters
|
||||||
|
|
||||||
@ -340,9 +344,9 @@ class KDiffusionSampler:
|
|||||||
self.model_wrap_cfg.init_latent = x
|
self.model_wrap_cfg.init_latent = x
|
||||||
self.last_latent = x
|
self.last_latent = x
|
||||||
extra_args={
|
extra_args={
|
||||||
'cond': conditioning,
|
'cond': conditioning,
|
||||||
'image_cond': image_conditioning,
|
'image_cond': image_conditioning,
|
||||||
'uncond': unconditional_conditioning,
|
'uncond': unconditional_conditioning,
|
||||||
'cond_scale': p.cfg_scale,
|
'cond_scale': p.cfg_scale,
|
||||||
's_min_uncond': self.s_min_uncond
|
's_min_uncond': self.s_min_uncond
|
||||||
}
|
}
|
||||||
@ -375,9 +379,9 @@ class KDiffusionSampler:
|
|||||||
|
|
||||||
self.last_latent = x
|
self.last_latent = x
|
||||||
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
|
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
|
||||||
'cond': conditioning,
|
'cond': conditioning,
|
||||||
'image_cond': image_conditioning,
|
'image_cond': image_conditioning,
|
||||||
'uncond': unconditional_conditioning,
|
'uncond': unconditional_conditioning,
|
||||||
'cond_scale': p.cfg_scale,
|
'cond_scale': p.cfg_scale,
|
||||||
's_min_uncond': self.s_min_uncond
|
's_min_uncond': self.s_min_uncond
|
||||||
}, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
}, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||||
|
@ -1,8 +1,5 @@
|
|||||||
import torch
|
|
||||||
import safetensors.torch
|
|
||||||
import os
|
import os
|
||||||
import collections
|
import collections
|
||||||
from collections import namedtuple
|
|
||||||
from modules import paths, shared, devices, script_callbacks, sd_models
|
from modules import paths, shared, devices, script_callbacks, sd_models
|
||||||
import glob
|
import glob
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
@ -1,12 +1,9 @@
|
|||||||
import argparse
|
|
||||||
import datetime
|
import datetime
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import requests
|
|
||||||
|
|
||||||
from PIL import Image
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import tqdm
|
import tqdm
|
||||||
|
|
||||||
@ -15,7 +12,8 @@ import modules.memmon
|
|||||||
import modules.styles
|
import modules.styles
|
||||||
import modules.devices as devices
|
import modules.devices as devices
|
||||||
from modules import localization, script_loading, errors, ui_components, shared_items, cmd_args
|
from modules import localization, script_loading, errors, ui_components, shared_items, cmd_args
|
||||||
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir
|
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401
|
||||||
|
from ldm.models.diffusion.ddpm import LatentDiffusion
|
||||||
|
|
||||||
demo = None
|
demo = None
|
||||||
|
|
||||||
@ -201,8 +199,9 @@ interrogator = modules.interrogate.InterrogateModels("interrogate")
|
|||||||
|
|
||||||
face_restorers = []
|
face_restorers = []
|
||||||
|
|
||||||
|
|
||||||
class OptionInfo:
|
class OptionInfo:
|
||||||
def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None):
|
def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None, comment_before='', comment_after=''):
|
||||||
self.default = default
|
self.default = default
|
||||||
self.label = label
|
self.label = label
|
||||||
self.component = component
|
self.component = component
|
||||||
@ -211,9 +210,33 @@ class OptionInfo:
|
|||||||
self.section = section
|
self.section = section
|
||||||
self.refresh = refresh
|
self.refresh = refresh
|
||||||
|
|
||||||
|
self.comment_before = comment_before
|
||||||
|
"""HTML text that will be added after label in UI"""
|
||||||
|
|
||||||
|
self.comment_after = comment_after
|
||||||
|
"""HTML text that will be added before label in UI"""
|
||||||
|
|
||||||
|
def link(self, label, url):
|
||||||
|
self.comment_before += f"[<a href='{url}' target='_blank'>{label}</a>]"
|
||||||
|
return self
|
||||||
|
|
||||||
|
def js(self, label, js_func):
|
||||||
|
self.comment_before += f"[<a onclick='{js_func}(); return false'>{label}</a>]"
|
||||||
|
return self
|
||||||
|
|
||||||
|
def info(self, info):
|
||||||
|
self.comment_after += f"<span class='info'>({info})</span>"
|
||||||
|
return self
|
||||||
|
|
||||||
|
def needs_restart(self):
|
||||||
|
self.comment_after += " <span class='info'>(requires restart)</span>"
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def options_section(section_identifier, options_dict):
|
def options_section(section_identifier, options_dict):
|
||||||
for k, v in options_dict.items():
|
for v in options_dict.values():
|
||||||
v.section = section_identifier
|
v.section = section_identifier
|
||||||
|
|
||||||
return options_dict
|
return options_dict
|
||||||
@ -242,7 +265,7 @@ options_templates = {}
|
|||||||
options_templates.update(options_section(('saving-images', "Saving images/grids"), {
|
options_templates.update(options_section(('saving-images', "Saving images/grids"), {
|
||||||
"samples_save": OptionInfo(True, "Always save all generated images"),
|
"samples_save": OptionInfo(True, "Always save all generated images"),
|
||||||
"samples_format": OptionInfo('png', 'File format for images'),
|
"samples_format": OptionInfo('png', 'File format for images'),
|
||||||
"samples_filename_pattern": OptionInfo("", "Images filename pattern", component_args=hide_dirs),
|
"samples_filename_pattern": OptionInfo("", "Images filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
|
||||||
"save_images_add_number": OptionInfo(True, "Add number to filename when saving", component_args=hide_dirs),
|
"save_images_add_number": OptionInfo(True, "Add number to filename when saving", component_args=hide_dirs),
|
||||||
|
|
||||||
"grid_save": OptionInfo(True, "Always save all generated image grids"),
|
"grid_save": OptionInfo(True, "Always save all generated image grids"),
|
||||||
@ -261,10 +284,10 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
|
|||||||
"save_mask_composite": OptionInfo(False, "For inpainting, save a masked composite"),
|
"save_mask_composite": OptionInfo(False, "For inpainting, save a masked composite"),
|
||||||
"jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
|
"jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
|
||||||
"webp_lossless": OptionInfo(False, "Use lossless compression for webp images"),
|
"webp_lossless": OptionInfo(False, "Use lossless compression for webp images"),
|
||||||
"export_for_4chan": OptionInfo(True, "If the saved image file size is above the limit, or its either width or height are above the limit, save a downscaled copy as JPG"),
|
"export_for_4chan": OptionInfo(True, "Save copy of large images as JPG").info("if the file size is above the limit, or either width or height are above the limit"),
|
||||||
"img_downscale_threshold": OptionInfo(4.0, "File size limit for the above option, MB", gr.Number),
|
"img_downscale_threshold": OptionInfo(4.0, "File size limit for the above option, MB", gr.Number),
|
||||||
"target_side_length": OptionInfo(4000, "Width/height limit for the above option, in pixels", gr.Number),
|
"target_side_length": OptionInfo(4000, "Width/height limit for the above option, in pixels", gr.Number),
|
||||||
"img_max_size_mp": OptionInfo(200, "Maximum image size, in megapixels", gr.Number),
|
"img_max_size_mp": OptionInfo(200, "Maximum image size", gr.Number).info("in megapixels"),
|
||||||
|
|
||||||
"use_original_name_batch": OptionInfo(True, "Use original name for output filename during batch process in extras tab"),
|
"use_original_name_batch": OptionInfo(True, "Use original name for output filename during batch process in extras tab"),
|
||||||
"use_upscaler_name_as_suffix": OptionInfo(False, "Use upscaler name as filename suffix in the extras tab"),
|
"use_upscaler_name_as_suffix": OptionInfo(False, "Use upscaler name as filename suffix in the extras tab"),
|
||||||
@ -292,28 +315,26 @@ options_templates.update(options_section(('saving-to-dirs', "Saving to a directo
|
|||||||
"save_to_dirs": OptionInfo(True, "Save images to a subdirectory"),
|
"save_to_dirs": OptionInfo(True, "Save images to a subdirectory"),
|
||||||
"grid_save_to_dirs": OptionInfo(True, "Save grids to a subdirectory"),
|
"grid_save_to_dirs": OptionInfo(True, "Save grids to a subdirectory"),
|
||||||
"use_save_to_dirs_for_ui": OptionInfo(False, "When using \"Save\" button, save images to a subdirectory"),
|
"use_save_to_dirs_for_ui": OptionInfo(False, "When using \"Save\" button, save images to a subdirectory"),
|
||||||
"directories_filename_pattern": OptionInfo("[date]", "Directory name pattern", component_args=hide_dirs),
|
"directories_filename_pattern": OptionInfo("[date]", "Directory name pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
|
||||||
"directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1, **hide_dirs}),
|
"directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1, **hide_dirs}),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('upscaling', "Upscaling"), {
|
options_templates.update(options_section(('upscaling', "Upscaling"), {
|
||||||
"ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
|
"ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).info("0 = no tiling"),
|
||||||
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
|
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}).info("Low values = visible seam"),
|
||||||
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}),
|
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI.", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}),
|
||||||
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}),
|
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}),
|
||||||
"SCUNET_tile": OptionInfo(256, "Tile size for SCUNET upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
|
|
||||||
"SCUNET_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SCUNET upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}),
|
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('face-restoration', "Face restoration"), {
|
options_templates.update(options_section(('face-restoration', "Face restoration"), {
|
||||||
"face_restoration_model": OptionInfo("CodeFormer", "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in face_restorers]}),
|
"face_restoration_model": OptionInfo("CodeFormer", "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in face_restorers]}),
|
||||||
"code_former_weight": OptionInfo(0.5, "CodeFormer weight parameter; 0 = maximum effect; 1 = minimum effect", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
|
"code_former_weight": OptionInfo(0.5, "CodeFormer weight", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}).info("0 = maximum effect; 1 = minimum effect"),
|
||||||
"face_restoration_unload": OptionInfo(False, "Move face restoration model from VRAM into RAM after processing"),
|
"face_restoration_unload": OptionInfo(False, "Move face restoration model from VRAM into RAM after processing"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('system', "System"), {
|
options_templates.update(options_section(('system', "System"), {
|
||||||
"show_warnings": OptionInfo(False, "Show warnings in console."),
|
"show_warnings": OptionInfo(False, "Show warnings in console."),
|
||||||
"memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation. Set to 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}),
|
"memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}).info("0 = disable"),
|
||||||
"samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"),
|
"samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"),
|
||||||
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
|
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
|
||||||
"print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."),
|
"print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."),
|
||||||
@ -338,20 +359,22 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
|||||||
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
|
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
|
||||||
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||||
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||||
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list),
|
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list).info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"),
|
||||||
"sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
|
"sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
|
||||||
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||||
"initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}),
|
"initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}),
|
||||||
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
|
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
|
||||||
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
|
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies.").info("normally you'd do less with less denoising"),
|
||||||
"img2img_background_color": OptionInfo("#ffffff", "With img2img, fill image's transparent parts with this color.", ui_components.FormColorPicker, {}),
|
"img2img_background_color": OptionInfo("#ffffff", "With img2img, fill image's transparent parts with this color.", ui_components.FormColorPicker, {}),
|
||||||
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
|
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
|
||||||
"enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
|
"enable_emphasis": OptionInfo(True, "Enable emphasis").info("use (text) to make model pay more attention to text and [text] to make it pay less attention"),
|
||||||
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
|
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
|
||||||
"comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
|
"comma_padding_backtrack": OptionInfo(20, "Prompt word wrap length limit", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1}).info("in tokens - for texts shorter than specified, if they don't fit into 75 token limit, move them to the next 75 token chunk"),
|
||||||
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
|
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP nrtwork; 1 ignores none, 2 ignores one layer"),
|
||||||
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
|
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
|
||||||
"randn_source": OptionInfo("GPU", "Random number generator source. Changes seeds drastically. Use CPU to produce the same picture across different vidocard vendors.", gr.Radio, {"choices": ["GPU", "CPU"]}),
|
"randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU"]}).info("changes seeds drastically; use CPU to produce the same picture across different vidocard vendors"),
|
||||||
|
"token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"),
|
||||||
|
"token_merging_ratio_hr": OptionInfo(0.0, "Togen merging ratio for high-res pass", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('compatibility', "Compatibility"), {
|
options_templates.update(options_section(('compatibility', "Compatibility"), {
|
||||||
@ -363,80 +386,87 @@ options_templates.update(options_section(('compatibility', "Compatibility"), {
|
|||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('interrogate', "Interrogate Options"), {
|
options_templates.update(options_section(('interrogate', "Interrogate Options"), {
|
||||||
"interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"),
|
"interrogate_keep_models_in_memory": OptionInfo(False, "Keep models in VRAM"),
|
||||||
"interrogate_return_ranks": OptionInfo(False, "Interrogate: include ranks of model tags matches in results (Has no effect on caption-based interrogators)."),
|
"interrogate_return_ranks": OptionInfo(False, "Include ranks of model tags matches in results.").info("booru only"),
|
||||||
"interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
|
"interrogate_clip_num_beams": OptionInfo(1, "BLIP: num_beams", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
|
||||||
"interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
|
"interrogate_clip_min_length": OptionInfo(24, "BLIP: minimum description length", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
|
||||||
"interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
|
"interrogate_clip_max_length": OptionInfo(48, "BLIP: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
|
||||||
"interrogate_clip_dict_limit": OptionInfo(1500, "CLIP: maximum number of lines in text file (0 = No limit)"),
|
"interrogate_clip_dict_limit": OptionInfo(1500, "CLIP: maximum number of lines in text file").info("0 = No limit"),
|
||||||
"interrogate_clip_skip_categories": OptionInfo([], "CLIP: skip inquire categories", gr.CheckboxGroup, lambda: {"choices": modules.interrogate.category_types()}, refresh=modules.interrogate.category_types),
|
"interrogate_clip_skip_categories": OptionInfo([], "CLIP: skip inquire categories", gr.CheckboxGroup, lambda: {"choices": modules.interrogate.category_types()}, refresh=modules.interrogate.category_types),
|
||||||
"interrogate_deepbooru_score_threshold": OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
|
"interrogate_deepbooru_score_threshold": OptionInfo(0.5, "deepbooru: score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
|
||||||
"deepbooru_sort_alpha": OptionInfo(True, "Interrogate: deepbooru sort alphabetically"),
|
"deepbooru_sort_alpha": OptionInfo(True, "deepbooru: sort tags alphabetically").info("if not: sort by score"),
|
||||||
"deepbooru_use_spaces": OptionInfo(False, "use spaces for tags in deepbooru"),
|
"deepbooru_use_spaces": OptionInfo(True, "deepbooru: use spaces in tags").info("if not: use underscores"),
|
||||||
"deepbooru_escape": OptionInfo(True, "escape (\\) brackets in deepbooru (so they are used as literal brackets and not for emphasis)"),
|
"deepbooru_escape": OptionInfo(True, "deepbooru: escape (\\) brackets").info("so they are used as literal brackets and not for emphasis"),
|
||||||
"deepbooru_filter_tags": OptionInfo("", "filter out those tags from deepbooru output (separated by comma)"),
|
"deepbooru_filter_tags": OptionInfo("", "deepbooru: filter out those tags").info("separate by comma"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('extra_networks', "Extra Networks"), {
|
options_templates.update(options_section(('extra_networks', "Extra Networks"), {
|
||||||
"extra_networks_default_view": OptionInfo("cards", "Default view for Extra Networks", gr.Dropdown, {"choices": ["cards", "thumbs"]}),
|
"extra_networks_default_view": OptionInfo("cards", "Default view for Extra Networks", gr.Dropdown, {"choices": ["cards", "thumbs"]}),
|
||||||
"extra_networks_default_multiplier": OptionInfo(1.0, "Multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
"extra_networks_default_multiplier": OptionInfo(1.0, "Multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||||
"extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks (px)"),
|
"extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks").info("in pixels"),
|
||||||
"extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks (px)"),
|
"extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks").info("in pixels"),
|
||||||
"extra_networks_add_text_separator": OptionInfo(" ", "Extra text to add before <...> when adding extra network to prompt"),
|
"extra_networks_add_text_separator": OptionInfo(" ", "Extra networks separator").info("extra text to add before <...> when adding extra network to prompt"),
|
||||||
"sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
|
"sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None", *hypernetworks]}, refresh=reload_hypernetworks),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('ui', "User interface"), {
|
options_templates.update(options_section(('ui', "User interface"), {
|
||||||
|
"localization": OptionInfo("None", "Localization", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)).needs_restart(),
|
||||||
|
"gradio_theme": OptionInfo("Default", "Gradio theme", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + gradio_hf_hub_themes}).needs_restart(),
|
||||||
"return_grid": OptionInfo(True, "Show grid in results for web"),
|
"return_grid": OptionInfo(True, "Show grid in results for web"),
|
||||||
"return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"),
|
"return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"),
|
||||||
"return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"),
|
"return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"),
|
||||||
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
|
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
|
||||||
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
|
|
||||||
"add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
|
|
||||||
"disable_weights_auto_swap": OptionInfo(True, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."),
|
|
||||||
"send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
|
"send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
|
||||||
"send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
|
"send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
|
||||||
"font": OptionInfo("", "Font for image grids that have text"),
|
"font": OptionInfo("", "Font for image grids that have text"),
|
||||||
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
|
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
|
||||||
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
|
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
|
||||||
|
"js_modal_lightbox_gamepad": OptionInfo(True, "Navigate image viewer with gamepad"),
|
||||||
|
"js_modal_lightbox_gamepad_repeat": OptionInfo(250, "Gamepad repeat period, in milliseconds"),
|
||||||
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
|
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
|
||||||
"samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group"),
|
"samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group").needs_restart(),
|
||||||
"dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row"),
|
"dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row").needs_restart(),
|
||||||
"keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
|
"keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
|
||||||
"keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
|
"keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
|
||||||
"keyedit_delimiters": OptionInfo(".,\/!?%^*;:{}=`~()", "Ctrl+up/down word delimiters"),
|
"keyedit_delimiters": OptionInfo(".,\\/!?%^*;:{}=`~()", "Ctrl+up/down word delimiters"),
|
||||||
"quicksettings": OptionInfo("sd_model_checkpoint", "Quicksettings list"),
|
"quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_restart(),
|
||||||
"hidden_tabs": OptionInfo([], "Hidden UI tabs (requires restart)", ui_components.DropdownMulti, lambda: {"choices": [x for x in tab_names]}),
|
"hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(),
|
||||||
"ui_reorder": OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"),
|
"ui_reorder": OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"),
|
||||||
"ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order"),
|
"ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_restart(),
|
||||||
"localization": OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)),
|
}))
|
||||||
"gradio_theme": OptionInfo("Default", "Gradio theme (requires restart)", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + gradio_hf_hub_themes})
|
|
||||||
|
options_templates.update(options_section(('infotext', "Infotext"), {
|
||||||
|
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
|
||||||
|
"add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
|
||||||
|
"add_version_to_infotext": OptionInfo(True, "Add program version to generation information"),
|
||||||
|
"disable_weights_auto_swap": OptionInfo(True, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('ui', "Live previews"), {
|
options_templates.update(options_section(('ui', "Live previews"), {
|
||||||
"show_progressbar": OptionInfo(True, "Show progressbar"),
|
"show_progressbar": OptionInfo(True, "Show progressbar"),
|
||||||
"live_previews_enable": OptionInfo(True, "Show live previews of the created image"),
|
"live_previews_enable": OptionInfo(True, "Show live previews of the created image"),
|
||||||
|
"live_previews_image_format": OptionInfo("png", "Live preview file format", gr.Radio, {"choices": ["jpeg", "png", "webp"]}),
|
||||||
"show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
|
"show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
|
||||||
"show_progress_every_n_steps": OptionInfo(10, "Show new live preview image every N sampling steps. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}),
|
"show_progress_every_n_steps": OptionInfo(10, "Live preview display period", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}).info("in sampling steps - show new live preview image every N sampling steps; -1 = only show after completion of batch"),
|
||||||
"show_progress_type": OptionInfo("Approx NN", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}),
|
"show_progress_type": OptionInfo("Approx NN", "Live preview method", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}).info("Full = slow but pretty; Approx NN = fast but low quality; Approx cheap = super fast but terrible otherwise"),
|
||||||
"live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}),
|
"live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}),
|
||||||
"live_preview_refresh_period": OptionInfo(1000, "Progressbar/preview update period, in milliseconds")
|
"live_preview_refresh_period": OptionInfo(1000, "Progressbar and preview update period").info("in milliseconds"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
|
options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
|
||||||
"hide_samplers": OptionInfo([], "Hide samplers in user interface (requires restart)", gr.CheckboxGroup, lambda: {"choices": [x.name for x in list_samplers()]}),
|
"hide_samplers": OptionInfo([], "Hide samplers in user interface", gr.CheckboxGroup, lambda: {"choices": [x.name for x in list_samplers()]}).needs_restart(),
|
||||||
"eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
"eta_ddim": OptionInfo(0.0, "Eta for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}).info("noise multiplier; higher = more unperdictable results"),
|
||||||
"eta_ancestral": OptionInfo(1.0, "eta (noise multiplier) for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
"eta_ancestral": OptionInfo(1.0, "Eta for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}).info("noise multiplier; applies to Euler a and other samplers that have a in them"),
|
||||||
"ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
|
"ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
|
||||||
's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||||
's_min_uncond': OptionInfo(0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 4.0, "step": 0.01}),
|
's_min_uncond': OptionInfo(0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 4.0, "step": 0.01}),
|
||||||
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||||
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||||
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}),
|
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}).info("ENSD; does not improve anything, just produces different results for ancestral samplers - only useful for reproducing images"),
|
||||||
'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma"),
|
'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma").link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/6044"),
|
||||||
'uni_pc_variant': OptionInfo("bh1", "UniPC variant", gr.Radio, {"choices": ["bh1", "bh2", "vary_coeff"]}),
|
'uni_pc_variant': OptionInfo("bh1", "UniPC variant", gr.Radio, {"choices": ["bh1", "bh2", "vary_coeff"]}),
|
||||||
'uni_pc_skip_type': OptionInfo("time_uniform", "UniPC skip type", gr.Radio, {"choices": ["time_uniform", "time_quadratic", "logSNR"]}),
|
'uni_pc_skip_type': OptionInfo("time_uniform", "UniPC skip type", gr.Radio, {"choices": ["time_uniform", "time_quadratic", "logSNR"]}),
|
||||||
'uni_pc_order': OptionInfo(3, "UniPC order (must be < sampling steps)", gr.Slider, {"minimum": 1, "maximum": 50, "step": 1}),
|
'uni_pc_order': OptionInfo(3, "UniPC order", gr.Slider, {"minimum": 1, "maximum": 50, "step": 1}).info("must be < sampling steps"),
|
||||||
'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final"),
|
'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
@ -453,6 +483,7 @@ options_templates.update(options_section((None, "Hidden options"), {
|
|||||||
"sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"),
|
"sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
|
||||||
options_templates.update()
|
options_templates.update()
|
||||||
|
|
||||||
|
|
||||||
@ -542,6 +573,10 @@ class Options:
|
|||||||
with open(filename, "r", encoding="utf8") as file:
|
with open(filename, "r", encoding="utf8") as file:
|
||||||
self.data = json.load(file)
|
self.data = json.load(file)
|
||||||
|
|
||||||
|
# 1.1.1 quicksettings list migration
|
||||||
|
if self.data.get('quicksettings') is not None and self.data.get('quicksettings_list') is None:
|
||||||
|
self.data['quicksettings_list'] = [i.strip() for i in self.data.get('quicksettings').split(',')]
|
||||||
|
|
||||||
bad_settings = 0
|
bad_settings = 0
|
||||||
for k, v in self.data.items():
|
for k, v in self.data.items():
|
||||||
info = self.data_labels.get(k, None)
|
info = self.data_labels.get(k, None)
|
||||||
@ -560,7 +595,9 @@ class Options:
|
|||||||
func()
|
func()
|
||||||
|
|
||||||
def dumpjson(self):
|
def dumpjson(self):
|
||||||
d = {k: self.data.get(k, self.data_labels.get(k).default) for k in self.data_labels.keys()}
|
d = {k: self.data.get(k, v.default) for k, v in self.data_labels.items()}
|
||||||
|
d["_comments_before"] = {k: v.comment_before for k, v in self.data_labels.items() if v.comment_before is not None}
|
||||||
|
d["_comments_after"] = {k: v.comment_after for k, v in self.data_labels.items() if v.comment_after is not None}
|
||||||
return json.dumps(d)
|
return json.dumps(d)
|
||||||
|
|
||||||
def add_option(self, key, info):
|
def add_option(self, key, info):
|
||||||
@ -571,11 +608,11 @@ class Options:
|
|||||||
|
|
||||||
section_ids = {}
|
section_ids = {}
|
||||||
settings_items = self.data_labels.items()
|
settings_items = self.data_labels.items()
|
||||||
for k, item in settings_items:
|
for _, item in settings_items:
|
||||||
if item.section not in section_ids:
|
if item.section not in section_ids:
|
||||||
section_ids[item.section] = len(section_ids)
|
section_ids[item.section] = len(section_ids)
|
||||||
|
|
||||||
self.data_labels = {k: v for k, v in sorted(settings_items, key=lambda x: section_ids[x[1].section])}
|
self.data_labels = dict(sorted(settings_items, key=lambda x: section_ids[x[1].section]))
|
||||||
|
|
||||||
def cast_value(self, key, value):
|
def cast_value(self, key, value):
|
||||||
"""casts an arbitrary to the same type as this setting's value with key
|
"""casts an arbitrary to the same type as this setting's value with key
|
||||||
@ -600,13 +637,37 @@ class Options:
|
|||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
opts = Options()
|
opts = Options()
|
||||||
if os.path.exists(config_filename):
|
if os.path.exists(config_filename):
|
||||||
opts.load(config_filename)
|
opts.load(config_filename)
|
||||||
|
|
||||||
|
|
||||||
|
class Shared(sys.modules[__name__].__class__):
|
||||||
|
"""
|
||||||
|
this class is here to provide sd_model field as a property, so that it can be created and loaded on demand rather than
|
||||||
|
at program startup.
|
||||||
|
"""
|
||||||
|
|
||||||
|
sd_model_val = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sd_model(self):
|
||||||
|
import modules.sd_models
|
||||||
|
|
||||||
|
return modules.sd_models.model_data.get_sd_model()
|
||||||
|
|
||||||
|
@sd_model.setter
|
||||||
|
def sd_model(self, value):
|
||||||
|
import modules.sd_models
|
||||||
|
|
||||||
|
modules.sd_models.model_data.set_sd_model(value)
|
||||||
|
|
||||||
|
|
||||||
|
sd_model: LatentDiffusion = None # this var is here just for IDE's type checking; it cannot be accessed because the class field above will be accessed instead
|
||||||
|
sys.modules[__name__].__class__ = Shared
|
||||||
|
|
||||||
settings_components = None
|
settings_components = None
|
||||||
"""assinged from ui.py, a mapping on setting anmes to gradio components repsponsible for those settings"""
|
"""assinged from ui.py, a mapping on setting names to gradio components repsponsible for those settings"""
|
||||||
|
|
||||||
latent_upscale_default_mode = "Latent"
|
latent_upscale_default_mode = "Latent"
|
||||||
latent_upscale_modes = {
|
latent_upscale_modes = {
|
||||||
@ -620,8 +681,6 @@ latent_upscale_modes = {
|
|||||||
|
|
||||||
sd_upscalers = []
|
sd_upscalers = []
|
||||||
|
|
||||||
sd_model = None
|
|
||||||
|
|
||||||
clip_model = None
|
clip_model = None
|
||||||
|
|
||||||
progress_print_out = sys.stdout
|
progress_print_out = sys.stdout
|
||||||
@ -634,14 +693,19 @@ def reload_gradio_theme(theme_name=None):
|
|||||||
if not theme_name:
|
if not theme_name:
|
||||||
theme_name = opts.gradio_theme
|
theme_name = opts.gradio_theme
|
||||||
|
|
||||||
|
default_theme_args = dict(
|
||||||
|
font=["Source Sans Pro", 'ui-sans-serif', 'system-ui', 'sans-serif'],
|
||||||
|
font_mono=['IBM Plex Mono', 'ui-monospace', 'Consolas', 'monospace'],
|
||||||
|
)
|
||||||
|
|
||||||
if theme_name == "Default":
|
if theme_name == "Default":
|
||||||
gradio_theme = gr.themes.Default()
|
gradio_theme = gr.themes.Default(**default_theme_args)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
gradio_theme = gr.themes.ThemeClass.from_hub(theme_name)
|
gradio_theme = gr.themes.ThemeClass.from_hub(theme_name)
|
||||||
except requests.exceptions.ConnectionError:
|
except Exception as e:
|
||||||
print("Can't access HuggingFace Hub, falling back to default Gradio theme")
|
errors.display(e, "changing gradio theme")
|
||||||
gradio_theme = gr.themes.Default()
|
gradio_theme = gr.themes.Default(**default_theme_args)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -701,3 +765,20 @@ def html(filename):
|
|||||||
return file.read()
|
return file.read()
|
||||||
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def walk_files(path, allowed_extensions=None):
|
||||||
|
if not os.path.exists(path):
|
||||||
|
return
|
||||||
|
|
||||||
|
if allowed_extensions is not None:
|
||||||
|
allowed_extensions = set(allowed_extensions)
|
||||||
|
|
||||||
|
for root, _, files in os.walk(path, followlinks=True):
|
||||||
|
for filename in files:
|
||||||
|
if allowed_extensions is not None:
|
||||||
|
_, ext = os.path.splitext(filename)
|
||||||
|
if ext not in allowed_extensions:
|
||||||
|
continue
|
||||||
|
|
||||||
|
yield os.path.join(root, filename)
|
||||||
|
@ -1,18 +1,9 @@
|
|||||||
# We need this so Python doesn't complain about the unknown StableDiffusionProcessing-typehint at runtime
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import csv
|
import csv
|
||||||
import os
|
import os
|
||||||
import os.path
|
import os.path
|
||||||
import typing
|
import typing
|
||||||
import collections.abc as abc
|
|
||||||
import tempfile
|
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
|
||||||
# Only import this when code is being type-checked, it doesn't have any effect at runtime
|
|
||||||
from .processing import StableDiffusionProcessing
|
|
||||||
|
|
||||||
|
|
||||||
class PromptStyle(typing.NamedTuple):
|
class PromptStyle(typing.NamedTuple):
|
||||||
name: str
|
name: str
|
||||||
@ -74,7 +65,7 @@ class StyleDatabase:
|
|||||||
def save_styles(self, path: str) -> None:
|
def save_styles(self, path: str) -> None:
|
||||||
# Always keep a backup file around
|
# Always keep a backup file around
|
||||||
if os.path.exists(path):
|
if os.path.exists(path):
|
||||||
shutil.copy(path, path + ".bak")
|
shutil.copy(path, f"{path}.bak")
|
||||||
|
|
||||||
fd = os.open(path, os.O_RDWR|os.O_CREAT)
|
fd = os.open(path, os.O_RDWR|os.O_CREAT)
|
||||||
with os.fdopen(fd, "w", encoding="utf-8-sig", newline='') as file:
|
with os.fdopen(fd, "w", encoding="utf-8-sig", newline='') as file:
|
||||||
|
@ -179,7 +179,7 @@ def efficient_dot_product_attention(
|
|||||||
chunk_idx,
|
chunk_idx,
|
||||||
min(query_chunk_size, q_tokens)
|
min(query_chunk_size, q_tokens)
|
||||||
)
|
)
|
||||||
|
|
||||||
summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale)
|
summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale)
|
||||||
summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk
|
summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk
|
||||||
compute_query_chunk_attn: ComputeQueryChunkAttn = partial(
|
compute_query_chunk_attn: ComputeQueryChunkAttn = partial(
|
||||||
@ -201,14 +201,15 @@ def efficient_dot_product_attention(
|
|||||||
key=key,
|
key=key,
|
||||||
value=value,
|
value=value,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance,
|
res = torch.zeros_like(query)
|
||||||
# and pass slices to be mutated, instead of torch.cat()ing the returned slices
|
for i in range(math.ceil(q_tokens / query_chunk_size)):
|
||||||
res = torch.cat([
|
attn_scores = compute_query_chunk_attn(
|
||||||
compute_query_chunk_attn(
|
|
||||||
query=get_query_chunk(i * query_chunk_size),
|
query=get_query_chunk(i * query_chunk_size),
|
||||||
key=key,
|
key=key,
|
||||||
value=value,
|
value=value,
|
||||||
) for i in range(math.ceil(q_tokens / query_chunk_size))
|
)
|
||||||
], dim=1)
|
|
||||||
|
res[:, i * query_chunk_size:i * query_chunk_size + attn_scores.shape[1], :] = attn_scores
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
@ -1,10 +1,8 @@
|
|||||||
import cv2
|
import cv2
|
||||||
import requests
|
import requests
|
||||||
import os
|
import os
|
||||||
from collections import defaultdict
|
|
||||||
from math import log, sqrt
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image, ImageDraw
|
from PIL import ImageDraw
|
||||||
|
|
||||||
GREEN = "#0F0"
|
GREEN = "#0F0"
|
||||||
BLUE = "#00F"
|
BLUE = "#00F"
|
||||||
@ -12,63 +10,64 @@ RED = "#F00"
|
|||||||
|
|
||||||
|
|
||||||
def crop_image(im, settings):
|
def crop_image(im, settings):
|
||||||
""" Intelligently crop an image to the subject matter """
|
""" Intelligently crop an image to the subject matter """
|
||||||
|
|
||||||
scale_by = 1
|
scale_by = 1
|
||||||
if is_landscape(im.width, im.height):
|
if is_landscape(im.width, im.height):
|
||||||
scale_by = settings.crop_height / im.height
|
scale_by = settings.crop_height / im.height
|
||||||
elif is_portrait(im.width, im.height):
|
elif is_portrait(im.width, im.height):
|
||||||
scale_by = settings.crop_width / im.width
|
scale_by = settings.crop_width / im.width
|
||||||
elif is_square(im.width, im.height):
|
elif is_square(im.width, im.height):
|
||||||
if is_square(settings.crop_width, settings.crop_height):
|
if is_square(settings.crop_width, settings.crop_height):
|
||||||
scale_by = settings.crop_width / im.width
|
scale_by = settings.crop_width / im.width
|
||||||
elif is_landscape(settings.crop_width, settings.crop_height):
|
elif is_landscape(settings.crop_width, settings.crop_height):
|
||||||
scale_by = settings.crop_width / im.width
|
scale_by = settings.crop_width / im.width
|
||||||
elif is_portrait(settings.crop_width, settings.crop_height):
|
elif is_portrait(settings.crop_width, settings.crop_height):
|
||||||
scale_by = settings.crop_height / im.height
|
scale_by = settings.crop_height / im.height
|
||||||
|
|
||||||
im = im.resize((int(im.width * scale_by), int(im.height * scale_by)))
|
|
||||||
im_debug = im.copy()
|
|
||||||
|
|
||||||
focus = focal_point(im_debug, settings)
|
im = im.resize((int(im.width * scale_by), int(im.height * scale_by)))
|
||||||
|
im_debug = im.copy()
|
||||||
|
|
||||||
# take the focal point and turn it into crop coordinates that try to center over the focal
|
focus = focal_point(im_debug, settings)
|
||||||
# point but then get adjusted back into the frame
|
|
||||||
y_half = int(settings.crop_height / 2)
|
|
||||||
x_half = int(settings.crop_width / 2)
|
|
||||||
|
|
||||||
x1 = focus.x - x_half
|
# take the focal point and turn it into crop coordinates that try to center over the focal
|
||||||
if x1 < 0:
|
# point but then get adjusted back into the frame
|
||||||
x1 = 0
|
y_half = int(settings.crop_height / 2)
|
||||||
elif x1 + settings.crop_width > im.width:
|
x_half = int(settings.crop_width / 2)
|
||||||
x1 = im.width - settings.crop_width
|
|
||||||
|
|
||||||
y1 = focus.y - y_half
|
x1 = focus.x - x_half
|
||||||
if y1 < 0:
|
if x1 < 0:
|
||||||
y1 = 0
|
x1 = 0
|
||||||
elif y1 + settings.crop_height > im.height:
|
elif x1 + settings.crop_width > im.width:
|
||||||
y1 = im.height - settings.crop_height
|
x1 = im.width - settings.crop_width
|
||||||
|
|
||||||
x2 = x1 + settings.crop_width
|
y1 = focus.y - y_half
|
||||||
y2 = y1 + settings.crop_height
|
if y1 < 0:
|
||||||
|
y1 = 0
|
||||||
|
elif y1 + settings.crop_height > im.height:
|
||||||
|
y1 = im.height - settings.crop_height
|
||||||
|
|
||||||
crop = [x1, y1, x2, y2]
|
x2 = x1 + settings.crop_width
|
||||||
|
y2 = y1 + settings.crop_height
|
||||||
|
|
||||||
results = []
|
crop = [x1, y1, x2, y2]
|
||||||
|
|
||||||
results.append(im.crop(tuple(crop)))
|
results = []
|
||||||
|
|
||||||
if settings.annotate_image:
|
results.append(im.crop(tuple(crop)))
|
||||||
d = ImageDraw.Draw(im_debug)
|
|
||||||
rect = list(crop)
|
|
||||||
rect[2] -= 1
|
|
||||||
rect[3] -= 1
|
|
||||||
d.rectangle(rect, outline=GREEN)
|
|
||||||
results.append(im_debug)
|
|
||||||
if settings.destop_view_image:
|
|
||||||
im_debug.show()
|
|
||||||
|
|
||||||
return results
|
if settings.annotate_image:
|
||||||
|
d = ImageDraw.Draw(im_debug)
|
||||||
|
rect = list(crop)
|
||||||
|
rect[2] -= 1
|
||||||
|
rect[3] -= 1
|
||||||
|
d.rectangle(rect, outline=GREEN)
|
||||||
|
results.append(im_debug)
|
||||||
|
if settings.destop_view_image:
|
||||||
|
im_debug.show()
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
def focal_point(im, settings):
|
def focal_point(im, settings):
|
||||||
corner_points = image_corner_points(im, settings) if settings.corner_points_weight > 0 else []
|
corner_points = image_corner_points(im, settings) if settings.corner_points_weight > 0 else []
|
||||||
@ -88,7 +87,7 @@ def focal_point(im, settings):
|
|||||||
corner_centroid = None
|
corner_centroid = None
|
||||||
if len(corner_points) > 0:
|
if len(corner_points) > 0:
|
||||||
corner_centroid = centroid(corner_points)
|
corner_centroid = centroid(corner_points)
|
||||||
corner_centroid.weight = settings.corner_points_weight / weight_pref_total
|
corner_centroid.weight = settings.corner_points_weight / weight_pref_total
|
||||||
pois.append(corner_centroid)
|
pois.append(corner_centroid)
|
||||||
|
|
||||||
entropy_centroid = None
|
entropy_centroid = None
|
||||||
@ -100,7 +99,7 @@ def focal_point(im, settings):
|
|||||||
face_centroid = None
|
face_centroid = None
|
||||||
if len(face_points) > 0:
|
if len(face_points) > 0:
|
||||||
face_centroid = centroid(face_points)
|
face_centroid = centroid(face_points)
|
||||||
face_centroid.weight = settings.face_points_weight / weight_pref_total
|
face_centroid.weight = settings.face_points_weight / weight_pref_total
|
||||||
pois.append(face_centroid)
|
pois.append(face_centroid)
|
||||||
|
|
||||||
average_point = poi_average(pois, settings)
|
average_point = poi_average(pois, settings)
|
||||||
@ -111,7 +110,7 @@ def focal_point(im, settings):
|
|||||||
if corner_centroid is not None:
|
if corner_centroid is not None:
|
||||||
color = BLUE
|
color = BLUE
|
||||||
box = corner_centroid.bounding(max_size * corner_centroid.weight)
|
box = corner_centroid.bounding(max_size * corner_centroid.weight)
|
||||||
d.text((box[0], box[1]-15), "Edge: %.02f" % corner_centroid.weight, fill=color)
|
d.text((box[0], box[1]-15), f"Edge: {corner_centroid.weight:.02f}", fill=color)
|
||||||
d.ellipse(box, outline=color)
|
d.ellipse(box, outline=color)
|
||||||
if len(corner_points) > 1:
|
if len(corner_points) > 1:
|
||||||
for f in corner_points:
|
for f in corner_points:
|
||||||
@ -119,7 +118,7 @@ def focal_point(im, settings):
|
|||||||
if entropy_centroid is not None:
|
if entropy_centroid is not None:
|
||||||
color = "#ff0"
|
color = "#ff0"
|
||||||
box = entropy_centroid.bounding(max_size * entropy_centroid.weight)
|
box = entropy_centroid.bounding(max_size * entropy_centroid.weight)
|
||||||
d.text((box[0], box[1]-15), "Entropy: %.02f" % entropy_centroid.weight, fill=color)
|
d.text((box[0], box[1]-15), f"Entropy: {entropy_centroid.weight:.02f}", fill=color)
|
||||||
d.ellipse(box, outline=color)
|
d.ellipse(box, outline=color)
|
||||||
if len(entropy_points) > 1:
|
if len(entropy_points) > 1:
|
||||||
for f in entropy_points:
|
for f in entropy_points:
|
||||||
@ -127,14 +126,14 @@ def focal_point(im, settings):
|
|||||||
if face_centroid is not None:
|
if face_centroid is not None:
|
||||||
color = RED
|
color = RED
|
||||||
box = face_centroid.bounding(max_size * face_centroid.weight)
|
box = face_centroid.bounding(max_size * face_centroid.weight)
|
||||||
d.text((box[0], box[1]-15), "Face: %.02f" % face_centroid.weight, fill=color)
|
d.text((box[0], box[1]-15), f"Face: {face_centroid.weight:.02f}", fill=color)
|
||||||
d.ellipse(box, outline=color)
|
d.ellipse(box, outline=color)
|
||||||
if len(face_points) > 1:
|
if len(face_points) > 1:
|
||||||
for f in face_points:
|
for f in face_points:
|
||||||
d.rectangle(f.bounding(4), outline=color)
|
d.rectangle(f.bounding(4), outline=color)
|
||||||
|
|
||||||
d.ellipse(average_point.bounding(max_size), outline=GREEN)
|
d.ellipse(average_point.bounding(max_size), outline=GREEN)
|
||||||
|
|
||||||
return average_point
|
return average_point
|
||||||
|
|
||||||
|
|
||||||
@ -185,7 +184,7 @@ def image_face_points(im, settings):
|
|||||||
try:
|
try:
|
||||||
faces = classifier.detectMultiScale(gray, scaleFactor=1.1,
|
faces = classifier.detectMultiScale(gray, scaleFactor=1.1,
|
||||||
minNeighbors=7, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE)
|
minNeighbors=7, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE)
|
||||||
except:
|
except Exception:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if len(faces) > 0:
|
if len(faces) > 0:
|
||||||
@ -262,10 +261,11 @@ def image_entropy(im):
|
|||||||
hist = hist[hist > 0]
|
hist = hist[hist > 0]
|
||||||
return -np.log2(hist / hist.sum()).sum()
|
return -np.log2(hist / hist.sum()).sum()
|
||||||
|
|
||||||
|
|
||||||
def centroid(pois):
|
def centroid(pois):
|
||||||
x = [poi.x for poi in pois]
|
x = [poi.x for poi in pois]
|
||||||
y = [poi.y for poi in pois]
|
y = [poi.y for poi in pois]
|
||||||
return PointOfInterest(sum(x)/len(pois), sum(y)/len(pois))
|
return PointOfInterest(sum(x) / len(pois), sum(y) / len(pois))
|
||||||
|
|
||||||
|
|
||||||
def poi_average(pois, settings):
|
def poi_average(pois, settings):
|
||||||
@ -283,59 +283,59 @@ def poi_average(pois, settings):
|
|||||||
|
|
||||||
|
|
||||||
def is_landscape(w, h):
|
def is_landscape(w, h):
|
||||||
return w > h
|
return w > h
|
||||||
|
|
||||||
|
|
||||||
def is_portrait(w, h):
|
def is_portrait(w, h):
|
||||||
return h > w
|
return h > w
|
||||||
|
|
||||||
|
|
||||||
def is_square(w, h):
|
def is_square(w, h):
|
||||||
return w == h
|
return w == h
|
||||||
|
|
||||||
|
|
||||||
def download_and_cache_models(dirname):
|
def download_and_cache_models(dirname):
|
||||||
download_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true'
|
download_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true'
|
||||||
model_file_name = 'face_detection_yunet.onnx'
|
model_file_name = 'face_detection_yunet.onnx'
|
||||||
|
|
||||||
if not os.path.exists(dirname):
|
if not os.path.exists(dirname):
|
||||||
os.makedirs(dirname)
|
os.makedirs(dirname)
|
||||||
|
|
||||||
cache_file = os.path.join(dirname, model_file_name)
|
cache_file = os.path.join(dirname, model_file_name)
|
||||||
if not os.path.exists(cache_file):
|
if not os.path.exists(cache_file):
|
||||||
print(f"downloading face detection model from '{download_url}' to '{cache_file}'")
|
print(f"downloading face detection model from '{download_url}' to '{cache_file}'")
|
||||||
response = requests.get(download_url)
|
response = requests.get(download_url)
|
||||||
with open(cache_file, "wb") as f:
|
with open(cache_file, "wb") as f:
|
||||||
f.write(response.content)
|
f.write(response.content)
|
||||||
|
|
||||||
if os.path.exists(cache_file):
|
if os.path.exists(cache_file):
|
||||||
return cache_file
|
return cache_file
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
class PointOfInterest:
|
class PointOfInterest:
|
||||||
def __init__(self, x, y, weight=1.0, size=10):
|
def __init__(self, x, y, weight=1.0, size=10):
|
||||||
self.x = x
|
self.x = x
|
||||||
self.y = y
|
self.y = y
|
||||||
self.weight = weight
|
self.weight = weight
|
||||||
self.size = size
|
self.size = size
|
||||||
|
|
||||||
def bounding(self, size):
|
def bounding(self, size):
|
||||||
return [
|
return [
|
||||||
self.x - size//2,
|
self.x - size // 2,
|
||||||
self.y - size//2,
|
self.y - size // 2,
|
||||||
self.x + size//2,
|
self.x + size // 2,
|
||||||
self.y + size//2
|
self.y + size // 2
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class Settings:
|
class Settings:
|
||||||
def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False, dnn_model_path=None):
|
def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False, dnn_model_path=None):
|
||||||
self.crop_width = crop_width
|
self.crop_width = crop_width
|
||||||
self.crop_height = crop_height
|
self.crop_height = crop_height
|
||||||
self.corner_points_weight = corner_points_weight
|
self.corner_points_weight = corner_points_weight
|
||||||
self.entropy_points_weight = entropy_points_weight
|
self.entropy_points_weight = entropy_points_weight
|
||||||
self.face_points_weight = face_points_weight
|
self.face_points_weight = face_points_weight
|
||||||
self.annotate_image = annotate_image
|
self.annotate_image = annotate_image
|
||||||
self.destop_view_image = False
|
self.destop_view_image = False
|
||||||
self.dnn_model_path = dnn_model_path
|
self.dnn_model_path = dnn_model_path
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user