Merge branch 'dev' into startup-profile
This commit is contained in:
commit
b3390a9840
@ -50,13 +50,14 @@ module.exports = {
|
|||||||
globals: {
|
globals: {
|
||||||
//script.js
|
//script.js
|
||||||
gradioApp: "readonly",
|
gradioApp: "readonly",
|
||||||
|
executeCallbacks: "readonly",
|
||||||
|
onAfterUiUpdate: "readonly",
|
||||||
|
onOptionsChanged: "readonly",
|
||||||
onUiLoaded: "readonly",
|
onUiLoaded: "readonly",
|
||||||
onUiUpdate: "readonly",
|
onUiUpdate: "readonly",
|
||||||
onOptionsChanged: "readonly",
|
|
||||||
uiCurrentTab: "writable",
|
uiCurrentTab: "writable",
|
||||||
uiElementIsVisible: "readonly",
|
|
||||||
uiElementInSight: "readonly",
|
uiElementInSight: "readonly",
|
||||||
executeCallbacks: "readonly",
|
uiElementIsVisible: "readonly",
|
||||||
//ui.js
|
//ui.js
|
||||||
opts: "writable",
|
opts: "writable",
|
||||||
all_gallery_buttons: "readonly",
|
all_gallery_buttons: "readonly",
|
||||||
@ -84,5 +85,7 @@ module.exports = {
|
|||||||
// imageviewer.js
|
// imageviewer.js
|
||||||
modalPrevImage: "readonly",
|
modalPrevImage: "readonly",
|
||||||
modalNextImage: "readonly",
|
modalNextImage: "readonly",
|
||||||
|
// token-counters.js
|
||||||
|
setupTokenCounters: "readonly",
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
4
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
4
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
@ -43,8 +43,8 @@ body:
|
|||||||
- type: input
|
- type: input
|
||||||
id: commit
|
id: commit
|
||||||
attributes:
|
attributes:
|
||||||
label: Commit where the problem happens
|
label: Version or Commit where the problem happens
|
||||||
description: Which commit are you running ? (Do not write *Latest version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Commit** link at the bottom of the UI, or from the cmd/terminal if you can't launch it.)
|
description: "Which webui version or commit are you running ? (Do not write *Latest Version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Version: v1.2.3** link at the bottom of the UI, or from the cmd/terminal if you can't launch it.)"
|
||||||
validations:
|
validations:
|
||||||
required: true
|
required: true
|
||||||
- type: dropdown
|
- type: dropdown
|
||||||
|
12
CHANGELOG.md
12
CHANGELOG.md
@ -1,4 +1,4 @@
|
|||||||
## Upcoming 1.3.0
|
## 1.3.0
|
||||||
|
|
||||||
### Features:
|
### Features:
|
||||||
* add UI to edit defaults
|
* add UI to edit defaults
|
||||||
@ -8,6 +8,10 @@
|
|||||||
* update extensions table: show branch, show date in separate column, and show version from tags if available
|
* update extensions table: show branch, show date in separate column, and show version from tags if available
|
||||||
* TAESD - another option for cheap live previews
|
* TAESD - another option for cheap live previews
|
||||||
* allow choosing sampler and prompts for second pass of hires fix - hidden by default, enabled in settings
|
* allow choosing sampler and prompts for second pass of hires fix - hidden by default, enabled in settings
|
||||||
|
* calculate hashes for Lora
|
||||||
|
* add lora hashes to infotext
|
||||||
|
* when pasting infotext, use infotext's lora hashes to find local loras for `<lora:xxx:1>` entries whose hashes match loras the user has
|
||||||
|
* select cross attention optimization from UI
|
||||||
|
|
||||||
### Minor:
|
### Minor:
|
||||||
* bump Gradio to 3.31.0
|
* bump Gradio to 3.31.0
|
||||||
@ -26,6 +30,8 @@
|
|||||||
* switch from pyngrok to ngrok-py
|
* switch from pyngrok to ngrok-py
|
||||||
* lazy-load images in extra networks UI
|
* lazy-load images in extra networks UI
|
||||||
* set "Navigate image viewer with gamepad" option to false by default, by request
|
* set "Navigate image viewer with gamepad" option to false by default, by request
|
||||||
|
* change upscalers to download models into user-specified directory (from commandline args) rather than the default models/<...>
|
||||||
|
* allow hiding buttons in ui-config.json
|
||||||
|
|
||||||
### Extensions:
|
### Extensions:
|
||||||
* add /sdapi/v1/script-info api
|
* add /sdapi/v1/script-info api
|
||||||
@ -35,6 +41,8 @@
|
|||||||
* add command and endpoint for graceful server stopping
|
* add command and endpoint for graceful server stopping
|
||||||
* add some locals (prompts/seeds/etc) from processing function into the Processing class as fields
|
* add some locals (prompts/seeds/etc) from processing function into the Processing class as fields
|
||||||
* rework quoting for infotext items that have commas in them to use JSON (should be backwards compatible except for cases where it didn't work previously)
|
* rework quoting for infotext items that have commas in them to use JSON (should be backwards compatible except for cases where it didn't work previously)
|
||||||
|
* add /sdapi/v1/refresh-loras api checkpoint post request
|
||||||
|
* tests overhaul
|
||||||
|
|
||||||
### Bug Fixes:
|
### Bug Fixes:
|
||||||
* fix an issue preventing the program from starting if the user specifies a bad Gradio theme
|
* fix an issue preventing the program from starting if the user specifies a bad Gradio theme
|
||||||
@ -46,6 +54,8 @@
|
|||||||
* fix inability to merge checkpoint without adding metadata
|
* fix inability to merge checkpoint without adding metadata
|
||||||
* fix extra networks' save preview image not adding infotext for jpeg/webm
|
* fix extra networks' save preview image not adding infotext for jpeg/webm
|
||||||
* remove blinking effect from text in hires fix and scale resolution preview
|
* remove blinking effect from text in hires fix and scale resolution preview
|
||||||
|
* make links to `http://<...>.git` extensions work in the extension tab
|
||||||
|
* fix bug with webui hanging at startup due to hanging git process
|
||||||
|
|
||||||
|
|
||||||
## 1.2.1
|
## 1.2.1
|
||||||
|
@ -15,7 +15,7 @@ A browser interface based on Gradio library for Stable Diffusion.
|
|||||||
- Attention, specify parts of text that the model should pay more attention to
|
- Attention, specify parts of text that the model should pay more attention to
|
||||||
- a man in a `((tuxedo))` - will pay more attention to tuxedo
|
- a man in a `((tuxedo))` - will pay more attention to tuxedo
|
||||||
- a man in a `(tuxedo:1.21)` - alternative syntax
|
- a man in a `(tuxedo:1.21)` - alternative syntax
|
||||||
- select text and press `Ctrl+Up` or `Ctrl+Down` to automatically adjust attention to selected text (code contributed by anonymous user)
|
- select text and press `Ctrl+Up` or `Ctrl+Down` (or `Command+Up` or `Command+Down` if you're on a MacOS) to automatically adjust attention to selected text (code contributed by anonymous user)
|
||||||
- Loopback, run img2img processing multiple times
|
- Loopback, run img2img processing multiple times
|
||||||
- X/Y/Z plot, a way to draw a 3 dimensional plot of images with different parameters
|
- X/Y/Z plot, a way to draw a 3 dimensional plot of images with different parameters
|
||||||
- Textual Inversion
|
- Textual Inversion
|
||||||
|
@ -1,12 +1,10 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
from basicsr.utils.download_util import load_file_from_url
|
from basicsr.utils.download_util import load_file_from_url
|
||||||
|
|
||||||
from modules.upscaler import Upscaler, UpscalerData
|
from modules.upscaler import Upscaler, UpscalerData
|
||||||
from ldsr_model_arch import LDSR
|
from ldsr_model_arch import LDSR
|
||||||
from modules import shared, script_callbacks
|
from modules import shared, script_callbacks, errors
|
||||||
import sd_hijack_autoencoder # noqa: F401
|
import sd_hijack_autoencoder # noqa: F401
|
||||||
import sd_hijack_ddpm_v1 # noqa: F401
|
import sd_hijack_ddpm_v1 # noqa: F401
|
||||||
|
|
||||||
@ -51,10 +49,8 @@ class UpscalerLDSR(Upscaler):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
return LDSR(model, yaml)
|
return LDSR(model, yaml)
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
print("Error importing LDSR:", file=sys.stderr)
|
errors.report("Error importing LDSR", exc_info=True)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def do_upscale(self, img, path):
|
def do_upscale(self, img, path):
|
||||||
|
@ -10,7 +10,7 @@ from contextlib import contextmanager
|
|||||||
from torch.optim.lr_scheduler import LambdaLR
|
from torch.optim.lr_scheduler import LambdaLR
|
||||||
|
|
||||||
from ldm.modules.ema import LitEma
|
from ldm.modules.ema import LitEma
|
||||||
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
|
from vqvae_quantize import VectorQuantizer2 as VectorQuantizer
|
||||||
from ldm.modules.diffusionmodules.model import Encoder, Decoder
|
from ldm.modules.diffusionmodules.model import Encoder, Decoder
|
||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
|
|
||||||
|
147
extensions-builtin/LDSR/vqvae_quantize.py
Normal file
147
extensions-builtin/LDSR/vqvae_quantize.py
Normal file
@ -0,0 +1,147 @@
|
|||||||
|
# Vendored from https://raw.githubusercontent.com/CompVis/taming-transformers/24268930bf1dce879235a7fddd0b2355b84d7ea6/taming/modules/vqvae/quantize.py,
|
||||||
|
# where the license is as follows:
|
||||||
|
#
|
||||||
|
# Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer
|
||||||
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
# of this software and associated documentation files (the "Software"), to deal
|
||||||
|
# in the Software without restriction, including without limitation the rights
|
||||||
|
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
# copies of the Software, and to permit persons to whom the Software is
|
||||||
|
# furnished to do so, subject to the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be included in all
|
||||||
|
# copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||||
|
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||||
|
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||||
|
# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
||||||
|
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
|
||||||
|
# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
|
||||||
|
# OR OTHER DEALINGS IN THE SOFTWARE./
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import numpy as np
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
|
||||||
|
class VectorQuantizer2(nn.Module):
|
||||||
|
"""
|
||||||
|
Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
|
||||||
|
avoids costly matrix multiplications and allows for post-hoc remapping of indices.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# NOTE: due to a bug the beta term was applied to the wrong term. for
|
||||||
|
# backwards compatibility we use the buggy version by default, but you can
|
||||||
|
# specify legacy=False to fix it.
|
||||||
|
def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random",
|
||||||
|
sane_index_shape=False, legacy=True):
|
||||||
|
super().__init__()
|
||||||
|
self.n_e = n_e
|
||||||
|
self.e_dim = e_dim
|
||||||
|
self.beta = beta
|
||||||
|
self.legacy = legacy
|
||||||
|
|
||||||
|
self.embedding = nn.Embedding(self.n_e, self.e_dim)
|
||||||
|
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
||||||
|
|
||||||
|
self.remap = remap
|
||||||
|
if self.remap is not None:
|
||||||
|
self.register_buffer("used", torch.tensor(np.load(self.remap)))
|
||||||
|
self.re_embed = self.used.shape[0]
|
||||||
|
self.unknown_index = unknown_index # "random" or "extra" or integer
|
||||||
|
if self.unknown_index == "extra":
|
||||||
|
self.unknown_index = self.re_embed
|
||||||
|
self.re_embed = self.re_embed + 1
|
||||||
|
print(f"Remapping {self.n_e} indices to {self.re_embed} indices. "
|
||||||
|
f"Using {self.unknown_index} for unknown indices.")
|
||||||
|
else:
|
||||||
|
self.re_embed = n_e
|
||||||
|
|
||||||
|
self.sane_index_shape = sane_index_shape
|
||||||
|
|
||||||
|
def remap_to_used(self, inds):
|
||||||
|
ishape = inds.shape
|
||||||
|
assert len(ishape) > 1
|
||||||
|
inds = inds.reshape(ishape[0], -1)
|
||||||
|
used = self.used.to(inds)
|
||||||
|
match = (inds[:, :, None] == used[None, None, ...]).long()
|
||||||
|
new = match.argmax(-1)
|
||||||
|
unknown = match.sum(2) < 1
|
||||||
|
if self.unknown_index == "random":
|
||||||
|
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
|
||||||
|
else:
|
||||||
|
new[unknown] = self.unknown_index
|
||||||
|
return new.reshape(ishape)
|
||||||
|
|
||||||
|
def unmap_to_all(self, inds):
|
||||||
|
ishape = inds.shape
|
||||||
|
assert len(ishape) > 1
|
||||||
|
inds = inds.reshape(ishape[0], -1)
|
||||||
|
used = self.used.to(inds)
|
||||||
|
if self.re_embed > self.used.shape[0]: # extra token
|
||||||
|
inds[inds >= self.used.shape[0]] = 0 # simply set to zero
|
||||||
|
back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
|
||||||
|
return back.reshape(ishape)
|
||||||
|
|
||||||
|
def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
|
||||||
|
assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
|
||||||
|
assert rescale_logits is False, "Only for interface compatible with Gumbel"
|
||||||
|
assert return_logits is False, "Only for interface compatible with Gumbel"
|
||||||
|
# reshape z -> (batch, height, width, channel) and flatten
|
||||||
|
z = rearrange(z, 'b c h w -> b h w c').contiguous()
|
||||||
|
z_flattened = z.view(-1, self.e_dim)
|
||||||
|
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
||||||
|
|
||||||
|
d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
|
||||||
|
torch.sum(self.embedding.weight ** 2, dim=1) - 2 * \
|
||||||
|
torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
|
||||||
|
|
||||||
|
min_encoding_indices = torch.argmin(d, dim=1)
|
||||||
|
z_q = self.embedding(min_encoding_indices).view(z.shape)
|
||||||
|
perplexity = None
|
||||||
|
min_encodings = None
|
||||||
|
|
||||||
|
# compute loss for embedding
|
||||||
|
if not self.legacy:
|
||||||
|
loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + \
|
||||||
|
torch.mean((z_q - z.detach()) ** 2)
|
||||||
|
else:
|
||||||
|
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * \
|
||||||
|
torch.mean((z_q - z.detach()) ** 2)
|
||||||
|
|
||||||
|
# preserve gradients
|
||||||
|
z_q = z + (z_q - z).detach()
|
||||||
|
|
||||||
|
# reshape back to match original input shape
|
||||||
|
z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
|
||||||
|
|
||||||
|
if self.remap is not None:
|
||||||
|
min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis
|
||||||
|
min_encoding_indices = self.remap_to_used(min_encoding_indices)
|
||||||
|
min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
|
||||||
|
|
||||||
|
if self.sane_index_shape:
|
||||||
|
min_encoding_indices = min_encoding_indices.reshape(
|
||||||
|
z_q.shape[0], z_q.shape[2], z_q.shape[3])
|
||||||
|
|
||||||
|
return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
|
||||||
|
|
||||||
|
def get_codebook_entry(self, indices, shape):
|
||||||
|
# shape specifying (batch, height, width, channel)
|
||||||
|
if self.remap is not None:
|
||||||
|
indices = indices.reshape(shape[0], -1) # add batch axis
|
||||||
|
indices = self.unmap_to_all(indices)
|
||||||
|
indices = indices.reshape(-1) # flatten again
|
||||||
|
|
||||||
|
# get quantized latent vectors
|
||||||
|
z_q = self.embedding(indices)
|
||||||
|
|
||||||
|
if shape is not None:
|
||||||
|
z_q = z_q.view(shape)
|
||||||
|
# reshape back to match original input shape
|
||||||
|
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
||||||
|
|
||||||
|
return z_q
|
@ -1,6 +1,5 @@
|
|||||||
import os.path
|
import os.path
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
|
||||||
|
|
||||||
import PIL.Image
|
import PIL.Image
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -10,8 +9,9 @@ from tqdm import tqdm
|
|||||||
from basicsr.utils.download_util import load_file_from_url
|
from basicsr.utils.download_util import load_file_from_url
|
||||||
|
|
||||||
import modules.upscaler
|
import modules.upscaler
|
||||||
from modules import devices, modelloader, script_callbacks
|
from modules import devices, modelloader, script_callbacks, errors
|
||||||
from scunet_model_arch import SCUNet as net
|
from scunet_model_arch import SCUNet as net
|
||||||
|
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
|
|
||||||
|
|
||||||
@ -38,8 +38,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
|
|||||||
scaler_data = modules.upscaler.UpscalerData(name, file, self, 4)
|
scaler_data = modules.upscaler.UpscalerData(name, file, self, 4)
|
||||||
scalers.append(scaler_data)
|
scalers.append(scaler_data)
|
||||||
except Exception:
|
except Exception:
|
||||||
print(f"Error loading ScuNET model: {file}", file=sys.stderr)
|
errors.report(f"Error loading ScuNET model: {file}", exc_info=True)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
if add_model2:
|
if add_model2:
|
||||||
scaler_data2 = modules.upscaler.UpscalerData(self.model_name2, self.model_url2, self)
|
scaler_data2 = modules.upscaler.UpscalerData(self.model_name2, self.model_url2, self)
|
||||||
scalers.append(scaler_data2)
|
scalers.append(scaler_data2)
|
||||||
|
431
extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js
Normal file
431
extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js
Normal file
@ -0,0 +1,431 @@
|
|||||||
|
// Main
|
||||||
|
|
||||||
|
// Helper functions
|
||||||
|
// Get active tab
|
||||||
|
function getActiveTab(elements, all = false) {
|
||||||
|
const tabs = elements.img2imgTabs.querySelectorAll("button");
|
||||||
|
|
||||||
|
if (all) return tabs;
|
||||||
|
|
||||||
|
for (let tab of tabs) {
|
||||||
|
if (tab.classList.contains("selected")) {
|
||||||
|
return tab;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
onUiLoaded(async() => {
|
||||||
|
const hotkeysConfig = {
|
||||||
|
resetZoom: "KeyR",
|
||||||
|
fitToScreen: "KeyS",
|
||||||
|
moveKey: "KeyF",
|
||||||
|
overlap: "KeyO"
|
||||||
|
};
|
||||||
|
|
||||||
|
let isMoving = false;
|
||||||
|
let mouseX, mouseY;
|
||||||
|
|
||||||
|
const elementIDs = {
|
||||||
|
sketch: "#img2img_sketch",
|
||||||
|
inpaint: "#img2maskimg",
|
||||||
|
inpaintSketch: "#inpaint_sketch",
|
||||||
|
img2imgTabs: "#mode_img2img .tab-nav"
|
||||||
|
};
|
||||||
|
|
||||||
|
async function getElements() {
|
||||||
|
const elements = await Promise.all(
|
||||||
|
Object.values(elementIDs).map(id => document.querySelector(id))
|
||||||
|
);
|
||||||
|
return Object.fromEntries(
|
||||||
|
Object.keys(elementIDs).map((key, index) => [key, elements[index]])
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
const elements = await getElements();
|
||||||
|
|
||||||
|
function applyZoomAndPan(targetElement, elemId) {
|
||||||
|
targetElement.style.transformOrigin = "0 0";
|
||||||
|
let [zoomLevel, panX, panY] = [1, 0, 0];
|
||||||
|
let fullScreenMode = false;
|
||||||
|
|
||||||
|
// In the course of research, it was found that the tag img is very harmful when zooming and creates white canvases. This hack allows you to almost never think about this problem, it has no effect on webui.
|
||||||
|
function fixCanvas() {
|
||||||
|
const activeTab = getActiveTab(elements).textContent.trim();
|
||||||
|
|
||||||
|
if (activeTab !== "img2img") {
|
||||||
|
const img = targetElement.querySelector(`${elemId} img`);
|
||||||
|
|
||||||
|
if (img && img.style.display !== "none") {
|
||||||
|
img.style.display = "none";
|
||||||
|
img.style.visibility = "hidden";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset the zoom level and pan position of the target element to their initial values
|
||||||
|
function resetZoom() {
|
||||||
|
zoomLevel = 1;
|
||||||
|
panX = 0;
|
||||||
|
panY = 0;
|
||||||
|
|
||||||
|
fixCanvas();
|
||||||
|
targetElement.style.transform = `scale(${zoomLevel}) translate(${panX}px, ${panY}px)`;
|
||||||
|
|
||||||
|
const canvas = gradioApp().querySelector(
|
||||||
|
`${elemId} canvas[key="interface"]`
|
||||||
|
);
|
||||||
|
|
||||||
|
toggleOverlap("off");
|
||||||
|
fullScreenMode = false;
|
||||||
|
|
||||||
|
if (
|
||||||
|
canvas &&
|
||||||
|
parseFloat(canvas.style.width) > 865 &&
|
||||||
|
parseFloat(targetElement.style.width) > 865
|
||||||
|
) {
|
||||||
|
fitToElement();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
targetElement.style.width = "";
|
||||||
|
if (canvas) {
|
||||||
|
targetElement.style.height = canvas.style.height;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Toggle the zIndex of the target element between two values, allowing it to overlap or be overlapped by other elements
|
||||||
|
function toggleOverlap(forced = "") {
|
||||||
|
const zIndex1 = "0";
|
||||||
|
const zIndex2 = "998";
|
||||||
|
|
||||||
|
targetElement.style.zIndex =
|
||||||
|
targetElement.style.zIndex !== zIndex2 ? zIndex2 : zIndex1;
|
||||||
|
|
||||||
|
if (forced === "off") {
|
||||||
|
targetElement.style.zIndex = zIndex1;
|
||||||
|
} else if (forced === "on") {
|
||||||
|
targetElement.style.zIndex = zIndex2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Adjust the brush size based on the deltaY value from a mouse wheel event
|
||||||
|
function adjustBrushSize(
|
||||||
|
elemId,
|
||||||
|
deltaY,
|
||||||
|
withoutValue = false,
|
||||||
|
percentage = 5
|
||||||
|
) {
|
||||||
|
const input =
|
||||||
|
gradioApp().querySelector(
|
||||||
|
`${elemId} input[aria-label='Brush radius']`
|
||||||
|
) ||
|
||||||
|
gradioApp().querySelector(
|
||||||
|
`${elemId} button[aria-label="Use brush"]`
|
||||||
|
);
|
||||||
|
|
||||||
|
if (input) {
|
||||||
|
input.click();
|
||||||
|
if (!withoutValue) {
|
||||||
|
const maxValue =
|
||||||
|
parseFloat(input.getAttribute("max")) || 100;
|
||||||
|
const changeAmount = maxValue * (percentage / 100);
|
||||||
|
const newValue =
|
||||||
|
parseFloat(input.value) +
|
||||||
|
(deltaY > 0 ? -changeAmount : changeAmount);
|
||||||
|
input.value = Math.min(Math.max(newValue, 0), maxValue);
|
||||||
|
input.dispatchEvent(new Event("change"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset zoom when uploading a new image
|
||||||
|
const fileInput = gradioApp().querySelector(
|
||||||
|
`${elemId} input[type="file"][accept="image/*"].svelte-116rqfv`
|
||||||
|
);
|
||||||
|
fileInput.addEventListener("click", resetZoom);
|
||||||
|
|
||||||
|
// Update the zoom level and pan position of the target element based on the values of the zoomLevel, panX and panY variables
|
||||||
|
function updateZoom(newZoomLevel, mouseX, mouseY) {
|
||||||
|
newZoomLevel = Math.max(0.5, Math.min(newZoomLevel, 15));
|
||||||
|
panX += mouseX - (mouseX * newZoomLevel) / zoomLevel;
|
||||||
|
panY += mouseY - (mouseY * newZoomLevel) / zoomLevel;
|
||||||
|
|
||||||
|
targetElement.style.transformOrigin = "0 0";
|
||||||
|
targetElement.style.transform = `translate(${panX}px, ${panY}px) scale(${newZoomLevel})`;
|
||||||
|
|
||||||
|
toggleOverlap("on");
|
||||||
|
return newZoomLevel;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Change the zoom level based on user interaction
|
||||||
|
function changeZoomLevel(operation, e) {
|
||||||
|
if (e.shiftKey) {
|
||||||
|
e.preventDefault();
|
||||||
|
|
||||||
|
let zoomPosX, zoomPosY;
|
||||||
|
let delta = 0.2;
|
||||||
|
if (zoomLevel > 7) {
|
||||||
|
delta = 0.9;
|
||||||
|
} else if (zoomLevel > 2) {
|
||||||
|
delta = 0.6;
|
||||||
|
}
|
||||||
|
|
||||||
|
zoomPosX = e.clientX;
|
||||||
|
zoomPosY = e.clientY;
|
||||||
|
|
||||||
|
fullScreenMode = false;
|
||||||
|
zoomLevel = updateZoom(
|
||||||
|
zoomLevel + (operation === "+" ? delta : -delta),
|
||||||
|
zoomPosX - targetElement.getBoundingClientRect().left,
|
||||||
|
zoomPosY - targetElement.getBoundingClientRect().top
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This function fits the target element to the screen by calculating
|
||||||
|
* the required scale and offsets. It also updates the global variables
|
||||||
|
* zoomLevel, panX, and panY to reflect the new state.
|
||||||
|
*/
|
||||||
|
|
||||||
|
function fitToElement() {
|
||||||
|
//Reset Zoom
|
||||||
|
targetElement.style.transform = `translate(${0}px, ${0}px) scale(${1})`;
|
||||||
|
|
||||||
|
// Get element and screen dimensions
|
||||||
|
const elementWidth = targetElement.offsetWidth;
|
||||||
|
const elementHeight = targetElement.offsetHeight;
|
||||||
|
const parentElement = targetElement.parentElement;
|
||||||
|
const screenWidth = parentElement.clientWidth;
|
||||||
|
const screenHeight = parentElement.clientHeight;
|
||||||
|
|
||||||
|
// Get element's coordinates relative to the parent element
|
||||||
|
const elementRect = targetElement.getBoundingClientRect();
|
||||||
|
const parentRect = parentElement.getBoundingClientRect();
|
||||||
|
const elementX = elementRect.x - parentRect.x;
|
||||||
|
|
||||||
|
// Calculate scale and offsets
|
||||||
|
const scaleX = screenWidth / elementWidth;
|
||||||
|
const scaleY = screenHeight / elementHeight;
|
||||||
|
const scale = Math.min(scaleX, scaleY);
|
||||||
|
|
||||||
|
const transformOrigin =
|
||||||
|
window.getComputedStyle(targetElement).transformOrigin;
|
||||||
|
const [originX, originY] = transformOrigin.split(" ");
|
||||||
|
const originXValue = parseFloat(originX);
|
||||||
|
const originYValue = parseFloat(originY);
|
||||||
|
|
||||||
|
const offsetX =
|
||||||
|
(screenWidth - elementWidth * scale) / 2 -
|
||||||
|
originXValue * (1 - scale);
|
||||||
|
const offsetY =
|
||||||
|
(screenHeight - elementHeight * scale) / 2.5 -
|
||||||
|
originYValue * (1 - scale);
|
||||||
|
|
||||||
|
// Apply scale and offsets to the element
|
||||||
|
targetElement.style.transform = `translate(${offsetX}px, ${offsetY}px) scale(${scale})`;
|
||||||
|
|
||||||
|
// Update global variables
|
||||||
|
zoomLevel = scale;
|
||||||
|
panX = offsetX;
|
||||||
|
panY = offsetY;
|
||||||
|
|
||||||
|
fullScreenMode = false;
|
||||||
|
toggleOverlap("off");
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This function fits the target element to the screen by calculating
|
||||||
|
* the required scale and offsets. It also updates the global variables
|
||||||
|
* zoomLevel, panX, and panY to reflect the new state.
|
||||||
|
*/
|
||||||
|
|
||||||
|
// Fullscreen mode
|
||||||
|
function fitToScreen() {
|
||||||
|
const canvas = gradioApp().querySelector(
|
||||||
|
`${elemId} canvas[key="interface"]`
|
||||||
|
);
|
||||||
|
|
||||||
|
if (!canvas) return;
|
||||||
|
|
||||||
|
if (canvas.offsetWidth > 862) {
|
||||||
|
targetElement.style.width = canvas.offsetWidth + "px";
|
||||||
|
}
|
||||||
|
|
||||||
|
if (fullScreenMode) {
|
||||||
|
resetZoom();
|
||||||
|
fullScreenMode = false;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
//Reset Zoom
|
||||||
|
targetElement.style.transform = `translate(${0}px, ${0}px) scale(${1})`;
|
||||||
|
|
||||||
|
// Get scrollbar width to right-align the image
|
||||||
|
const scrollbarWidth = window.innerWidth - document.documentElement.clientWidth;
|
||||||
|
|
||||||
|
// Get element and screen dimensions
|
||||||
|
const elementWidth = targetElement.offsetWidth;
|
||||||
|
const elementHeight = targetElement.offsetHeight;
|
||||||
|
const screenWidth = window.innerWidth - scrollbarWidth;
|
||||||
|
const screenHeight = window.innerHeight;
|
||||||
|
|
||||||
|
// Get element's coordinates relative to the page
|
||||||
|
const elementRect = targetElement.getBoundingClientRect();
|
||||||
|
const elementY = elementRect.y;
|
||||||
|
const elementX = elementRect.x;
|
||||||
|
|
||||||
|
// Calculate scale and offsets
|
||||||
|
const scaleX = screenWidth / elementWidth;
|
||||||
|
const scaleY = screenHeight / elementHeight;
|
||||||
|
const scale = Math.min(scaleX, scaleY);
|
||||||
|
|
||||||
|
// Get the current transformOrigin
|
||||||
|
const computedStyle = window.getComputedStyle(targetElement);
|
||||||
|
const transformOrigin = computedStyle.transformOrigin;
|
||||||
|
const [originX, originY] = transformOrigin.split(" ");
|
||||||
|
const originXValue = parseFloat(originX);
|
||||||
|
const originYValue = parseFloat(originY);
|
||||||
|
|
||||||
|
// Calculate offsets with respect to the transformOrigin
|
||||||
|
const offsetX =
|
||||||
|
(screenWidth - elementWidth * scale) / 2 -
|
||||||
|
elementX -
|
||||||
|
originXValue * (1 - scale);
|
||||||
|
const offsetY =
|
||||||
|
(screenHeight - elementHeight * scale) / 2 -
|
||||||
|
elementY -
|
||||||
|
originYValue * (1 - scale);
|
||||||
|
|
||||||
|
// Apply scale and offsets to the element
|
||||||
|
targetElement.style.transform = `translate(${offsetX}px, ${offsetY}px) scale(${scale})`;
|
||||||
|
|
||||||
|
// Update global variables
|
||||||
|
zoomLevel = scale;
|
||||||
|
panX = offsetX;
|
||||||
|
panY = offsetY;
|
||||||
|
|
||||||
|
fullScreenMode = true;
|
||||||
|
toggleOverlap("on");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle keydown events
|
||||||
|
function handleKeyDown(event) {
|
||||||
|
const hotkeyActions = {
|
||||||
|
[hotkeysConfig.resetZoom]: resetZoom,
|
||||||
|
[hotkeysConfig.overlap]: toggleOverlap,
|
||||||
|
[hotkeysConfig.fitToScreen]: fitToScreen
|
||||||
|
// [hotkeysConfig.moveKey] : moveCanvas,
|
||||||
|
};
|
||||||
|
|
||||||
|
const action = hotkeyActions[event.code];
|
||||||
|
if (action) {
|
||||||
|
event.preventDefault();
|
||||||
|
action(event);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get Mouse position
|
||||||
|
function getMousePosition(e) {
|
||||||
|
mouseX = e.offsetX;
|
||||||
|
mouseY = e.offsetY;
|
||||||
|
}
|
||||||
|
|
||||||
|
targetElement.addEventListener("mousemove", getMousePosition);
|
||||||
|
|
||||||
|
// Handle events only inside the targetElement
|
||||||
|
let isKeyDownHandlerAttached = false;
|
||||||
|
|
||||||
|
function handleMouseMove() {
|
||||||
|
if (!isKeyDownHandlerAttached) {
|
||||||
|
document.addEventListener("keydown", handleKeyDown);
|
||||||
|
isKeyDownHandlerAttached = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function handleMouseLeave() {
|
||||||
|
if (isKeyDownHandlerAttached) {
|
||||||
|
document.removeEventListener("keydown", handleKeyDown);
|
||||||
|
isKeyDownHandlerAttached = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add mouse event handlers
|
||||||
|
targetElement.addEventListener("mousemove", handleMouseMove);
|
||||||
|
targetElement.addEventListener("mouseleave", handleMouseLeave);
|
||||||
|
|
||||||
|
// Reset zoom when click on another tab
|
||||||
|
elements.img2imgTabs.addEventListener("click", resetZoom);
|
||||||
|
elements.img2imgTabs.addEventListener("click", () => {
|
||||||
|
// targetElement.style.width = "";
|
||||||
|
if (parseInt(targetElement.style.width) > 865) {
|
||||||
|
setTimeout(fitToElement, 0);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
targetElement.addEventListener("wheel", e => {
|
||||||
|
// change zoom level
|
||||||
|
const operation = e.deltaY > 0 ? "-" : "+";
|
||||||
|
changeZoomLevel(operation, e);
|
||||||
|
|
||||||
|
// Handle brush size adjustment with ctrl key pressed
|
||||||
|
if (e.ctrlKey || e.metaKey) {
|
||||||
|
e.preventDefault();
|
||||||
|
|
||||||
|
// Increase or decrease brush size based on scroll direction
|
||||||
|
adjustBrushSize(elemId, e.deltaY);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Handle the move event for pan functionality. Updates the panX and panY variables and applies the new transform to the target element.
|
||||||
|
* @param {MouseEvent} e - The mouse event.
|
||||||
|
*/
|
||||||
|
function handleMoveKeyDown(e) {
|
||||||
|
if (e.code === hotkeysConfig.moveKey) {
|
||||||
|
if (!e.ctrlKey && !e.metaKey) {
|
||||||
|
isMoving = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function handleMoveKeyUp(e) {
|
||||||
|
if (e.code === hotkeysConfig.moveKey) {
|
||||||
|
isMoving = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
document.addEventListener("keydown", handleMoveKeyDown);
|
||||||
|
document.addEventListener("keyup", handleMoveKeyUp);
|
||||||
|
|
||||||
|
// Detect zoom level and update the pan speed.
|
||||||
|
function updatePanPosition(movementX, movementY) {
|
||||||
|
let panSpeed = 1.5;
|
||||||
|
|
||||||
|
if (zoomLevel > 8) {
|
||||||
|
panSpeed = 2.5;
|
||||||
|
}
|
||||||
|
|
||||||
|
panX = panX + movementX * panSpeed;
|
||||||
|
panY = panY + movementY * panSpeed;
|
||||||
|
|
||||||
|
targetElement.style.transform = `translate(${panX}px, ${panY}px) scale(${zoomLevel})`;
|
||||||
|
toggleOverlap("on");
|
||||||
|
}
|
||||||
|
|
||||||
|
function handleMoveByKey(e) {
|
||||||
|
if (isMoving) {
|
||||||
|
updatePanPosition(e.movementX, e.movementY);
|
||||||
|
targetElement.style.pointerEvents = "none";
|
||||||
|
} else {
|
||||||
|
targetElement.style.pointerEvents = "auto";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
gradioApp().addEventListener("mousemove", handleMoveByKey);
|
||||||
|
}
|
||||||
|
|
||||||
|
applyZoomAndPan(elements.sketch, elementIDs.sketch);
|
||||||
|
applyZoomAndPan(elements.inpaint, elementIDs.inpaint);
|
||||||
|
applyZoomAndPan(elements.inpaintSketch, elementIDs.inpaintSketch);
|
||||||
|
});
|
@ -0,0 +1,48 @@
|
|||||||
|
import gradio as gr
|
||||||
|
from modules import scripts, shared, ui_components, ui_settings
|
||||||
|
from modules.ui_components import FormColumn
|
||||||
|
|
||||||
|
|
||||||
|
class ExtraOptionsSection(scripts.Script):
|
||||||
|
section = "extra_options"
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.comps = None
|
||||||
|
self.setting_names = None
|
||||||
|
|
||||||
|
def title(self):
|
||||||
|
return "Extra options"
|
||||||
|
|
||||||
|
def show(self, is_img2img):
|
||||||
|
return scripts.AlwaysVisible
|
||||||
|
|
||||||
|
def ui(self, is_img2img):
|
||||||
|
self.comps = []
|
||||||
|
self.setting_names = []
|
||||||
|
|
||||||
|
with gr.Blocks() as interface:
|
||||||
|
with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and len(shared.opts.extra_options) > 0 else gr.Group(), gr.Row():
|
||||||
|
for setting_name in shared.opts.extra_options:
|
||||||
|
with FormColumn():
|
||||||
|
comp = ui_settings.create_setting_component(setting_name)
|
||||||
|
|
||||||
|
self.comps.append(comp)
|
||||||
|
self.setting_names.append(setting_name)
|
||||||
|
|
||||||
|
def get_settings_values():
|
||||||
|
return [ui_settings.get_value_for_setting(key) for key in self.setting_names]
|
||||||
|
|
||||||
|
interface.load(fn=get_settings_values, inputs=[], outputs=self.comps, queue=False, show_progress=False)
|
||||||
|
|
||||||
|
return self.comps
|
||||||
|
|
||||||
|
def before_process(self, p, *args):
|
||||||
|
for name, value in zip(self.setting_names, args):
|
||||||
|
if name not in p.override_settings:
|
||||||
|
p.override_settings[name] = value
|
||||||
|
|
||||||
|
|
||||||
|
shared.options_templates.update(shared.options_section(('ui', "User interface"), {
|
||||||
|
"extra_options": shared.OptionInfo([], "Options in main UI", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img/img2img interfaces").needs_restart(),
|
||||||
|
"extra_options_accordion": shared.OptionInfo(False, "Place options in main UI into an accordion")
|
||||||
|
}))
|
@ -81,7 +81,7 @@ function dimensionChange(e, is_width, is_height) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
onUiUpdate(function() {
|
onAfterUiUpdate(function() {
|
||||||
var arPreviewRect = gradioApp().querySelector('#imageARPreview');
|
var arPreviewRect = gradioApp().querySelector('#imageARPreview');
|
||||||
if (arPreviewRect) {
|
if (arPreviewRect) {
|
||||||
arPreviewRect.style.display = 'none';
|
arPreviewRect.style.display = 'none';
|
||||||
|
@ -167,6 +167,4 @@ var addContextMenuEventListener = initResponse[2];
|
|||||||
})();
|
})();
|
||||||
//End example Context Menu Items
|
//End example Context Menu Items
|
||||||
|
|
||||||
onUiUpdate(function() {
|
onAfterUiUpdate(addContextMenuEventListener);
|
||||||
addContextMenuEventListener();
|
|
||||||
});
|
|
||||||
|
55
javascript/dragdrop.js
vendored
55
javascript/dragdrop.js
vendored
@ -48,12 +48,27 @@ function dropReplaceImage(imgWrap, files) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function eventHasFiles(e) {
|
||||||
|
if (!e.dataTransfer || !e.dataTransfer.files) return false;
|
||||||
|
if (e.dataTransfer.files.length > 0) return true;
|
||||||
|
if (e.dataTransfer.items.length > 0 && e.dataTransfer.items[0].kind == "file") return true;
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
function dragDropTargetIsPrompt(target) {
|
||||||
|
if (target?.placeholder && target?.placeholder.indexOf("Prompt") >= 0) return true;
|
||||||
|
if (target?.parentNode?.parentNode?.className?.indexOf("prompt") > 0) return true;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
window.document.addEventListener('dragover', e => {
|
window.document.addEventListener('dragover', e => {
|
||||||
const target = e.composedPath()[0];
|
const target = e.composedPath()[0];
|
||||||
const imgWrap = target.closest('[data-testid="image"]');
|
if (!eventHasFiles(e)) return;
|
||||||
if (!imgWrap && target.placeholder && target.placeholder.indexOf("Prompt") == -1) {
|
|
||||||
return;
|
var targetImage = target.closest('[data-testid="image"]');
|
||||||
}
|
if (!dragDropTargetIsPrompt(target) && !targetImage) return;
|
||||||
|
|
||||||
e.stopPropagation();
|
e.stopPropagation();
|
||||||
e.preventDefault();
|
e.preventDefault();
|
||||||
e.dataTransfer.dropEffect = 'copy';
|
e.dataTransfer.dropEffect = 'copy';
|
||||||
@ -61,17 +76,31 @@ window.document.addEventListener('dragover', e => {
|
|||||||
|
|
||||||
window.document.addEventListener('drop', e => {
|
window.document.addEventListener('drop', e => {
|
||||||
const target = e.composedPath()[0];
|
const target = e.composedPath()[0];
|
||||||
if (target.placeholder.indexOf("Prompt") == -1) {
|
if (!eventHasFiles(e)) return;
|
||||||
|
|
||||||
|
if (dragDropTargetIsPrompt(target)) {
|
||||||
|
e.stopPropagation();
|
||||||
|
e.preventDefault();
|
||||||
|
|
||||||
|
let prompt_target = get_tab_index('tabs') == 1 ? "img2img_prompt_image" : "txt2img_prompt_image";
|
||||||
|
|
||||||
|
const imgParent = gradioApp().getElementById(prompt_target);
|
||||||
|
const files = e.dataTransfer.files;
|
||||||
|
const fileInput = imgParent.querySelector('input[type="file"]');
|
||||||
|
if (fileInput) {
|
||||||
|
fileInput.files = files;
|
||||||
|
fileInput.dispatchEvent(new Event('change'));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var targetImage = target.closest('[data-testid="image"]');
|
||||||
|
if (targetImage) {
|
||||||
|
e.stopPropagation();
|
||||||
|
e.preventDefault();
|
||||||
|
const files = e.dataTransfer.files;
|
||||||
|
dropReplaceImage(targetImage, files);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const imgWrap = target.closest('[data-testid="image"]');
|
|
||||||
if (!imgWrap) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
e.stopPropagation();
|
|
||||||
e.preventDefault();
|
|
||||||
const files = e.dataTransfer.files;
|
|
||||||
dropReplaceImage(imgWrap, files);
|
|
||||||
});
|
});
|
||||||
|
|
||||||
window.addEventListener('paste', e => {
|
window.addEventListener('paste', e => {
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
function keyupEditAttention(event) {
|
function keyupEditAttention(event) {
|
||||||
let target = event.originalTarget || event.composedPath()[0];
|
let target = event.originalTarget || event.composedPath()[0];
|
||||||
if (!target.matches("[id*='_toprow'] [id*='_prompt'] textarea")) return;
|
if (!target.matches("*:is([id*='_toprow'] [id*='_prompt'], .prompt) textarea")) return;
|
||||||
if (!(event.metaKey || event.ctrlKey)) return;
|
if (!(event.metaKey || event.ctrlKey)) return;
|
||||||
|
|
||||||
let isPlus = event.key == "ArrowUp";
|
let isPlus = event.key == "ArrowUp";
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
// attaches listeners to the txt2img and img2img galleries to update displayed generation param text when the image changes
|
// attaches listeners to the txt2img and img2img galleries to update displayed generation param text when the image changes
|
||||||
|
|
||||||
let txt2img_gallery, img2img_gallery, modal = undefined;
|
let txt2img_gallery, img2img_gallery, modal = undefined;
|
||||||
onUiUpdate(function() {
|
onAfterUiUpdate(function() {
|
||||||
if (!txt2img_gallery) {
|
if (!txt2img_gallery) {
|
||||||
txt2img_gallery = attachGalleryListeners("txt2img");
|
txt2img_gallery = attachGalleryListeners("txt2img");
|
||||||
}
|
}
|
||||||
|
@ -116,17 +116,25 @@ var titles = {
|
|||||||
"Negative Guidance minimum sigma": "Skip negative prompt for steps where image is already mostly denoised; the higher this value, the more skips there will be; provides increased performance in exchange for minor quality reduction."
|
"Negative Guidance minimum sigma": "Skip negative prompt for steps where image is already mostly denoised; the higher this value, the more skips there will be; provides increased performance in exchange for minor quality reduction."
|
||||||
};
|
};
|
||||||
|
|
||||||
function updateTooltipForSpan(span) {
|
function updateTooltip(element) {
|
||||||
if (span.title) return; // already has a title
|
if (element.title) return; // already has a title
|
||||||
|
|
||||||
let tooltip = localization[titles[span.textContent]] || titles[span.textContent];
|
let text = element.textContent;
|
||||||
|
let tooltip = localization[titles[text]] || titles[text];
|
||||||
|
|
||||||
if (!tooltip) {
|
if (!tooltip) {
|
||||||
tooltip = localization[titles[span.value]] || titles[span.value];
|
let value = element.value;
|
||||||
|
if (value) tooltip = localization[titles[value]] || titles[value];
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!tooltip) {
|
if (!tooltip) {
|
||||||
for (const c of span.classList) {
|
// Gradio dropdown options have `data-value`.
|
||||||
|
let dataValue = element.dataset.value;
|
||||||
|
if (dataValue) tooltip = localization[titles[dataValue]] || titles[dataValue];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!tooltip) {
|
||||||
|
for (const c of element.classList) {
|
||||||
if (c in titles) {
|
if (c in titles) {
|
||||||
tooltip = localization[titles[c]] || titles[c];
|
tooltip = localization[titles[c]] || titles[c];
|
||||||
break;
|
break;
|
||||||
@ -135,34 +143,53 @@ function updateTooltipForSpan(span) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (tooltip) {
|
if (tooltip) {
|
||||||
span.title = tooltip;
|
element.title = tooltip;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
function updateTooltipForSelect(select) {
|
// Nodes to check for adding tooltips.
|
||||||
if (select.onchange != null) return;
|
const tooltipCheckNodes = new Set();
|
||||||
|
// Timer for debouncing tooltip check.
|
||||||
|
let tooltipCheckTimer = null;
|
||||||
|
|
||||||
select.onchange = function() {
|
function processTooltipCheckNodes() {
|
||||||
select.title = localization[titles[select.value]] || titles[select.value] || "";
|
for (const node of tooltipCheckNodes) {
|
||||||
};
|
updateTooltip(node);
|
||||||
|
}
|
||||||
|
tooltipCheckNodes.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
var observedTooltipElements = {SPAN: 1, BUTTON: 1, SELECT: 1, P: 1};
|
onUiUpdate(function(mutationRecords) {
|
||||||
|
for (const record of mutationRecords) {
|
||||||
onUiUpdate(function(m) {
|
if (record.type === "childList" && record.target.classList.contains("options")) {
|
||||||
m.forEach(function(record) {
|
// This smells like a Gradio dropdown menu having changed,
|
||||||
record.addedNodes.forEach(function(node) {
|
// so let's enqueue an update for the input element that shows the current value.
|
||||||
if (observedTooltipElements[node.tagName]) {
|
let wrap = record.target.parentNode;
|
||||||
updateTooltipForSpan(node);
|
let input = wrap?.querySelector("input");
|
||||||
|
if (input) {
|
||||||
|
input.title = ""; // So we'll even have a chance to update it.
|
||||||
|
tooltipCheckNodes.add(input);
|
||||||
}
|
}
|
||||||
if (node.tagName == "SELECT") {
|
}
|
||||||
updateTooltipForSelect(node);
|
for (const node of record.addedNodes) {
|
||||||
|
if (node.nodeType === Node.ELEMENT_NODE && !node.classList.contains("hide")) {
|
||||||
|
if (!node.title) {
|
||||||
|
if (
|
||||||
|
node.tagName === "SPAN" ||
|
||||||
|
node.tagName === "BUTTON" ||
|
||||||
|
node.tagName === "P" ||
|
||||||
|
node.tagName === "INPUT" ||
|
||||||
|
(node.tagName === "LI" && node.classList.contains("item")) // Gradio dropdown item
|
||||||
|
) {
|
||||||
|
tooltipCheckNodes.add(node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node.querySelectorAll('span, button, p').forEach(n => tooltipCheckNodes.add(n));
|
||||||
}
|
}
|
||||||
|
}
|
||||||
if (node.querySelectorAll) {
|
}
|
||||||
node.querySelectorAll('span, button, select, p').forEach(updateTooltipForSpan);
|
if (tooltipCheckNodes.size) {
|
||||||
node.querySelectorAll('select').forEach(updateTooltipForSelect);
|
clearTimeout(tooltipCheckTimer);
|
||||||
}
|
tooltipCheckTimer = setTimeout(processTooltipCheckNodes, 1000);
|
||||||
});
|
}
|
||||||
});
|
|
||||||
});
|
});
|
||||||
|
@ -39,5 +39,5 @@ function imageMaskResize() {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
onUiUpdate(imageMaskResize);
|
onAfterUiUpdate(imageMaskResize);
|
||||||
window.addEventListener('resize', imageMaskResize);
|
window.addEventListener('resize', imageMaskResize);
|
||||||
|
@ -1,18 +0,0 @@
|
|||||||
window.onload = (function() {
|
|
||||||
window.addEventListener('drop', e => {
|
|
||||||
const target = e.composedPath()[0];
|
|
||||||
if (target.placeholder.indexOf("Prompt") == -1) return;
|
|
||||||
|
|
||||||
let prompt_target = get_tab_index('tabs') == 1 ? "img2img_prompt_image" : "txt2img_prompt_image";
|
|
||||||
|
|
||||||
e.stopPropagation();
|
|
||||||
e.preventDefault();
|
|
||||||
const imgParent = gradioApp().getElementById(prompt_target);
|
|
||||||
const files = e.dataTransfer.files;
|
|
||||||
const fileInput = imgParent.querySelector('input[type="file"]');
|
|
||||||
if (fileInput) {
|
|
||||||
fileInput.files = files;
|
|
||||||
fileInput.dispatchEvent(new Event('change'));
|
|
||||||
}
|
|
||||||
});
|
|
||||||
});
|
|
@ -170,7 +170,7 @@ function modalTileImageToggle(event) {
|
|||||||
event.stopPropagation();
|
event.stopPropagation();
|
||||||
}
|
}
|
||||||
|
|
||||||
onUiUpdate(function() {
|
onAfterUiUpdate(function() {
|
||||||
var fullImg_preview = gradioApp().querySelectorAll('.gradio-gallery > div > img');
|
var fullImg_preview = gradioApp().querySelectorAll('.gradio-gallery > div > img');
|
||||||
if (fullImg_preview != null) {
|
if (fullImg_preview != null) {
|
||||||
fullImg_preview.forEach(setupImageForLightbox);
|
fullImg_preview.forEach(setupImageForLightbox);
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
|
let gamepads = [];
|
||||||
|
|
||||||
window.addEventListener('gamepadconnected', (e) => {
|
window.addEventListener('gamepadconnected', (e) => {
|
||||||
const index = e.gamepad.index;
|
const index = e.gamepad.index;
|
||||||
let isWaiting = false;
|
let isWaiting = false;
|
||||||
setInterval(async() => {
|
gamepads[index] = setInterval(async() => {
|
||||||
if (!opts.js_modal_lightbox_gamepad || isWaiting) return;
|
if (!opts.js_modal_lightbox_gamepad || isWaiting) return;
|
||||||
const gamepad = navigator.getGamepads()[index];
|
const gamepad = navigator.getGamepads()[index];
|
||||||
const xValue = gamepad.axes[0];
|
const xValue = gamepad.axes[0];
|
||||||
@ -24,6 +26,10 @@ window.addEventListener('gamepadconnected', (e) => {
|
|||||||
}, 10);
|
}, 10);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
window.addEventListener('gamepaddisconnected', (e) => {
|
||||||
|
clearInterval(gamepads[e.gamepad.index]);
|
||||||
|
});
|
||||||
|
|
||||||
/*
|
/*
|
||||||
Primarily for vr controller type pointer devices.
|
Primarily for vr controller type pointer devices.
|
||||||
I use the wheel event because there's currently no way to do it properly with web xr.
|
I use the wheel event because there's currently no way to do it properly with web xr.
|
||||||
|
@ -4,7 +4,7 @@ let lastHeadImg = null;
|
|||||||
|
|
||||||
let notificationButton = null;
|
let notificationButton = null;
|
||||||
|
|
||||||
onUiUpdate(function() {
|
onAfterUiUpdate(function() {
|
||||||
if (notificationButton == null) {
|
if (notificationButton == null) {
|
||||||
notificationButton = gradioApp().getElementById('request_notifications');
|
notificationButton = gradioApp().getElementById('request_notifications');
|
||||||
|
|
||||||
|
83
javascript/token-counters.js
Normal file
83
javascript/token-counters.js
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
let promptTokenCountDebounceTime = 800;
|
||||||
|
let promptTokenCountTimeouts = {};
|
||||||
|
var promptTokenCountUpdateFunctions = {};
|
||||||
|
|
||||||
|
function update_txt2img_tokens(...args) {
|
||||||
|
// Called from Gradio
|
||||||
|
update_token_counter("txt2img_token_button");
|
||||||
|
if (args.length == 2) {
|
||||||
|
return args[0];
|
||||||
|
}
|
||||||
|
return args;
|
||||||
|
}
|
||||||
|
|
||||||
|
function update_img2img_tokens(...args) {
|
||||||
|
// Called from Gradio
|
||||||
|
update_token_counter("img2img_token_button");
|
||||||
|
if (args.length == 2) {
|
||||||
|
return args[0];
|
||||||
|
}
|
||||||
|
return args;
|
||||||
|
}
|
||||||
|
|
||||||
|
function update_token_counter(button_id) {
|
||||||
|
if (opts.disable_token_counters) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (promptTokenCountTimeouts[button_id]) {
|
||||||
|
clearTimeout(promptTokenCountTimeouts[button_id]);
|
||||||
|
}
|
||||||
|
promptTokenCountTimeouts[button_id] = setTimeout(
|
||||||
|
() => gradioApp().getElementById(button_id)?.click(),
|
||||||
|
promptTokenCountDebounceTime,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
function recalculatePromptTokens(name) {
|
||||||
|
promptTokenCountUpdateFunctions[name]?.();
|
||||||
|
}
|
||||||
|
|
||||||
|
function recalculate_prompts_txt2img() {
|
||||||
|
// Called from Gradio
|
||||||
|
recalculatePromptTokens('txt2img_prompt');
|
||||||
|
recalculatePromptTokens('txt2img_neg_prompt');
|
||||||
|
return Array.from(arguments);
|
||||||
|
}
|
||||||
|
|
||||||
|
function recalculate_prompts_img2img() {
|
||||||
|
// Called from Gradio
|
||||||
|
recalculatePromptTokens('img2img_prompt');
|
||||||
|
recalculatePromptTokens('img2img_neg_prompt');
|
||||||
|
return Array.from(arguments);
|
||||||
|
}
|
||||||
|
|
||||||
|
function setupTokenCounting(id, id_counter, id_button) {
|
||||||
|
var prompt = gradioApp().getElementById(id);
|
||||||
|
var counter = gradioApp().getElementById(id_counter);
|
||||||
|
var textarea = gradioApp().querySelector(`#${id} > label > textarea`);
|
||||||
|
|
||||||
|
if (opts.disable_token_counters) {
|
||||||
|
counter.style.display = "none";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (counter.parentElement == prompt.parentElement) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt.parentElement.insertBefore(counter, prompt);
|
||||||
|
prompt.parentElement.style.position = "relative";
|
||||||
|
|
||||||
|
promptTokenCountUpdateFunctions[id] = function() {
|
||||||
|
update_token_counter(id_button);
|
||||||
|
};
|
||||||
|
textarea.addEventListener("input", promptTokenCountUpdateFunctions[id]);
|
||||||
|
}
|
||||||
|
|
||||||
|
function setupTokenCounters() {
|
||||||
|
setupTokenCounting('txt2img_prompt', 'txt2img_token_counter', 'txt2img_token_button');
|
||||||
|
setupTokenCounting('txt2img_neg_prompt', 'txt2img_negative_token_counter', 'txt2img_negative_token_button');
|
||||||
|
setupTokenCounting('img2img_prompt', 'img2img_token_counter', 'img2img_token_button');
|
||||||
|
setupTokenCounting('img2img_neg_prompt', 'img2img_negative_token_counter', 'img2img_negative_token_button');
|
||||||
|
}
|
@ -248,29 +248,8 @@ function confirm_clear_prompt(prompt, negative_prompt) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
var promptTokecountUpdateFuncs = {};
|
|
||||||
|
|
||||||
function recalculatePromptTokens(name) {
|
|
||||||
if (promptTokecountUpdateFuncs[name]) {
|
|
||||||
promptTokecountUpdateFuncs[name]();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
function recalculate_prompts_txt2img() {
|
|
||||||
recalculatePromptTokens('txt2img_prompt');
|
|
||||||
recalculatePromptTokens('txt2img_neg_prompt');
|
|
||||||
return Array.from(arguments);
|
|
||||||
}
|
|
||||||
|
|
||||||
function recalculate_prompts_img2img() {
|
|
||||||
recalculatePromptTokens('img2img_prompt');
|
|
||||||
recalculatePromptTokens('img2img_neg_prompt');
|
|
||||||
return Array.from(arguments);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
var opts = {};
|
var opts = {};
|
||||||
onUiUpdate(function() {
|
onAfterUiUpdate(function() {
|
||||||
if (Object.keys(opts).length != 0) return;
|
if (Object.keys(opts).length != 0) return;
|
||||||
|
|
||||||
var json_elem = gradioApp().getElementById('settings_json');
|
var json_elem = gradioApp().getElementById('settings_json');
|
||||||
@ -302,28 +281,7 @@ onUiUpdate(function() {
|
|||||||
|
|
||||||
json_elem.parentElement.style.display = "none";
|
json_elem.parentElement.style.display = "none";
|
||||||
|
|
||||||
function registerTextarea(id, id_counter, id_button) {
|
setupTokenCounters();
|
||||||
var prompt = gradioApp().getElementById(id);
|
|
||||||
var counter = gradioApp().getElementById(id_counter);
|
|
||||||
var textarea = gradioApp().querySelector("#" + id + " > label > textarea");
|
|
||||||
|
|
||||||
if (counter.parentElement == prompt.parentElement) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
prompt.parentElement.insertBefore(counter, prompt);
|
|
||||||
prompt.parentElement.style.position = "relative";
|
|
||||||
|
|
||||||
promptTokecountUpdateFuncs[id] = function() {
|
|
||||||
update_token_counter(id_button);
|
|
||||||
};
|
|
||||||
textarea.addEventListener("input", promptTokecountUpdateFuncs[id]);
|
|
||||||
}
|
|
||||||
|
|
||||||
registerTextarea('txt2img_prompt', 'txt2img_token_counter', 'txt2img_token_button');
|
|
||||||
registerTextarea('txt2img_neg_prompt', 'txt2img_negative_token_counter', 'txt2img_negative_token_button');
|
|
||||||
registerTextarea('img2img_prompt', 'img2img_token_counter', 'img2img_token_button');
|
|
||||||
registerTextarea('img2img_neg_prompt', 'img2img_negative_token_counter', 'img2img_negative_token_button');
|
|
||||||
|
|
||||||
var show_all_pages = gradioApp().getElementById('settings_show_all_pages');
|
var show_all_pages = gradioApp().getElementById('settings_show_all_pages');
|
||||||
var settings_tabs = gradioApp().querySelector('#settings div');
|
var settings_tabs = gradioApp().querySelector('#settings div');
|
||||||
@ -354,33 +312,6 @@ onOptionsChanged(function() {
|
|||||||
});
|
});
|
||||||
|
|
||||||
let txt2img_textarea, img2img_textarea = undefined;
|
let txt2img_textarea, img2img_textarea = undefined;
|
||||||
let wait_time = 800;
|
|
||||||
let token_timeouts = {};
|
|
||||||
|
|
||||||
function update_txt2img_tokens(...args) {
|
|
||||||
update_token_counter("txt2img_token_button");
|
|
||||||
if (args.length == 2) {
|
|
||||||
return args[0];
|
|
||||||
}
|
|
||||||
return args;
|
|
||||||
}
|
|
||||||
|
|
||||||
function update_img2img_tokens(...args) {
|
|
||||||
update_token_counter(
|
|
||||||
"img2img_token_button"
|
|
||||||
);
|
|
||||||
if (args.length == 2) {
|
|
||||||
return args[0];
|
|
||||||
}
|
|
||||||
return args;
|
|
||||||
}
|
|
||||||
|
|
||||||
function update_token_counter(button_id) {
|
|
||||||
if (token_timeouts[button_id]) {
|
|
||||||
clearTimeout(token_timeouts[button_id]);
|
|
||||||
}
|
|
||||||
token_timeouts[button_id] = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time);
|
|
||||||
}
|
|
||||||
|
|
||||||
function restart_reload() {
|
function restart_reload() {
|
||||||
document.body.innerHTML = '<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>';
|
document.body.innerHTML = '<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>';
|
||||||
|
358
launch.py
358
launch.py
@ -1,337 +1,27 @@
|
|||||||
# this scripts installs necessary requirements and launches main program in webui.py
|
from modules import launch_utils
|
||||||
import subprocess
|
|
||||||
import os
|
|
||||||
import sys
|
args = launch_utils.args
|
||||||
import importlib.util
|
python = launch_utils.python
|
||||||
import platform
|
git = launch_utils.git
|
||||||
import json
|
index_url = launch_utils.index_url
|
||||||
from functools import lru_cache
|
dir_repos = launch_utils.dir_repos
|
||||||
|
|
||||||
from modules import cmd_args
|
commit_hash = launch_utils.commit_hash
|
||||||
from modules.paths_internal import script_path, extensions_dir
|
git_tag = launch_utils.git_tag
|
||||||
|
|
||||||
args, _ = cmd_args.parser.parse_known_args()
|
run = launch_utils.run
|
||||||
|
is_installed = launch_utils.is_installed
|
||||||
python = sys.executable
|
repo_dir = launch_utils.repo_dir
|
||||||
git = os.environ.get('GIT', "git")
|
|
||||||
index_url = os.environ.get('INDEX_URL', "")
|
run_pip = launch_utils.run_pip
|
||||||
dir_repos = "repositories"
|
check_run_python = launch_utils.check_run_python
|
||||||
|
git_clone = launch_utils.git_clone
|
||||||
# Whether to default to printing command output
|
git_pull_recursive = launch_utils.git_pull_recursive
|
||||||
default_command_live = (os.environ.get('WEBUI_LAUNCH_LIVE_OUTPUT') == "1")
|
run_extension_installer = launch_utils.run_extension_installer
|
||||||
|
prepare_environment = launch_utils.prepare_environment
|
||||||
if 'GRADIO_ANALYTICS_ENABLED' not in os.environ:
|
configure_for_tests = launch_utils.configure_for_tests
|
||||||
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
|
start = launch_utils.start
|
||||||
|
|
||||||
|
|
||||||
def check_python_version():
|
|
||||||
is_windows = platform.system() == "Windows"
|
|
||||||
major = sys.version_info.major
|
|
||||||
minor = sys.version_info.minor
|
|
||||||
micro = sys.version_info.micro
|
|
||||||
|
|
||||||
if is_windows:
|
|
||||||
supported_minors = [10]
|
|
||||||
else:
|
|
||||||
supported_minors = [7, 8, 9, 10, 11]
|
|
||||||
|
|
||||||
if not (major == 3 and minor in supported_minors):
|
|
||||||
import modules.errors
|
|
||||||
|
|
||||||
modules.errors.print_error_explanation(f"""
|
|
||||||
INCOMPATIBLE PYTHON VERSION
|
|
||||||
|
|
||||||
This program is tested with 3.10.6 Python, but you have {major}.{minor}.{micro}.
|
|
||||||
If you encounter an error with "RuntimeError: Couldn't install torch." message,
|
|
||||||
or any other error regarding unsuccessful package (library) installation,
|
|
||||||
please downgrade (or upgrade) to the latest version of 3.10 Python
|
|
||||||
and delete current Python and "venv" folder in WebUI's directory.
|
|
||||||
|
|
||||||
You can download 3.10 Python from here: https://www.python.org/downloads/release/python-3106/
|
|
||||||
|
|
||||||
{"Alternatively, use a binary release of WebUI: https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases" if is_windows else ""}
|
|
||||||
|
|
||||||
Use --skip-python-version-check to suppress this warning.
|
|
||||||
""")
|
|
||||||
|
|
||||||
|
|
||||||
@lru_cache()
|
|
||||||
def commit_hash():
|
|
||||||
try:
|
|
||||||
return subprocess.check_output([git, "rev-parse", "HEAD"], shell=False, encoding='utf8').strip()
|
|
||||||
except Exception:
|
|
||||||
return "<none>"
|
|
||||||
|
|
||||||
|
|
||||||
@lru_cache()
|
|
||||||
def git_tag():
|
|
||||||
try:
|
|
||||||
return subprocess.check_output([git, "describe", "--tags"], shell=False, encoding='utf8').strip()
|
|
||||||
except Exception:
|
|
||||||
return "<none>"
|
|
||||||
|
|
||||||
|
|
||||||
def run(command, desc=None, errdesc=None, custom_env=None, live: bool = default_command_live) -> str:
|
|
||||||
if desc is not None:
|
|
||||||
print(desc)
|
|
||||||
|
|
||||||
run_kwargs = {
|
|
||||||
"args": command,
|
|
||||||
"shell": True,
|
|
||||||
"env": os.environ if custom_env is None else custom_env,
|
|
||||||
"encoding": 'utf8',
|
|
||||||
"errors": 'ignore',
|
|
||||||
}
|
|
||||||
|
|
||||||
if not live:
|
|
||||||
run_kwargs["stdout"] = run_kwargs["stderr"] = subprocess.PIPE
|
|
||||||
|
|
||||||
result = subprocess.run(**run_kwargs)
|
|
||||||
|
|
||||||
if result.returncode != 0:
|
|
||||||
error_bits = [
|
|
||||||
f"{errdesc or 'Error running command'}.",
|
|
||||||
f"Command: {command}",
|
|
||||||
f"Error code: {result.returncode}",
|
|
||||||
]
|
|
||||||
if result.stdout:
|
|
||||||
error_bits.append(f"stdout: {result.stdout}")
|
|
||||||
if result.stderr:
|
|
||||||
error_bits.append(f"stderr: {result.stderr}")
|
|
||||||
raise RuntimeError("\n".join(error_bits))
|
|
||||||
|
|
||||||
return (result.stdout or "")
|
|
||||||
|
|
||||||
|
|
||||||
def is_installed(package):
|
|
||||||
try:
|
|
||||||
spec = importlib.util.find_spec(package)
|
|
||||||
except ModuleNotFoundError:
|
|
||||||
return False
|
|
||||||
|
|
||||||
return spec is not None
|
|
||||||
|
|
||||||
|
|
||||||
def repo_dir(name):
|
|
||||||
return os.path.join(script_path, dir_repos, name)
|
|
||||||
|
|
||||||
|
|
||||||
def run_pip(command, desc=None, live=default_command_live):
|
|
||||||
if args.skip_install:
|
|
||||||
return
|
|
||||||
|
|
||||||
index_url_line = f' --index-url {index_url}' if index_url != '' else ''
|
|
||||||
return run(f'"{python}" -m pip {command} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}", live=live)
|
|
||||||
|
|
||||||
|
|
||||||
def check_run_python(code: str) -> bool:
|
|
||||||
result = subprocess.run([python, "-c", code], capture_output=True, shell=False)
|
|
||||||
return result.returncode == 0
|
|
||||||
|
|
||||||
|
|
||||||
def git_clone(url, dir, name, commithash=None):
|
|
||||||
# TODO clone into temporary dir and move if successful
|
|
||||||
|
|
||||||
if os.path.exists(dir):
|
|
||||||
if commithash is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
current_hash = run(f'"{git}" -C "{dir}" rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}").strip()
|
|
||||||
if current_hash == commithash:
|
|
||||||
return
|
|
||||||
|
|
||||||
run(f'"{git}" -C "{dir}" fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")
|
|
||||||
run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}")
|
|
||||||
return
|
|
||||||
|
|
||||||
run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}")
|
|
||||||
|
|
||||||
if commithash is not None:
|
|
||||||
run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
|
|
||||||
|
|
||||||
|
|
||||||
def git_pull_recursive(dir):
|
|
||||||
for subdir, _, _ in os.walk(dir):
|
|
||||||
if os.path.exists(os.path.join(subdir, '.git')):
|
|
||||||
try:
|
|
||||||
output = subprocess.check_output([git, '-C', subdir, 'pull', '--autostash'])
|
|
||||||
print(f"Pulled changes for repository in '{subdir}':\n{output.decode('utf-8').strip()}\n")
|
|
||||||
except subprocess.CalledProcessError as e:
|
|
||||||
print(f"Couldn't perform 'git pull' on repository in '{subdir}':\n{e.output.decode('utf-8').strip()}\n")
|
|
||||||
|
|
||||||
|
|
||||||
def version_check(commit):
|
|
||||||
try:
|
|
||||||
import requests
|
|
||||||
commits = requests.get('https://api.github.com/repos/AUTOMATIC1111/stable-diffusion-webui/branches/master').json()
|
|
||||||
if commit != "<none>" and commits['commit']['sha'] != commit:
|
|
||||||
print("--------------------------------------------------------")
|
|
||||||
print("| You are not up to date with the most recent release. |")
|
|
||||||
print("| Consider running `git pull` to update. |")
|
|
||||||
print("--------------------------------------------------------")
|
|
||||||
elif commits['commit']['sha'] == commit:
|
|
||||||
print("You are up to date with the most recent release.")
|
|
||||||
else:
|
|
||||||
print("Not a git clone, can't perform version check.")
|
|
||||||
except Exception as e:
|
|
||||||
print("version check failed", e)
|
|
||||||
|
|
||||||
|
|
||||||
def run_extension_installer(extension_dir):
|
|
||||||
path_installer = os.path.join(extension_dir, "install.py")
|
|
||||||
if not os.path.isfile(path_installer):
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
env = os.environ.copy()
|
|
||||||
env['PYTHONPATH'] = os.path.abspath(".")
|
|
||||||
|
|
||||||
print(run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env))
|
|
||||||
except Exception as e:
|
|
||||||
print(e, file=sys.stderr)
|
|
||||||
|
|
||||||
|
|
||||||
def list_extensions(settings_file):
|
|
||||||
settings = {}
|
|
||||||
|
|
||||||
try:
|
|
||||||
if os.path.isfile(settings_file):
|
|
||||||
with open(settings_file, "r", encoding="utf8") as file:
|
|
||||||
settings = json.load(file)
|
|
||||||
except Exception as e:
|
|
||||||
print(e, file=sys.stderr)
|
|
||||||
|
|
||||||
disabled_extensions = set(settings.get('disabled_extensions', []))
|
|
||||||
disable_all_extensions = settings.get('disable_all_extensions', 'none')
|
|
||||||
|
|
||||||
if disable_all_extensions != 'none':
|
|
||||||
return []
|
|
||||||
|
|
||||||
return [x for x in os.listdir(extensions_dir) if x not in disabled_extensions]
|
|
||||||
|
|
||||||
|
|
||||||
def run_extensions_installers(settings_file):
|
|
||||||
if not os.path.isdir(extensions_dir):
|
|
||||||
return
|
|
||||||
|
|
||||||
for dirname_extension in list_extensions(settings_file):
|
|
||||||
run_extension_installer(os.path.join(extensions_dir, dirname_extension))
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_environment():
|
|
||||||
torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118")
|
|
||||||
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}")
|
|
||||||
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
|
|
||||||
|
|
||||||
xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.17')
|
|
||||||
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "https://github.com/TencentARC/GFPGAN/archive/8d2447a2d918f8eba5a4a01463fd48e45126a379.zip")
|
|
||||||
clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip")
|
|
||||||
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip")
|
|
||||||
|
|
||||||
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
|
|
||||||
taming_transformers_repo = os.environ.get('TAMING_TRANSFORMERS_REPO', "https://github.com/CompVis/taming-transformers.git")
|
|
||||||
k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
|
|
||||||
codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git')
|
|
||||||
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
|
|
||||||
|
|
||||||
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
|
|
||||||
taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6")
|
|
||||||
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "5b3af030dd83e0297272d861c19477735d0317ec")
|
|
||||||
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
|
|
||||||
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
|
|
||||||
|
|
||||||
if not args.skip_python_version_check:
|
|
||||||
check_python_version()
|
|
||||||
|
|
||||||
commit = commit_hash()
|
|
||||||
tag = git_tag()
|
|
||||||
|
|
||||||
print(f"Python {sys.version}")
|
|
||||||
print(f"Version: {tag}")
|
|
||||||
print(f"Commit hash: {commit}")
|
|
||||||
|
|
||||||
if args.reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):
|
|
||||||
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)
|
|
||||||
|
|
||||||
if not args.skip_torch_cuda_test and not check_run_python("import torch; assert torch.cuda.is_available()"):
|
|
||||||
raise RuntimeError(
|
|
||||||
'Torch is not able to use GPU; '
|
|
||||||
'add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'
|
|
||||||
)
|
|
||||||
|
|
||||||
if not is_installed("gfpgan"):
|
|
||||||
run_pip(f"install {gfpgan_package}", "gfpgan")
|
|
||||||
|
|
||||||
if not is_installed("clip"):
|
|
||||||
run_pip(f"install {clip_package}", "clip")
|
|
||||||
|
|
||||||
if not is_installed("open_clip"):
|
|
||||||
run_pip(f"install {openclip_package}", "open_clip")
|
|
||||||
|
|
||||||
if (not is_installed("xformers") or args.reinstall_xformers) and args.xformers:
|
|
||||||
if platform.system() == "Windows":
|
|
||||||
if platform.python_version().startswith("3.10"):
|
|
||||||
run_pip(f"install -U -I --no-deps {xformers_package}", "xformers", live=True)
|
|
||||||
else:
|
|
||||||
print("Installation of xformers is not supported in this version of Python.")
|
|
||||||
print("You can also check this and build manually: https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers#building-xformers-on-windows-by-duckness")
|
|
||||||
if not is_installed("xformers"):
|
|
||||||
exit(0)
|
|
||||||
elif platform.system() == "Linux":
|
|
||||||
run_pip(f"install -U -I --no-deps {xformers_package}", "xformers")
|
|
||||||
|
|
||||||
if not is_installed("ngrok") and args.ngrok:
|
|
||||||
run_pip("install ngrok", "ngrok")
|
|
||||||
|
|
||||||
os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True)
|
|
||||||
|
|
||||||
git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
|
|
||||||
git_clone(taming_transformers_repo, repo_dir('taming-transformers'), "Taming Transformers", taming_transformers_commit_hash)
|
|
||||||
git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
|
|
||||||
git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
|
|
||||||
git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
|
|
||||||
|
|
||||||
if not is_installed("lpips"):
|
|
||||||
run_pip(f"install -r \"{os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}\"", "requirements for CodeFormer")
|
|
||||||
|
|
||||||
if not os.path.isfile(requirements_file):
|
|
||||||
requirements_file = os.path.join(script_path, requirements_file)
|
|
||||||
run_pip(f"install -r \"{requirements_file}\"", "requirements")
|
|
||||||
|
|
||||||
run_extensions_installers(settings_file=args.ui_settings_file)
|
|
||||||
|
|
||||||
if args.update_check:
|
|
||||||
version_check(commit)
|
|
||||||
|
|
||||||
if args.update_all_extensions:
|
|
||||||
git_pull_recursive(extensions_dir)
|
|
||||||
|
|
||||||
if "--exit" in sys.argv:
|
|
||||||
print("Exiting because of --exit argument")
|
|
||||||
exit(0)
|
|
||||||
|
|
||||||
|
|
||||||
def configure_for_tests():
|
|
||||||
if "--api" not in sys.argv:
|
|
||||||
sys.argv.append("--api")
|
|
||||||
if "--ckpt" not in sys.argv:
|
|
||||||
sys.argv.append("--ckpt")
|
|
||||||
sys.argv.append(os.path.join(script_path, "test/test_files/empty.pt"))
|
|
||||||
if "--skip-torch-cuda-test" not in sys.argv:
|
|
||||||
sys.argv.append("--skip-torch-cuda-test")
|
|
||||||
if "--disable-nan-check" not in sys.argv:
|
|
||||||
sys.argv.append("--disable-nan-check")
|
|
||||||
|
|
||||||
os.environ['COMMANDLINE_ARGS'] = ""
|
|
||||||
|
|
||||||
|
|
||||||
def start():
|
|
||||||
print(f"Launching {'API server' if '--nowebui' in sys.argv else 'Web UI'} with arguments: {' '.join(sys.argv[1:])}")
|
|
||||||
import webui
|
|
||||||
if '--nowebui' in sys.argv:
|
|
||||||
webui.api_only()
|
|
||||||
else:
|
|
||||||
webui.webui()
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
@ -14,7 +14,7 @@ from fastapi.encoders import jsonable_encoder
|
|||||||
from secrets import compare_digest
|
from secrets import compare_digest
|
||||||
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing
|
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors
|
||||||
from modules.api import models
|
from modules.api import models
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
||||||
@ -23,6 +23,7 @@ from modules.textual_inversion.preprocess import preprocess
|
|||||||
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
|
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
|
||||||
from PIL import PngImagePlugin,Image
|
from PIL import PngImagePlugin,Image
|
||||||
from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights
|
from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights
|
||||||
|
from modules.sd_vae import vae_dict
|
||||||
from modules.sd_models_config import find_checkpoint_config_near_filename
|
from modules.sd_models_config import find_checkpoint_config_near_filename
|
||||||
from modules.realesrgan_model import get_realesrgan_models
|
from modules.realesrgan_model import get_realesrgan_models
|
||||||
from modules import devices
|
from modules import devices
|
||||||
@ -108,7 +109,6 @@ def api_middleware(app: FastAPI):
|
|||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
console = Console()
|
console = Console()
|
||||||
except Exception:
|
except Exception:
|
||||||
import traceback
|
|
||||||
rich_available = False
|
rich_available = False
|
||||||
|
|
||||||
@app.middleware("http")
|
@app.middleware("http")
|
||||||
@ -139,11 +139,12 @@ def api_middleware(app: FastAPI):
|
|||||||
"errors": str(e),
|
"errors": str(e),
|
||||||
}
|
}
|
||||||
if not isinstance(e, HTTPException): # do not print backtrace on known httpexceptions
|
if not isinstance(e, HTTPException): # do not print backtrace on known httpexceptions
|
||||||
print(f"API error: {request.method}: {request.url} {err}")
|
message = f"API error: {request.method}: {request.url} {err}"
|
||||||
if rich_available:
|
if rich_available:
|
||||||
|
print(message)
|
||||||
console.print_exception(show_locals=True, max_frames=2, extra_lines=1, suppress=[anyio, starlette], word_wrap=False, width=min([console.width, 200]))
|
console.print_exception(show_locals=True, max_frames=2, extra_lines=1, suppress=[anyio, starlette], word_wrap=False, width=min([console.width, 200]))
|
||||||
else:
|
else:
|
||||||
traceback.print_exc()
|
errors.report(message, exc_info=True)
|
||||||
return JSONResponse(status_code=vars(e).get('status_code', 500), content=jsonable_encoder(err))
|
return JSONResponse(status_code=vars(e).get('status_code', 500), content=jsonable_encoder(err))
|
||||||
|
|
||||||
@app.middleware("http")
|
@app.middleware("http")
|
||||||
@ -189,6 +190,7 @@ class Api:
|
|||||||
self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem])
|
self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem])
|
||||||
self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem])
|
self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem])
|
||||||
self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem])
|
self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem])
|
||||||
|
self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=List[models.SDVaeItem])
|
||||||
self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem])
|
self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem])
|
||||||
self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[models.FaceRestorerItem])
|
self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[models.FaceRestorerItem])
|
||||||
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem])
|
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem])
|
||||||
@ -541,6 +543,9 @@ class Api:
|
|||||||
def get_sd_models(self):
|
def get_sd_models(self):
|
||||||
return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in checkpoints_list.values()]
|
return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in checkpoints_list.values()]
|
||||||
|
|
||||||
|
def get_sd_vaes(self):
|
||||||
|
return [{"model_name": x, "filename": vae_dict[x]} for x in vae_dict.keys()]
|
||||||
|
|
||||||
def get_hypernetworks(self):
|
def get_hypernetworks(self):
|
||||||
return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
|
return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
|
||||||
|
|
||||||
@ -700,4 +705,4 @@ class Api:
|
|||||||
|
|
||||||
def launch(self, server_name, port):
|
def launch(self, server_name, port):
|
||||||
self.app.include_router(self.router)
|
self.app.include_router(self.router)
|
||||||
uvicorn.run(self.app, host=server_name, port=port)
|
uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=0)
|
||||||
|
@ -249,6 +249,10 @@ class SDModelItem(BaseModel):
|
|||||||
filename: str = Field(title="Filename")
|
filename: str = Field(title="Filename")
|
||||||
config: Optional[str] = Field(title="Config file")
|
config: Optional[str] = Field(title="Config file")
|
||||||
|
|
||||||
|
class SDVaeItem(BaseModel):
|
||||||
|
model_name: str = Field(title="Model Name")
|
||||||
|
filename: str = Field(title="Filename")
|
||||||
|
|
||||||
class HypernetworkItem(BaseModel):
|
class HypernetworkItem(BaseModel):
|
||||||
name: str = Field(title="Name")
|
name: str = Field(title="Name")
|
||||||
path: Optional[str] = Field(title="Path")
|
path: Optional[str] = Field(title="Path")
|
||||||
|
@ -1,10 +1,8 @@
|
|||||||
import html
|
import html
|
||||||
import sys
|
|
||||||
import threading
|
import threading
|
||||||
import traceback
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from modules import shared, progress
|
from modules import shared, progress, errors
|
||||||
|
|
||||||
queue_lock = threading.Lock()
|
queue_lock = threading.Lock()
|
||||||
|
|
||||||
@ -56,16 +54,14 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
|
|||||||
try:
|
try:
|
||||||
res = list(func(*args, **kwargs))
|
res = list(func(*args, **kwargs))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# When printing out our debug argument list, do not print out more than a MB of text
|
# When printing out our debug argument list,
|
||||||
max_debug_str_len = 131072 # (1024*1024)/8
|
# do not print out more than a 100 KB of text
|
||||||
|
max_debug_str_len = 131072
|
||||||
print("Error completing request", file=sys.stderr)
|
message = "Error completing request"
|
||||||
argStr = f"Arguments: {args} {kwargs}"
|
arg_str = f"Arguments: {args} {kwargs}"[:max_debug_str_len]
|
||||||
print(argStr[:max_debug_str_len], file=sys.stderr)
|
if len(arg_str) > max_debug_str_len:
|
||||||
if len(argStr) > max_debug_str_len:
|
arg_str += f" (Argument list truncated at {max_debug_str_len}/{len(arg_str)} characters)"
|
||||||
print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr)
|
errors.report(f"{message}\n{arg_str}", exc_info=True)
|
||||||
|
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
|
|
||||||
shared.state.job = ""
|
shared.state.job = ""
|
||||||
shared.state.job_count = 0
|
shared.state.job_count = 0
|
||||||
@ -108,4 +104,3 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
|
|||||||
return tuple(res)
|
return tuple(res)
|
||||||
|
|
||||||
return f
|
return f
|
||||||
|
|
||||||
|
@ -11,7 +11,7 @@ parser.add_argument("--skip-python-version-check", action='store_true', help="la
|
|||||||
parser.add_argument("--skip-torch-cuda-test", action='store_true', help="launch.py argument: do not check if CUDA is able to work properly")
|
parser.add_argument("--skip-torch-cuda-test", action='store_true', help="launch.py argument: do not check if CUDA is able to work properly")
|
||||||
parser.add_argument("--reinstall-xformers", action='store_true', help="launch.py argument: install the appropriate version of xformers even if you have some version already installed")
|
parser.add_argument("--reinstall-xformers", action='store_true', help="launch.py argument: install the appropriate version of xformers even if you have some version already installed")
|
||||||
parser.add_argument("--reinstall-torch", action='store_true', help="launch.py argument: install the appropriate version of torch even if you have some version already installed")
|
parser.add_argument("--reinstall-torch", action='store_true', help="launch.py argument: install the appropriate version of torch even if you have some version already installed")
|
||||||
parser.add_argument("--update-check", action='store_true', help="launch.py argument: chck for updates at startup")
|
parser.add_argument("--update-check", action='store_true', help="launch.py argument: check for updates at startup")
|
||||||
parser.add_argument("--test-server", action='store_true', help="launch.py argument: configure server for testing")
|
parser.add_argument("--test-server", action='store_true', help="launch.py argument: configure server for testing")
|
||||||
parser.add_argument("--skip-prepare-environment", action='store_true', help="launch.py argument: skip all environment preparation")
|
parser.add_argument("--skip-prepare-environment", action='store_true', help="launch.py argument: skip all environment preparation")
|
||||||
parser.add_argument("--skip-install", action='store_true', help="launch.py argument: skip installation of packages")
|
parser.add_argument("--skip-install", action='store_true', help="launch.py argument: skip installation of packages")
|
||||||
@ -62,7 +62,7 @@ parser.add_argument("--opt-split-attention-invokeai", action='store_true', help=
|
|||||||
parser.add_argument("--opt-split-attention-v1", action='store_true', help="prefer older version of split attention optimization for automatic choice of optimization")
|
parser.add_argument("--opt-split-attention-v1", action='store_true', help="prefer older version of split attention optimization for automatic choice of optimization")
|
||||||
parser.add_argument("--opt-sdp-attention", action='store_true', help="prefer scaled dot product cross-attention layer optimization for automatic choice of optimization; requires PyTorch 2.*")
|
parser.add_argument("--opt-sdp-attention", action='store_true', help="prefer scaled dot product cross-attention layer optimization for automatic choice of optimization; requires PyTorch 2.*")
|
||||||
parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="prefer scaled dot product cross-attention layer optimization without memory efficient attention for automatic choice of optimization, makes image generation deterministic; requires PyTorch 2.*")
|
parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="prefer scaled dot product cross-attention layer optimization without memory efficient attention for automatic choice of optimization, makes image generation deterministic; requires PyTorch 2.*")
|
||||||
parser.add_argument("--disable-opt-split-attention", action='store_true', help="does not do anything")
|
parser.add_argument("--disable-opt-split-attention", action='store_true', help="prefer no cross-attention layer optimization for automatic choice of optimization")
|
||||||
parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
|
parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
|
||||||
parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
|
parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
|
||||||
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
|
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
|
||||||
|
@ -1,13 +1,11 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import modules.face_restoration
|
import modules.face_restoration
|
||||||
import modules.shared
|
import modules.shared
|
||||||
from modules import shared, devices, modelloader
|
from modules import shared, devices, modelloader, errors
|
||||||
from modules.paths import models_path
|
from modules.paths import models_path
|
||||||
|
|
||||||
# codeformer people made a choice to include modified basicsr library to their project which makes
|
# codeformer people made a choice to include modified basicsr library to their project which makes
|
||||||
@ -105,8 +103,8 @@ def setup_model(dirname):
|
|||||||
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
|
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
|
||||||
del output
|
del output
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
except Exception as error:
|
except Exception:
|
||||||
print(f'\tFailed inference for CodeFormer: {error}', file=sys.stderr)
|
errors.report('Failed inference for CodeFormer', exc_info=True)
|
||||||
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
|
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
|
||||||
|
|
||||||
restored_face = restored_face.astype('uint8')
|
restored_face = restored_face.astype('uint8')
|
||||||
@ -135,7 +133,6 @@ def setup_model(dirname):
|
|||||||
shared.face_restorers.append(codeformer)
|
shared.face_restorers.append(codeformer)
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
print("Error setting up CodeFormer:", file=sys.stderr)
|
errors.report("Error setting up CodeFormer", exc_info=True)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
|
|
||||||
# sys.path = stored_sys_path
|
# sys.path = stored_sys_path
|
||||||
|
@ -3,8 +3,6 @@ Supports saving and restoring webui and extensions from a known working set of c
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
import traceback
|
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
import tqdm
|
import tqdm
|
||||||
@ -13,7 +11,7 @@ from datetime import datetime
|
|||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
import git
|
import git
|
||||||
|
|
||||||
from modules import shared, extensions
|
from modules import shared, extensions, errors
|
||||||
from modules.paths_internal import script_path, config_states_dir
|
from modules.paths_internal import script_path, config_states_dir
|
||||||
|
|
||||||
|
|
||||||
@ -53,8 +51,7 @@ def get_webui_config():
|
|||||||
if os.path.exists(os.path.join(script_path, ".git")):
|
if os.path.exists(os.path.join(script_path, ".git")):
|
||||||
webui_repo = git.Repo(script_path)
|
webui_repo = git.Repo(script_path)
|
||||||
except Exception:
|
except Exception:
|
||||||
print(f"Error reading webui git info from {script_path}:", file=sys.stderr)
|
errors.report(f"Error reading webui git info from {script_path}", exc_info=True)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
|
|
||||||
webui_remote = None
|
webui_remote = None
|
||||||
webui_commit_hash = None
|
webui_commit_hash = None
|
||||||
@ -134,8 +131,7 @@ def restore_webui_config(config):
|
|||||||
if os.path.exists(os.path.join(script_path, ".git")):
|
if os.path.exists(os.path.join(script_path, ".git")):
|
||||||
webui_repo = git.Repo(script_path)
|
webui_repo = git.Repo(script_path)
|
||||||
except Exception:
|
except Exception:
|
||||||
print(f"Error reading webui git info from {script_path}:", file=sys.stderr)
|
errors.report(f"Error reading webui git info from {script_path}", exc_info=True)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -143,8 +139,7 @@ def restore_webui_config(config):
|
|||||||
webui_repo.git.reset(webui_commit_hash, hard=True)
|
webui_repo.git.reset(webui_commit_hash, hard=True)
|
||||||
print(f"* Restored webui to commit {webui_commit_hash}.")
|
print(f"* Restored webui to commit {webui_commit_hash}.")
|
||||||
except Exception:
|
except Exception:
|
||||||
print(f"Error restoring webui to commit {webui_commit_hash}:", file=sys.stderr)
|
errors.report(f"Error restoring webui to commit{webui_commit_hash}")
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
|
|
||||||
|
|
||||||
def restore_extension_config(config):
|
def restore_extension_config(config):
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
import sys
|
import sys
|
||||||
import contextlib
|
import contextlib
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from modules import errors
|
from modules import errors
|
||||||
|
|
||||||
@ -154,3 +156,19 @@ def test_for_nans(x, where):
|
|||||||
message += " Use --disable-nan-check commandline argument to disable this check."
|
message += " Use --disable-nan-check commandline argument to disable this check."
|
||||||
|
|
||||||
raise NansException(message)
|
raise NansException(message)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def first_time_calculation():
|
||||||
|
"""
|
||||||
|
just do any calculation with pytorch layers - the first time this is done it allocaltes about 700MB of memory and
|
||||||
|
spends about 2.7 seconds doing that, at least wih NVidia.
|
||||||
|
"""
|
||||||
|
|
||||||
|
x = torch.zeros((1, 1)).to(device, dtype)
|
||||||
|
linear = torch.nn.Linear(1, 1).to(device, dtype)
|
||||||
|
linear(x)
|
||||||
|
|
||||||
|
x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
|
||||||
|
conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
|
||||||
|
conv2d(x)
|
||||||
|
@ -1,7 +1,19 @@
|
|||||||
import sys
|
import sys
|
||||||
|
import textwrap
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
|
||||||
|
def report(message: str, *, exc_info: bool = False) -> None:
|
||||||
|
"""
|
||||||
|
Print an error message to stderr, with optional traceback.
|
||||||
|
"""
|
||||||
|
for line in message.splitlines():
|
||||||
|
print("***", line, file=sys.stderr)
|
||||||
|
if exc_info:
|
||||||
|
print(textwrap.indent(traceback.format_exc(), " "), file=sys.stderr)
|
||||||
|
print("---", file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
def print_error_explanation(message):
|
def print_error_explanation(message):
|
||||||
lines = message.strip().split("\n")
|
lines = message.strip().split("\n")
|
||||||
max_len = max([len(x) for x in lines])
|
max_len = max([len(x) for x in lines])
|
||||||
@ -12,9 +24,13 @@ def print_error_explanation(message):
|
|||||||
print('=' * max_len, file=sys.stderr)
|
print('=' * max_len, file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
def display(e: Exception, task):
|
def display(e: Exception, task, *, full_traceback=False):
|
||||||
print(f"{task or 'error'}: {type(e).__name__}", file=sys.stderr)
|
print(f"{task or 'error'}: {type(e).__name__}", file=sys.stderr)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
te = traceback.TracebackException.from_exception(e)
|
||||||
|
if full_traceback:
|
||||||
|
# include frames leading up to the try-catch block
|
||||||
|
te.stack = traceback.StackSummary(traceback.extract_stack()[:-2] + te.stack)
|
||||||
|
print(*te.format(), sep="", file=sys.stderr)
|
||||||
|
|
||||||
message = str(e)
|
message = str(e)
|
||||||
if "copying a param with shape torch.Size([640, 1024]) from checkpoint, the shape in current model is torch.Size([640, 768])" in message:
|
if "copying a param with shape torch.Size([640, 1024]) from checkpoint, the shape in current model is torch.Size([640, 768])" in message:
|
||||||
|
@ -1,11 +1,8 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
import threading
|
import threading
|
||||||
import traceback
|
|
||||||
|
|
||||||
import git
|
from modules import shared, errors
|
||||||
|
from modules.gitpython_hack import Repo
|
||||||
from modules import shared
|
|
||||||
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401
|
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401
|
||||||
|
|
||||||
extensions = []
|
extensions = []
|
||||||
@ -54,10 +51,9 @@ class Extension:
|
|||||||
repo = None
|
repo = None
|
||||||
try:
|
try:
|
||||||
if os.path.exists(os.path.join(self.path, ".git")):
|
if os.path.exists(os.path.join(self.path, ".git")):
|
||||||
repo = git.Repo(self.path)
|
repo = Repo(self.path)
|
||||||
except Exception:
|
except Exception:
|
||||||
print(f"Error reading github repository info from {self.path}:", file=sys.stderr)
|
errors.report(f"Error reading github repository info from {self.path}", exc_info=True)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
|
|
||||||
if repo is None or repo.bare:
|
if repo is None or repo.bare:
|
||||||
self.remote = None
|
self.remote = None
|
||||||
@ -65,14 +61,15 @@ class Extension:
|
|||||||
try:
|
try:
|
||||||
self.status = 'unknown'
|
self.status = 'unknown'
|
||||||
self.remote = next(repo.remote().urls, None)
|
self.remote = next(repo.remote().urls, None)
|
||||||
self.commit_date = repo.head.commit.committed_date
|
commit = repo.head.commit
|
||||||
|
self.commit_date = commit.committed_date
|
||||||
if repo.active_branch:
|
if repo.active_branch:
|
||||||
self.branch = repo.active_branch.name
|
self.branch = repo.active_branch.name
|
||||||
self.commit_hash = repo.head.commit.hexsha
|
self.commit_hash = commit.hexsha
|
||||||
self.version = repo.git.describe("--always", "--tags") # compared to `self.commit_hash[:8]` this takes about 30% more time total but since we run it in parallel we don't care
|
self.version = self.commit_hash[:8]
|
||||||
|
|
||||||
except Exception as ex:
|
except Exception:
|
||||||
print(f"Failed reading extension data from Git repository ({self.name}): {ex}", file=sys.stderr)
|
errors.report(f"Failed reading extension data from Git repository ({self.name})", exc_info=True)
|
||||||
self.remote = None
|
self.remote = None
|
||||||
|
|
||||||
self.have_info_from_repo = True
|
self.have_info_from_repo = True
|
||||||
@ -93,7 +90,7 @@ class Extension:
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
def check_updates(self):
|
def check_updates(self):
|
||||||
repo = git.Repo(self.path)
|
repo = Repo(self.path)
|
||||||
for fetch in repo.remote().fetch(dry_run=True):
|
for fetch in repo.remote().fetch(dry_run=True):
|
||||||
if fetch.flags != fetch.HEAD_UPTODATE:
|
if fetch.flags != fetch.HEAD_UPTODATE:
|
||||||
self.can_update = True
|
self.can_update = True
|
||||||
@ -115,7 +112,7 @@ class Extension:
|
|||||||
self.status = "latest"
|
self.status = "latest"
|
||||||
|
|
||||||
def fetch_and_reset_hard(self, commit='origin'):
|
def fetch_and_reset_hard(self, commit='origin'):
|
||||||
repo = git.Repo(self.path)
|
repo = Repo(self.path)
|
||||||
# Fix: `error: Your local changes to the following files would be overwritten by merge`,
|
# Fix: `error: Your local changes to the following files would be overwritten by merge`,
|
||||||
# because WSL2 Docker set 755 file permissions instead of 644, this results to the error.
|
# because WSL2 Docker set 755 file permissions instead of 644, this results to the error.
|
||||||
repo.git.fetch(all=True)
|
repo.git.fetch(all=True)
|
||||||
|
@ -26,7 +26,7 @@ class ExtraNetworkParams:
|
|||||||
self.named = {}
|
self.named = {}
|
||||||
|
|
||||||
for item in self.items:
|
for item in self.items:
|
||||||
parts = item.split('=', 2)
|
parts = item.split('=', 2) if isinstance(item, str) else [item]
|
||||||
if len(parts) == 2:
|
if len(parts) == 2:
|
||||||
self.named[parts[0]] = parts[1]
|
self.named[parts[0]] = parts[1]
|
||||||
else:
|
else:
|
||||||
|
@ -35,7 +35,7 @@ def reset():
|
|||||||
|
|
||||||
|
|
||||||
def quote(text):
|
def quote(text):
|
||||||
if ',' not in str(text) and '\n' not in str(text):
|
if ',' not in str(text) and '\n' not in str(text) and ':' not in str(text):
|
||||||
return text
|
return text
|
||||||
|
|
||||||
return json.dumps(text, ensure_ascii=False)
|
return json.dumps(text, ensure_ascii=False)
|
||||||
@ -306,6 +306,18 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
|||||||
if "RNG" not in res:
|
if "RNG" not in res:
|
||||||
res["RNG"] = "GPU"
|
res["RNG"] = "GPU"
|
||||||
|
|
||||||
|
if "Schedule type" not in res:
|
||||||
|
res["Schedule type"] = "Automatic"
|
||||||
|
|
||||||
|
if "Schedule max sigma" not in res:
|
||||||
|
res["Schedule max sigma"] = 0
|
||||||
|
|
||||||
|
if "Schedule min sigma" not in res:
|
||||||
|
res["Schedule min sigma"] = 0
|
||||||
|
|
||||||
|
if "Schedule rho" not in res:
|
||||||
|
res["Schedule rho"] = 0
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
@ -318,6 +330,10 @@ infotext_to_setting_name_mapping = [
|
|||||||
('Conditional mask weight', 'inpainting_mask_weight'),
|
('Conditional mask weight', 'inpainting_mask_weight'),
|
||||||
('Model hash', 'sd_model_checkpoint'),
|
('Model hash', 'sd_model_checkpoint'),
|
||||||
('ENSD', 'eta_noise_seed_delta'),
|
('ENSD', 'eta_noise_seed_delta'),
|
||||||
|
('Schedule type', 'k_sched_type'),
|
||||||
|
('Schedule max sigma', 'sigma_max'),
|
||||||
|
('Schedule min sigma', 'sigma_min'),
|
||||||
|
('Schedule rho', 'rho'),
|
||||||
('Noise multiplier', 'initial_noise_multiplier'),
|
('Noise multiplier', 'initial_noise_multiplier'),
|
||||||
('Eta', 'eta_ancestral'),
|
('Eta', 'eta_ancestral'),
|
||||||
('Eta DDIM', 'eta_ddim'),
|
('Eta DDIM', 'eta_ddim'),
|
||||||
|
@ -1,12 +1,10 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
import facexlib
|
import facexlib
|
||||||
import gfpgan
|
import gfpgan
|
||||||
|
|
||||||
import modules.face_restoration
|
import modules.face_restoration
|
||||||
from modules import paths, shared, devices, modelloader
|
from modules import paths, shared, devices, modelloader, errors
|
||||||
|
|
||||||
model_dir = "GFPGAN"
|
model_dir = "GFPGAN"
|
||||||
user_path = None
|
user_path = None
|
||||||
@ -112,5 +110,4 @@ def setup_model(dirname):
|
|||||||
|
|
||||||
shared.face_restorers.append(FaceRestorerGFPGAN())
|
shared.face_restorers.append(FaceRestorerGFPGAN())
|
||||||
except Exception:
|
except Exception:
|
||||||
print("Error setting up GFPGAN:", file=sys.stderr)
|
errors.report("Error setting up GFPGAN", exc_info=True)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
|
42
modules/gitpython_hack.py
Normal file
42
modules/gitpython_hack.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import io
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
import git
|
||||||
|
|
||||||
|
|
||||||
|
class Git(git.Git):
|
||||||
|
"""
|
||||||
|
Git subclassed to never use persistent processes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _get_persistent_cmd(self, attr_name, cmd_name, *args, **kwargs):
|
||||||
|
raise NotImplementedError(f"Refusing to use persistent process: {attr_name} ({cmd_name} {args} {kwargs})")
|
||||||
|
|
||||||
|
def get_object_header(self, ref: str | bytes) -> tuple[str, str, int]:
|
||||||
|
ret = subprocess.check_output(
|
||||||
|
[self.GIT_PYTHON_GIT_EXECUTABLE, "cat-file", "--batch-check"],
|
||||||
|
input=self._prepare_ref(ref),
|
||||||
|
cwd=self._working_dir,
|
||||||
|
timeout=2,
|
||||||
|
)
|
||||||
|
return self._parse_object_header(ret)
|
||||||
|
|
||||||
|
def stream_object_data(self, ref: str) -> tuple[str, str, int, "Git.CatFileContentStream"]:
|
||||||
|
# Not really streaming, per se; this buffers the entire object in memory.
|
||||||
|
# Shouldn't be a problem for our use case, since we're only using this for
|
||||||
|
# object headers (commit objects).
|
||||||
|
ret = subprocess.check_output(
|
||||||
|
[self.GIT_PYTHON_GIT_EXECUTABLE, "cat-file", "--batch"],
|
||||||
|
input=self._prepare_ref(ref),
|
||||||
|
cwd=self._working_dir,
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
|
bio = io.BytesIO(ret)
|
||||||
|
hexsha, typename, size = self._parse_object_header(bio.readline())
|
||||||
|
return (hexsha, typename, size, self.CatFileContentStream(size, bio))
|
||||||
|
|
||||||
|
|
||||||
|
class Repo(git.Repo):
|
||||||
|
GitCommandWrapperType = Git
|
@ -2,8 +2,6 @@ import datetime
|
|||||||
import glob
|
import glob
|
||||||
import html
|
import html
|
||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
import traceback
|
|
||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
import modules.textual_inversion.dataset
|
import modules.textual_inversion.dataset
|
||||||
@ -11,7 +9,7 @@ import torch
|
|||||||
import tqdm
|
import tqdm
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from ldm.util import default
|
from ldm.util import default
|
||||||
from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint
|
from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors
|
||||||
from modules.textual_inversion import textual_inversion, logging
|
from modules.textual_inversion import textual_inversion, logging
|
||||||
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
||||||
from torch import einsum
|
from torch import einsum
|
||||||
@ -325,17 +323,14 @@ def load_hypernetwork(name):
|
|||||||
if path is None:
|
if path is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
hypernetwork = Hypernetwork()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
hypernetwork = Hypernetwork()
|
||||||
hypernetwork.load(path)
|
hypernetwork.load(path)
|
||||||
|
return hypernetwork
|
||||||
except Exception:
|
except Exception:
|
||||||
print(f"Error loading hypernetwork {path}", file=sys.stderr)
|
errors.report(f"Error loading hypernetwork {path}", exc_info=True)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return hypernetwork
|
|
||||||
|
|
||||||
|
|
||||||
def load_hypernetworks(names, multipliers=None):
|
def load_hypernetworks(names, multipliers=None):
|
||||||
already_loaded = {}
|
already_loaded = {}
|
||||||
@ -770,7 +765,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
|
|||||||
</p>
|
</p>
|
||||||
"""
|
"""
|
||||||
except Exception:
|
except Exception:
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
errors.report("Exception in training hypernetwork", exc_info=True)
|
||||||
finally:
|
finally:
|
||||||
pbar.leave = False
|
pbar.leave = False
|
||||||
pbar.close()
|
pbar.close()
|
||||||
|
@ -1,6 +1,4 @@
|
|||||||
import datetime
|
import datetime
|
||||||
import sys
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
import pytz
|
import pytz
|
||||||
import io
|
import io
|
||||||
@ -21,6 +19,8 @@ from modules import sd_samplers, shared, script_callbacks, errors
|
|||||||
from modules.paths_internal import roboto_ttf_file
|
from modules.paths_internal import roboto_ttf_file
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
|
|
||||||
|
import modules.sd_vae as sd_vae
|
||||||
|
|
||||||
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
|
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
|
||||||
|
|
||||||
|
|
||||||
@ -336,8 +336,20 @@ def sanitize_filename_part(text, replace_spaces=True):
|
|||||||
|
|
||||||
|
|
||||||
class FilenameGenerator:
|
class FilenameGenerator:
|
||||||
|
def get_vae_filename(self): #get the name of the VAE file.
|
||||||
|
if sd_vae.loaded_vae_file is None:
|
||||||
|
return "NoneType"
|
||||||
|
file_name = os.path.basename(sd_vae.loaded_vae_file)
|
||||||
|
split_file_name = file_name.split('.')
|
||||||
|
if len(split_file_name) > 1 and split_file_name[0] == '':
|
||||||
|
return split_file_name[1] # if the first character of the filename is "." then [1] is obtained.
|
||||||
|
else:
|
||||||
|
return split_file_name[0]
|
||||||
|
|
||||||
replacements = {
|
replacements = {
|
||||||
'seed': lambda self: self.seed if self.seed is not None else '',
|
'seed': lambda self: self.seed if self.seed is not None else '',
|
||||||
|
'seed_first': lambda self: self.seed if self.p.batch_size == 1 else self.p.all_seeds[0],
|
||||||
|
'seed_last': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.batch_size == 1 else self.p.all_seeds[-1],
|
||||||
'steps': lambda self: self.p and self.p.steps,
|
'steps': lambda self: self.p and self.p.steps,
|
||||||
'cfg': lambda self: self.p and self.p.cfg_scale,
|
'cfg': lambda self: self.p and self.p.cfg_scale,
|
||||||
'width': lambda self: self.image.width,
|
'width': lambda self: self.image.width,
|
||||||
@ -354,19 +366,23 @@ class FilenameGenerator:
|
|||||||
'prompt_no_styles': lambda self: self.prompt_no_style(),
|
'prompt_no_styles': lambda self: self.prompt_no_style(),
|
||||||
'prompt_spaces': lambda self: sanitize_filename_part(self.prompt, replace_spaces=False),
|
'prompt_spaces': lambda self: sanitize_filename_part(self.prompt, replace_spaces=False),
|
||||||
'prompt_words': lambda self: self.prompt_words(),
|
'prompt_words': lambda self: self.prompt_words(),
|
||||||
'batch_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.batch_size == 1 else self.p.batch_index + 1,
|
'batch_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.batch_size == 1 or self.zip else self.p.batch_index + 1,
|
||||||
'generation_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.n_iter == 1 and self.p.batch_size == 1 else self.p.iteration * self.p.batch_size + self.p.batch_index + 1,
|
'batch_size': lambda self: self.p.batch_size,
|
||||||
|
'generation_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if (self.p.n_iter == 1 and self.p.batch_size == 1) or self.zip else self.p.iteration * self.p.batch_size + self.p.batch_index + 1,
|
||||||
'hasprompt': lambda self, *args: self.hasprompt(*args), # accepts formats:[hasprompt<prompt1|default><prompt2>..]
|
'hasprompt': lambda self, *args: self.hasprompt(*args), # accepts formats:[hasprompt<prompt1|default><prompt2>..]
|
||||||
'clip_skip': lambda self: opts.data["CLIP_stop_at_last_layers"],
|
'clip_skip': lambda self: opts.data["CLIP_stop_at_last_layers"],
|
||||||
'denoising': lambda self: self.p.denoising_strength if self.p and self.p.denoising_strength else NOTHING_AND_SKIP_PREVIOUS_TEXT,
|
'denoising': lambda self: self.p.denoising_strength if self.p and self.p.denoising_strength else NOTHING_AND_SKIP_PREVIOUS_TEXT,
|
||||||
|
'vae_filename': lambda self: self.get_vae_filename(),
|
||||||
|
|
||||||
}
|
}
|
||||||
default_time_format = '%Y%m%d%H%M%S'
|
default_time_format = '%Y%m%d%H%M%S'
|
||||||
|
|
||||||
def __init__(self, p, seed, prompt, image):
|
def __init__(self, p, seed, prompt, image, zip=False):
|
||||||
self.p = p
|
self.p = p
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
self.prompt = prompt
|
self.prompt = prompt
|
||||||
self.image = image
|
self.image = image
|
||||||
|
self.zip = zip
|
||||||
|
|
||||||
def hasprompt(self, *args):
|
def hasprompt(self, *args):
|
||||||
lower = self.prompt.lower()
|
lower = self.prompt.lower()
|
||||||
@ -446,8 +462,7 @@ class FilenameGenerator:
|
|||||||
replacement = fun(self, *pattern_args)
|
replacement = fun(self, *pattern_args)
|
||||||
except Exception:
|
except Exception:
|
||||||
replacement = None
|
replacement = None
|
||||||
print(f"Error adding [{pattern}] to filename", file=sys.stderr)
|
errors.report(f"Error adding [{pattern}] to filename", exc_info=True)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
|
|
||||||
if replacement == NOTHING_AND_SKIP_PREVIOUS_TEXT:
|
if replacement == NOTHING_AND_SKIP_PREVIOUS_TEXT:
|
||||||
continue
|
continue
|
||||||
@ -488,14 +503,13 @@ def save_image_with_geninfo(image, geninfo, filename, extension=None, existing_p
|
|||||||
|
|
||||||
image_format = Image.registered_extensions()[extension]
|
image_format = Image.registered_extensions()[extension]
|
||||||
|
|
||||||
existing_pnginfo = existing_pnginfo or {}
|
|
||||||
if opts.enable_pnginfo:
|
|
||||||
existing_pnginfo['parameters'] = geninfo
|
|
||||||
|
|
||||||
if extension.lower() == '.png':
|
if extension.lower() == '.png':
|
||||||
pnginfo_data = PngImagePlugin.PngInfo()
|
if opts.enable_pnginfo:
|
||||||
for k, v in (existing_pnginfo or {}).items():
|
pnginfo_data = PngImagePlugin.PngInfo()
|
||||||
pnginfo_data.add_text(k, str(v))
|
for k, v in (existing_pnginfo or {}).items():
|
||||||
|
pnginfo_data.add_text(k, str(v))
|
||||||
|
else:
|
||||||
|
pnginfo_data = None
|
||||||
|
|
||||||
image.save(filename, format=image_format, quality=opts.jpeg_quality, pnginfo=pnginfo_data)
|
image.save(filename, format=image_format, quality=opts.jpeg_quality, pnginfo=pnginfo_data)
|
||||||
|
|
||||||
@ -665,9 +679,10 @@ def read_info_from_image(image):
|
|||||||
items['exif comment'] = exif_comment
|
items['exif comment'] = exif_comment
|
||||||
geninfo = exif_comment
|
geninfo = exif_comment
|
||||||
|
|
||||||
for field in ['jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
|
for field in ['jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
|
||||||
'loop', 'background', 'timestamp', 'duration']:
|
'loop', 'background', 'timestamp', 'duration', 'progressive', 'progression',
|
||||||
items.pop(field, None)
|
'icc_profile', 'chromaticity']:
|
||||||
|
items.pop(field, None)
|
||||||
|
|
||||||
if items.get("Software", None) == "NovelAI":
|
if items.get("Software", None) == "NovelAI":
|
||||||
try:
|
try:
|
||||||
@ -678,8 +693,7 @@ def read_info_from_image(image):
|
|||||||
Negative prompt: {json_info["uc"]}
|
Negative prompt: {json_info["uc"]}
|
||||||
Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]}, Seed: {json_info["seed"]}, Size: {image.width}x{image.height}, Clip skip: 2, ENSD: 31337"""
|
Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]}, Seed: {json_info["seed"]}, Size: {image.width}x{image.height}, Clip skip: 2, ENSD: 31337"""
|
||||||
except Exception:
|
except Exception:
|
||||||
print("Error parsing NovelAI image generation parameters:", file=sys.stderr)
|
errors.report("Error parsing NovelAI image generation parameters", exc_info=True)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
|
|
||||||
return geninfo, items
|
return geninfo, items
|
||||||
|
|
||||||
|
@ -92,7 +92,8 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
|
|||||||
elif mode == 2: # inpaint
|
elif mode == 2: # inpaint
|
||||||
image, mask = init_img_with_mask["image"], init_img_with_mask["mask"]
|
image, mask = init_img_with_mask["image"], init_img_with_mask["mask"]
|
||||||
alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
|
alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
|
||||||
mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L')
|
mask = mask.convert('L').point(lambda x: 255 if x > 128 else 0, mode='1')
|
||||||
|
mask = ImageChops.lighter(alpha_mask, mask).convert('L')
|
||||||
image = image.convert("RGB")
|
image = image.convert("RGB")
|
||||||
elif mode == 3: # inpaint sketch
|
elif mode == 3: # inpaint sketch
|
||||||
image = inpaint_color_sketch
|
image = inpaint_color_sketch
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import re
|
import re
|
||||||
@ -216,8 +215,7 @@ class InterrogateModels:
|
|||||||
res += f", {match}"
|
res += f", {match}"
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
print("Error interrogating", file=sys.stderr)
|
errors.report("Error interrogating", exc_info=True)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
res += "<error>"
|
res += "<error>"
|
||||||
|
|
||||||
self.unload()
|
self.unload()
|
||||||
|
331
modules/launch_utils.py
Normal file
331
modules/launch_utils.py
Normal file
@ -0,0 +1,331 @@
|
|||||||
|
# this scripts installs necessary requirements and launches main program in webui.py
|
||||||
|
import subprocess
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import importlib.util
|
||||||
|
import platform
|
||||||
|
import json
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
from modules import cmd_args, errors
|
||||||
|
from modules.paths_internal import script_path, extensions_dir
|
||||||
|
|
||||||
|
args, _ = cmd_args.parser.parse_known_args()
|
||||||
|
|
||||||
|
python = sys.executable
|
||||||
|
git = os.environ.get('GIT', "git")
|
||||||
|
index_url = os.environ.get('INDEX_URL', "")
|
||||||
|
dir_repos = "repositories"
|
||||||
|
|
||||||
|
# Whether to default to printing command output
|
||||||
|
default_command_live = (os.environ.get('WEBUI_LAUNCH_LIVE_OUTPUT') == "1")
|
||||||
|
|
||||||
|
if 'GRADIO_ANALYTICS_ENABLED' not in os.environ:
|
||||||
|
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
|
||||||
|
|
||||||
|
|
||||||
|
def check_python_version():
|
||||||
|
is_windows = platform.system() == "Windows"
|
||||||
|
major = sys.version_info.major
|
||||||
|
minor = sys.version_info.minor
|
||||||
|
micro = sys.version_info.micro
|
||||||
|
|
||||||
|
if is_windows:
|
||||||
|
supported_minors = [10]
|
||||||
|
else:
|
||||||
|
supported_minors = [7, 8, 9, 10, 11]
|
||||||
|
|
||||||
|
if not (major == 3 and minor in supported_minors):
|
||||||
|
import modules.errors
|
||||||
|
|
||||||
|
modules.errors.print_error_explanation(f"""
|
||||||
|
INCOMPATIBLE PYTHON VERSION
|
||||||
|
|
||||||
|
This program is tested with 3.10.6 Python, but you have {major}.{minor}.{micro}.
|
||||||
|
If you encounter an error with "RuntimeError: Couldn't install torch." message,
|
||||||
|
or any other error regarding unsuccessful package (library) installation,
|
||||||
|
please downgrade (or upgrade) to the latest version of 3.10 Python
|
||||||
|
and delete current Python and "venv" folder in WebUI's directory.
|
||||||
|
|
||||||
|
You can download 3.10 Python from here: https://www.python.org/downloads/release/python-3106/
|
||||||
|
|
||||||
|
{"Alternatively, use a binary release of WebUI: https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases" if is_windows else ""}
|
||||||
|
|
||||||
|
Use --skip-python-version-check to suppress this warning.
|
||||||
|
""")
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def commit_hash():
|
||||||
|
try:
|
||||||
|
return subprocess.check_output([git, "rev-parse", "HEAD"], shell=False, encoding='utf8').strip()
|
||||||
|
except Exception:
|
||||||
|
return "<none>"
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def git_tag():
|
||||||
|
try:
|
||||||
|
return subprocess.check_output([git, "describe", "--tags"], shell=False, encoding='utf8').strip()
|
||||||
|
except Exception:
|
||||||
|
return "<none>"
|
||||||
|
|
||||||
|
|
||||||
|
def run(command, desc=None, errdesc=None, custom_env=None, live: bool = default_command_live) -> str:
|
||||||
|
if desc is not None:
|
||||||
|
print(desc)
|
||||||
|
|
||||||
|
run_kwargs = {
|
||||||
|
"args": command,
|
||||||
|
"shell": True,
|
||||||
|
"env": os.environ if custom_env is None else custom_env,
|
||||||
|
"encoding": 'utf8',
|
||||||
|
"errors": 'ignore',
|
||||||
|
}
|
||||||
|
|
||||||
|
if not live:
|
||||||
|
run_kwargs["stdout"] = run_kwargs["stderr"] = subprocess.PIPE
|
||||||
|
|
||||||
|
result = subprocess.run(**run_kwargs)
|
||||||
|
|
||||||
|
if result.returncode != 0:
|
||||||
|
error_bits = [
|
||||||
|
f"{errdesc or 'Error running command'}.",
|
||||||
|
f"Command: {command}",
|
||||||
|
f"Error code: {result.returncode}",
|
||||||
|
]
|
||||||
|
if result.stdout:
|
||||||
|
error_bits.append(f"stdout: {result.stdout}")
|
||||||
|
if result.stderr:
|
||||||
|
error_bits.append(f"stderr: {result.stderr}")
|
||||||
|
raise RuntimeError("\n".join(error_bits))
|
||||||
|
|
||||||
|
return (result.stdout or "")
|
||||||
|
|
||||||
|
|
||||||
|
def is_installed(package):
|
||||||
|
try:
|
||||||
|
spec = importlib.util.find_spec(package)
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return spec is not None
|
||||||
|
|
||||||
|
|
||||||
|
def repo_dir(name):
|
||||||
|
return os.path.join(script_path, dir_repos, name)
|
||||||
|
|
||||||
|
|
||||||
|
def run_pip(command, desc=None, live=default_command_live):
|
||||||
|
if args.skip_install:
|
||||||
|
return
|
||||||
|
|
||||||
|
index_url_line = f' --index-url {index_url}' if index_url != '' else ''
|
||||||
|
return run(f'"{python}" -m pip {command} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}", live=live)
|
||||||
|
|
||||||
|
|
||||||
|
def check_run_python(code: str) -> bool:
|
||||||
|
result = subprocess.run([python, "-c", code], capture_output=True, shell=False)
|
||||||
|
return result.returncode == 0
|
||||||
|
|
||||||
|
|
||||||
|
def git_clone(url, dir, name, commithash=None):
|
||||||
|
# TODO clone into temporary dir and move if successful
|
||||||
|
|
||||||
|
if os.path.exists(dir):
|
||||||
|
if commithash is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
current_hash = run(f'"{git}" -C "{dir}" rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}").strip()
|
||||||
|
if current_hash == commithash:
|
||||||
|
return
|
||||||
|
|
||||||
|
run(f'"{git}" -C "{dir}" fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")
|
||||||
|
run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}")
|
||||||
|
return
|
||||||
|
|
||||||
|
run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}")
|
||||||
|
|
||||||
|
if commithash is not None:
|
||||||
|
run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
|
||||||
|
|
||||||
|
|
||||||
|
def git_pull_recursive(dir):
|
||||||
|
for subdir, _, _ in os.walk(dir):
|
||||||
|
if os.path.exists(os.path.join(subdir, '.git')):
|
||||||
|
try:
|
||||||
|
output = subprocess.check_output([git, '-C', subdir, 'pull', '--autostash'])
|
||||||
|
print(f"Pulled changes for repository in '{subdir}':\n{output.decode('utf-8').strip()}\n")
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
print(f"Couldn't perform 'git pull' on repository in '{subdir}':\n{e.output.decode('utf-8').strip()}\n")
|
||||||
|
|
||||||
|
|
||||||
|
def version_check(commit):
|
||||||
|
try:
|
||||||
|
import requests
|
||||||
|
commits = requests.get('https://api.github.com/repos/AUTOMATIC1111/stable-diffusion-webui/branches/master').json()
|
||||||
|
if commit != "<none>" and commits['commit']['sha'] != commit:
|
||||||
|
print("--------------------------------------------------------")
|
||||||
|
print("| You are not up to date with the most recent release. |")
|
||||||
|
print("| Consider running `git pull` to update. |")
|
||||||
|
print("--------------------------------------------------------")
|
||||||
|
elif commits['commit']['sha'] == commit:
|
||||||
|
print("You are up to date with the most recent release.")
|
||||||
|
else:
|
||||||
|
print("Not a git clone, can't perform version check.")
|
||||||
|
except Exception as e:
|
||||||
|
print("version check failed", e)
|
||||||
|
|
||||||
|
|
||||||
|
def run_extension_installer(extension_dir):
|
||||||
|
path_installer = os.path.join(extension_dir, "install.py")
|
||||||
|
if not os.path.isfile(path_installer):
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
env = os.environ.copy()
|
||||||
|
env['PYTHONPATH'] = os.path.abspath(".")
|
||||||
|
|
||||||
|
print(run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env))
|
||||||
|
except Exception as e:
|
||||||
|
errors.report(str(e))
|
||||||
|
|
||||||
|
|
||||||
|
def list_extensions(settings_file):
|
||||||
|
settings = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
if os.path.isfile(settings_file):
|
||||||
|
with open(settings_file, "r", encoding="utf8") as file:
|
||||||
|
settings = json.load(file)
|
||||||
|
except Exception:
|
||||||
|
errors.report("Could not load settings", exc_info=True)
|
||||||
|
|
||||||
|
disabled_extensions = set(settings.get('disabled_extensions', []))
|
||||||
|
disable_all_extensions = settings.get('disable_all_extensions', 'none')
|
||||||
|
|
||||||
|
if disable_all_extensions != 'none':
|
||||||
|
return []
|
||||||
|
|
||||||
|
return [x for x in os.listdir(extensions_dir) if x not in disabled_extensions]
|
||||||
|
|
||||||
|
|
||||||
|
def run_extensions_installers(settings_file):
|
||||||
|
if not os.path.isdir(extensions_dir):
|
||||||
|
return
|
||||||
|
|
||||||
|
for dirname_extension in list_extensions(settings_file):
|
||||||
|
run_extension_installer(os.path.join(extensions_dir, dirname_extension))
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_environment():
|
||||||
|
torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118")
|
||||||
|
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}")
|
||||||
|
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
|
||||||
|
|
||||||
|
xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.20')
|
||||||
|
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "https://github.com/TencentARC/GFPGAN/archive/8d2447a2d918f8eba5a4a01463fd48e45126a379.zip")
|
||||||
|
clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip")
|
||||||
|
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip")
|
||||||
|
|
||||||
|
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
|
||||||
|
k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
|
||||||
|
codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git')
|
||||||
|
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
|
||||||
|
|
||||||
|
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
|
||||||
|
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "c9fe758757e022f05ca5a53fa8fac28889e4f1cf")
|
||||||
|
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
|
||||||
|
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
|
||||||
|
|
||||||
|
if not args.skip_python_version_check:
|
||||||
|
check_python_version()
|
||||||
|
|
||||||
|
commit = commit_hash()
|
||||||
|
tag = git_tag()
|
||||||
|
|
||||||
|
print(f"Python {sys.version}")
|
||||||
|
print(f"Version: {tag}")
|
||||||
|
print(f"Commit hash: {commit}")
|
||||||
|
|
||||||
|
if args.reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):
|
||||||
|
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)
|
||||||
|
|
||||||
|
if not args.skip_torch_cuda_test and not check_run_python("import torch; assert torch.cuda.is_available()"):
|
||||||
|
raise RuntimeError(
|
||||||
|
'Torch is not able to use GPU; '
|
||||||
|
'add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'
|
||||||
|
)
|
||||||
|
|
||||||
|
if not is_installed("gfpgan"):
|
||||||
|
run_pip(f"install {gfpgan_package}", "gfpgan")
|
||||||
|
|
||||||
|
if not is_installed("clip"):
|
||||||
|
run_pip(f"install {clip_package}", "clip")
|
||||||
|
|
||||||
|
if not is_installed("open_clip"):
|
||||||
|
run_pip(f"install {openclip_package}", "open_clip")
|
||||||
|
|
||||||
|
if (not is_installed("xformers") or args.reinstall_xformers) and args.xformers:
|
||||||
|
if platform.system() == "Windows":
|
||||||
|
if platform.python_version().startswith("3.10"):
|
||||||
|
run_pip(f"install -U -I --no-deps {xformers_package}", "xformers", live=True)
|
||||||
|
else:
|
||||||
|
print("Installation of xformers is not supported in this version of Python.")
|
||||||
|
print("You can also check this and build manually: https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers#building-xformers-on-windows-by-duckness")
|
||||||
|
if not is_installed("xformers"):
|
||||||
|
exit(0)
|
||||||
|
elif platform.system() == "Linux":
|
||||||
|
run_pip(f"install -U -I --no-deps {xformers_package}", "xformers")
|
||||||
|
|
||||||
|
if not is_installed("ngrok") and args.ngrok:
|
||||||
|
run_pip("install ngrok", "ngrok")
|
||||||
|
|
||||||
|
os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True)
|
||||||
|
|
||||||
|
git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
|
||||||
|
git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
|
||||||
|
git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
|
||||||
|
git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
|
||||||
|
|
||||||
|
if not is_installed("lpips"):
|
||||||
|
run_pip(f"install -r \"{os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}\"", "requirements for CodeFormer")
|
||||||
|
|
||||||
|
if not os.path.isfile(requirements_file):
|
||||||
|
requirements_file = os.path.join(script_path, requirements_file)
|
||||||
|
run_pip(f"install -r \"{requirements_file}\"", "requirements")
|
||||||
|
|
||||||
|
run_extensions_installers(settings_file=args.ui_settings_file)
|
||||||
|
|
||||||
|
if args.update_check:
|
||||||
|
version_check(commit)
|
||||||
|
|
||||||
|
if args.update_all_extensions:
|
||||||
|
git_pull_recursive(extensions_dir)
|
||||||
|
|
||||||
|
if "--exit" in sys.argv:
|
||||||
|
print("Exiting because of --exit argument")
|
||||||
|
exit(0)
|
||||||
|
|
||||||
|
|
||||||
|
def configure_for_tests():
|
||||||
|
if "--api" not in sys.argv:
|
||||||
|
sys.argv.append("--api")
|
||||||
|
if "--ckpt" not in sys.argv:
|
||||||
|
sys.argv.append("--ckpt")
|
||||||
|
sys.argv.append(os.path.join(script_path, "test/test_files/empty.pt"))
|
||||||
|
if "--skip-torch-cuda-test" not in sys.argv:
|
||||||
|
sys.argv.append("--skip-torch-cuda-test")
|
||||||
|
if "--disable-nan-check" not in sys.argv:
|
||||||
|
sys.argv.append("--disable-nan-check")
|
||||||
|
|
||||||
|
os.environ['COMMANDLINE_ARGS'] = ""
|
||||||
|
|
||||||
|
|
||||||
|
def start():
|
||||||
|
print(f"Launching {'API server' if '--nowebui' in sys.argv else 'Web UI'} with arguments: {' '.join(sys.argv[1:])}")
|
||||||
|
import webui
|
||||||
|
if '--nowebui' in sys.argv:
|
||||||
|
webui.api_only()
|
||||||
|
else:
|
||||||
|
webui.webui()
|
@ -1,8 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
|
from modules import errors
|
||||||
|
|
||||||
localizations = {}
|
localizations = {}
|
||||||
|
|
||||||
@ -31,7 +30,6 @@ def localization_js(current_localization_name: str) -> str:
|
|||||||
with open(fn, "r", encoding="utf8") as file:
|
with open(fn, "r", encoding="utf8") as file:
|
||||||
data = json.load(file)
|
data = json.load(file)
|
||||||
except Exception:
|
except Exception:
|
||||||
print(f"Error loading localization from {fn}:", file=sys.stderr)
|
errors.report(f"Error loading localization from {fn}", exc_info=True)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
|
|
||||||
return f"window.localization = {json.dumps(data)}"
|
return f"window.localization = {json.dumps(data)}"
|
||||||
|
@ -20,7 +20,6 @@ assert sd_path is not None, f"Couldn't find Stable Diffusion in any of: {possibl
|
|||||||
|
|
||||||
path_dirs = [
|
path_dirs = [
|
||||||
(sd_path, 'ldm', 'Stable Diffusion', []),
|
(sd_path, 'ldm', 'Stable Diffusion', []),
|
||||||
(os.path.join(sd_path, '../taming-transformers'), 'taming', 'Taming Transformers', []),
|
|
||||||
(os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []),
|
(os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []),
|
||||||
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
|
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
|
||||||
(os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
|
(os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
@ -13,7 +14,7 @@ from skimage import exposure
|
|||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
import modules.sd_hijack
|
import modules.sd_hijack
|
||||||
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common
|
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet
|
||||||
from modules.sd_hijack import model_hijack
|
from modules.sd_hijack import model_hijack
|
||||||
from modules.shared import opts, cmd_opts, state
|
from modules.shared import opts, cmd_opts, state
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
@ -23,7 +24,6 @@ import modules.images as images
|
|||||||
import modules.styles
|
import modules.styles
|
||||||
import modules.sd_models as sd_models
|
import modules.sd_models as sd_models
|
||||||
import modules.sd_vae as sd_vae
|
import modules.sd_vae as sd_vae
|
||||||
import logging
|
|
||||||
from ldm.data.util import AddMiDaS
|
from ldm.data.util import AddMiDaS
|
||||||
from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
|
from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
|
||||||
|
|
||||||
@ -321,14 +321,13 @@ class StableDiffusionProcessing:
|
|||||||
have been used before. The second element is where the previously
|
have been used before. The second element is where the previously
|
||||||
computed result is stored.
|
computed result is stored.
|
||||||
"""
|
"""
|
||||||
|
if cache[0] is not None and (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info) == cache[0]:
|
||||||
if cache[0] is not None and (required_prompts, steps) == cache[0]:
|
|
||||||
return cache[1]
|
return cache[1]
|
||||||
|
|
||||||
with devices.autocast():
|
with devices.autocast():
|
||||||
cache[1] = function(shared.sd_model, required_prompts, steps)
|
cache[1] = function(shared.sd_model, required_prompts, steps)
|
||||||
|
|
||||||
cache[0] = (required_prompts, steps)
|
cache[0] = (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info)
|
||||||
return cache[1]
|
return cache[1]
|
||||||
|
|
||||||
def setup_conds(self):
|
def setup_conds(self):
|
||||||
@ -589,11 +588,15 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
|
|||||||
|
|
||||||
|
|
||||||
def process_images(p: StableDiffusionProcessing) -> Processed:
|
def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||||
|
if p.scripts is not None:
|
||||||
|
p.scripts.before_process(p)
|
||||||
|
|
||||||
stored_opts = {k: opts.data[k] for k in p.override_settings.keys()}
|
stored_opts = {k: opts.data[k] for k in p.override_settings.keys()}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
|
# if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
|
||||||
if sd_models.checkpoint_alisases.get(p.override_settings.get('sd_model_checkpoint')) is None:
|
override_checkpoint = p.override_settings.get('sd_model_checkpoint')
|
||||||
|
if override_checkpoint is not None and sd_models.checkpoint_alisases.get(override_checkpoint) is None:
|
||||||
p.override_settings.pop('sd_model_checkpoint', None)
|
p.override_settings.pop('sd_model_checkpoint', None)
|
||||||
sd_models.reload_model_weights()
|
sd_models.reload_model_weights()
|
||||||
|
|
||||||
@ -674,6 +677,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
if shared.opts.live_previews_enable and opts.show_progress_type == "Approx NN":
|
if shared.opts.live_previews_enable and opts.show_progress_type == "Approx NN":
|
||||||
sd_vae_approx.model()
|
sd_vae_approx.model()
|
||||||
|
|
||||||
|
sd_unet.apply_unet()
|
||||||
|
|
||||||
if state.job_count == -1:
|
if state.job_count == -1:
|
||||||
state.job_count = p.n_iter
|
state.job_count = p.n_iter
|
||||||
|
|
||||||
|
@ -1,6 +1,4 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@ -9,7 +7,8 @@ from realesrgan import RealESRGANer
|
|||||||
|
|
||||||
from modules.upscaler import Upscaler, UpscalerData
|
from modules.upscaler import Upscaler, UpscalerData
|
||||||
from modules.shared import cmd_opts, opts
|
from modules.shared import cmd_opts, opts
|
||||||
from modules import modelloader
|
from modules import modelloader, errors
|
||||||
|
|
||||||
|
|
||||||
class UpscalerRealESRGAN(Upscaler):
|
class UpscalerRealESRGAN(Upscaler):
|
||||||
def __init__(self, path):
|
def __init__(self, path):
|
||||||
@ -36,8 +35,7 @@ class UpscalerRealESRGAN(Upscaler):
|
|||||||
self.scalers.append(scaler)
|
self.scalers.append(scaler)
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
print("Error importing Real-ESRGAN:", file=sys.stderr)
|
errors.report("Error importing Real-ESRGAN", exc_info=True)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
self.enable = False
|
self.enable = False
|
||||||
self.scalers = []
|
self.scalers = []
|
||||||
|
|
||||||
@ -76,9 +74,8 @@ class UpscalerRealESRGAN(Upscaler):
|
|||||||
info.local_data_path = load_file_from_url(url=info.data_path, model_dir=self.model_download_path, progress=True)
|
info.local_data_path = load_file_from_url(url=info.data_path, model_dir=self.model_download_path, progress=True)
|
||||||
|
|
||||||
return info
|
return info
|
||||||
except Exception as e:
|
except Exception:
|
||||||
print(f"Error making Real-ESRGAN models list: {e}", file=sys.stderr)
|
errors.report("Error making Real-ESRGAN models list", exc_info=True)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def load_models(self, _):
|
def load_models(self, _):
|
||||||
@ -135,5 +132,4 @@ def get_realesrgan_models(scaler):
|
|||||||
]
|
]
|
||||||
return models
|
return models
|
||||||
except Exception:
|
except Exception:
|
||||||
print("Error making Real-ESRGAN models list:", file=sys.stderr)
|
errors.report("Error making Real-ESRGAN models list", exc_info=True)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
|
@ -2,8 +2,6 @@
|
|||||||
|
|
||||||
import pickle
|
import pickle
|
||||||
import collections
|
import collections
|
||||||
import sys
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import numpy
|
import numpy
|
||||||
@ -11,7 +9,10 @@ import _codecs
|
|||||||
import zipfile
|
import zipfile
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
|
||||||
# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
|
# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
|
||||||
|
from modules import errors
|
||||||
|
|
||||||
TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
|
TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
|
||||||
|
|
||||||
def encode(*args):
|
def encode(*args):
|
||||||
@ -136,17 +137,20 @@ def load_with_extra(filename, extra_handler=None, *args, **kwargs):
|
|||||||
check_pt(filename, extra_handler)
|
check_pt(filename, extra_handler)
|
||||||
|
|
||||||
except pickle.UnpicklingError:
|
except pickle.UnpicklingError:
|
||||||
print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
|
errors.report(
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
f"Error verifying pickled file from {filename}\n"
|
||||||
print("-----> !!!! The file is most likely corrupted !!!! <-----", file=sys.stderr)
|
"-----> !!!! The file is most likely corrupted !!!! <-----\n"
|
||||||
print("You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", file=sys.stderr)
|
"You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
|
errors.report(
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
f"Error verifying pickled file from {filename}\n"
|
||||||
print("\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr)
|
f"The file may be malicious, so the program is not going to read it.\n"
|
||||||
print("You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr)
|
f"You can skip this check with --disable-safe-unpickle commandline argument.\n\n",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return unsafe_torch_load(filename, *args, **kwargs)
|
return unsafe_torch_load(filename, *args, **kwargs)
|
||||||
@ -190,4 +194,3 @@ with safe.Extra(handler):
|
|||||||
unsafe_torch_load = torch.load
|
unsafe_torch_load = torch.load
|
||||||
torch.load = load
|
torch.load = load
|
||||||
global_extra_handler = None
|
global_extra_handler = None
|
||||||
|
|
||||||
|
@ -1,18 +1,15 @@
|
|||||||
import sys
|
|
||||||
import traceback
|
|
||||||
from collections import namedtuple
|
|
||||||
import inspect
|
import inspect
|
||||||
|
from collections import namedtuple
|
||||||
from typing import Optional, Dict, Any
|
from typing import Optional, Dict, Any
|
||||||
|
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from gradio import Blocks
|
from gradio import Blocks
|
||||||
|
|
||||||
from modules import timer
|
from modules import errors, timer
|
||||||
|
|
||||||
|
|
||||||
def report_exception(c, job):
|
def report_exception(c, job):
|
||||||
print(f"Error executing callback {job} for {c.script}", file=sys.stderr)
|
errors.report(f"Error executing callback {job} for {c.script}", exc_info=True)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
|
|
||||||
|
|
||||||
class ImageSaveParams:
|
class ImageSaveParams:
|
||||||
@ -113,6 +110,7 @@ callback_map = dict(
|
|||||||
callbacks_before_ui=[],
|
callbacks_before_ui=[],
|
||||||
callbacks_on_reload=[],
|
callbacks_on_reload=[],
|
||||||
callbacks_list_optimizers=[],
|
callbacks_list_optimizers=[],
|
||||||
|
callbacks_list_unets=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -274,6 +272,18 @@ def list_optimizers_callback():
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def list_unets_callback():
|
||||||
|
res = []
|
||||||
|
|
||||||
|
for c in callback_map['callbacks_list_unets']:
|
||||||
|
try:
|
||||||
|
c.callback(res)
|
||||||
|
except Exception:
|
||||||
|
report_exception(c, 'list_unets')
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
def add_callback(callbacks, fun):
|
def add_callback(callbacks, fun):
|
||||||
stack = [x for x in inspect.stack() if x.filename != __file__]
|
stack = [x for x in inspect.stack() if x.filename != __file__]
|
||||||
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
|
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
|
||||||
@ -433,3 +443,10 @@ def on_list_optimizers(callback):
|
|||||||
to it."""
|
to it."""
|
||||||
|
|
||||||
add_callback(callback_map['callbacks_list_optimizers'], callback)
|
add_callback(callback_map['callbacks_list_optimizers'], callback)
|
||||||
|
|
||||||
|
|
||||||
|
def on_list_unets(callback):
|
||||||
|
"""register a function to be called when UI is making a list of alternative options for unet.
|
||||||
|
The function will be called with one argument, a list, and shall add objects of type modules.sd_unet.SdUnetOption to it."""
|
||||||
|
|
||||||
|
add_callback(callback_map['callbacks_list_unets'], callback)
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
import traceback
|
|
||||||
import importlib.util
|
import importlib.util
|
||||||
|
|
||||||
|
from modules import errors
|
||||||
|
|
||||||
|
|
||||||
def load_module(path):
|
def load_module(path):
|
||||||
module_spec = importlib.util.spec_from_file_location(os.path.basename(path), path)
|
module_spec = importlib.util.spec_from_file_location(os.path.basename(path), path)
|
||||||
@ -27,5 +27,4 @@ def preload_extensions(extensions_dir, parser):
|
|||||||
module.preload(parser)
|
module.preload(parser)
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
print(f"Error running preload() for {preload_script}", file=sys.stderr)
|
errors.report(f"Error running preload() for {preload_script}", exc_info=True)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
|
@ -1,12 +1,11 @@
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing, timer
|
from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing, errors, timer
|
||||||
|
|
||||||
AlwaysVisible = object()
|
AlwaysVisible = object()
|
||||||
|
|
||||||
@ -20,6 +19,9 @@ class Script:
|
|||||||
name = None
|
name = None
|
||||||
"""script's internal name derived from title"""
|
"""script's internal name derived from title"""
|
||||||
|
|
||||||
|
section = None
|
||||||
|
"""name of UI section that the script's controls will be placed into"""
|
||||||
|
|
||||||
filename = None
|
filename = None
|
||||||
args_from = None
|
args_from = None
|
||||||
args_to = None
|
args_to = None
|
||||||
@ -82,6 +84,15 @@ class Script:
|
|||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def before_process(self, p, *args):
|
||||||
|
"""
|
||||||
|
This function is called very early before processing begins for AlwaysVisible scripts.
|
||||||
|
You can modify the processing object (p) here, inject hooks, etc.
|
||||||
|
args contains all values returned by components from ui()
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
def process(self, p, *args):
|
def process(self, p, *args):
|
||||||
"""
|
"""
|
||||||
This function is called before processing begins for AlwaysVisible scripts.
|
This function is called before processing begins for AlwaysVisible scripts.
|
||||||
@ -264,8 +275,7 @@ def load_scripts():
|
|||||||
register_scripts_from_module(script_module)
|
register_scripts_from_module(script_module)
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
print(f"Error loading script: {scriptfile.filename}", file=sys.stderr)
|
errors.report(f"Error loading script: {scriptfile.filename}", exc_info=True)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
sys.path = syspath
|
sys.path = syspath
|
||||||
@ -281,11 +291,9 @@ def load_scripts():
|
|||||||
|
|
||||||
def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
|
def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
|
||||||
try:
|
try:
|
||||||
res = func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
return res
|
|
||||||
except Exception:
|
except Exception:
|
||||||
print(f"Error calling: {filename}/{funcname}", file=sys.stderr)
|
errors.report(f"Error calling: {filename}/{funcname}", exc_info=True)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
|
|
||||||
return default
|
return default
|
||||||
|
|
||||||
@ -298,6 +306,7 @@ class ScriptRunner:
|
|||||||
self.titles = []
|
self.titles = []
|
||||||
self.infotext_fields = []
|
self.infotext_fields = []
|
||||||
self.paste_field_names = []
|
self.paste_field_names = []
|
||||||
|
self.inputs = [None]
|
||||||
|
|
||||||
def initialize_scripts(self, is_img2img):
|
def initialize_scripts(self, is_img2img):
|
||||||
from modules import scripts_auto_postprocessing
|
from modules import scripts_auto_postprocessing
|
||||||
@ -325,69 +334,73 @@ class ScriptRunner:
|
|||||||
self.scripts.append(script)
|
self.scripts.append(script)
|
||||||
self.selectable_scripts.append(script)
|
self.selectable_scripts.append(script)
|
||||||
|
|
||||||
def setup_ui(self):
|
def create_script_ui(self, script):
|
||||||
import modules.api.models as api_models
|
import modules.api.models as api_models
|
||||||
|
|
||||||
|
script.args_from = len(self.inputs)
|
||||||
|
script.args_to = len(self.inputs)
|
||||||
|
|
||||||
|
controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img)
|
||||||
|
|
||||||
|
if controls is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
script.name = wrap_call(script.title, script.filename, "title", default=script.filename).lower()
|
||||||
|
api_args = []
|
||||||
|
|
||||||
|
for control in controls:
|
||||||
|
control.custom_script_source = os.path.basename(script.filename)
|
||||||
|
|
||||||
|
arg_info = api_models.ScriptArg(label=control.label or "")
|
||||||
|
|
||||||
|
for field in ("value", "minimum", "maximum", "step", "choices"):
|
||||||
|
v = getattr(control, field, None)
|
||||||
|
if v is not None:
|
||||||
|
setattr(arg_info, field, v)
|
||||||
|
|
||||||
|
api_args.append(arg_info)
|
||||||
|
|
||||||
|
script.api_info = api_models.ScriptInfo(
|
||||||
|
name=script.name,
|
||||||
|
is_img2img=script.is_img2img,
|
||||||
|
is_alwayson=script.alwayson,
|
||||||
|
args=api_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
if script.infotext_fields is not None:
|
||||||
|
self.infotext_fields += script.infotext_fields
|
||||||
|
|
||||||
|
if script.paste_field_names is not None:
|
||||||
|
self.paste_field_names += script.paste_field_names
|
||||||
|
|
||||||
|
self.inputs += controls
|
||||||
|
script.args_to = len(self.inputs)
|
||||||
|
|
||||||
|
def setup_ui_for_section(self, section, scriptlist=None):
|
||||||
|
if scriptlist is None:
|
||||||
|
scriptlist = self.alwayson_scripts
|
||||||
|
|
||||||
|
for script in scriptlist:
|
||||||
|
if script.alwayson and script.section != section:
|
||||||
|
continue
|
||||||
|
|
||||||
|
with gr.Group(visible=script.alwayson) as group:
|
||||||
|
self.create_script_ui(script)
|
||||||
|
|
||||||
|
script.group = group
|
||||||
|
|
||||||
|
def prepare_ui(self):
|
||||||
|
self.inputs = [None]
|
||||||
|
|
||||||
|
def setup_ui(self):
|
||||||
self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts]
|
self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts]
|
||||||
|
|
||||||
inputs = [None]
|
self.setup_ui_for_section(None)
|
||||||
inputs_alwayson = [True]
|
|
||||||
|
|
||||||
def create_script_ui(script, inputs, inputs_alwayson):
|
|
||||||
script.args_from = len(inputs)
|
|
||||||
script.args_to = len(inputs)
|
|
||||||
|
|
||||||
controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img)
|
|
||||||
|
|
||||||
if controls is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
script.name = wrap_call(script.title, script.filename, "title", default=script.filename).lower()
|
|
||||||
api_args = []
|
|
||||||
|
|
||||||
for control in controls:
|
|
||||||
control.custom_script_source = os.path.basename(script.filename)
|
|
||||||
|
|
||||||
arg_info = api_models.ScriptArg(label=control.label or "")
|
|
||||||
|
|
||||||
for field in ("value", "minimum", "maximum", "step", "choices"):
|
|
||||||
v = getattr(control, field, None)
|
|
||||||
if v is not None:
|
|
||||||
setattr(arg_info, field, v)
|
|
||||||
|
|
||||||
api_args.append(arg_info)
|
|
||||||
|
|
||||||
script.api_info = api_models.ScriptInfo(
|
|
||||||
name=script.name,
|
|
||||||
is_img2img=script.is_img2img,
|
|
||||||
is_alwayson=script.alwayson,
|
|
||||||
args=api_args,
|
|
||||||
)
|
|
||||||
|
|
||||||
if script.infotext_fields is not None:
|
|
||||||
self.infotext_fields += script.infotext_fields
|
|
||||||
|
|
||||||
if script.paste_field_names is not None:
|
|
||||||
self.paste_field_names += script.paste_field_names
|
|
||||||
|
|
||||||
inputs += controls
|
|
||||||
inputs_alwayson += [script.alwayson for _ in controls]
|
|
||||||
script.args_to = len(inputs)
|
|
||||||
|
|
||||||
for script in self.alwayson_scripts:
|
|
||||||
with gr.Group() as group:
|
|
||||||
create_script_ui(script, inputs, inputs_alwayson)
|
|
||||||
|
|
||||||
script.group = group
|
|
||||||
|
|
||||||
dropdown = gr.Dropdown(label="Script", elem_id="script_list", choices=["None"] + self.titles, value="None", type="index")
|
dropdown = gr.Dropdown(label="Script", elem_id="script_list", choices=["None"] + self.titles, value="None", type="index")
|
||||||
inputs[0] = dropdown
|
self.inputs[0] = dropdown
|
||||||
|
|
||||||
for script in self.selectable_scripts:
|
self.setup_ui_for_section(None, self.selectable_scripts)
|
||||||
with gr.Group(visible=False) as group:
|
|
||||||
create_script_ui(script, inputs, inputs_alwayson)
|
|
||||||
|
|
||||||
script.group = group
|
|
||||||
|
|
||||||
def select_script(script_index):
|
def select_script(script_index):
|
||||||
selected_script = self.selectable_scripts[script_index - 1] if script_index>0 else None
|
selected_script = self.selectable_scripts[script_index - 1] if script_index>0 else None
|
||||||
@ -412,6 +425,7 @@ class ScriptRunner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.script_load_ctr = 0
|
self.script_load_ctr = 0
|
||||||
|
|
||||||
def onload_script_visibility(params):
|
def onload_script_visibility(params):
|
||||||
title = params.get('Script', None)
|
title = params.get('Script', None)
|
||||||
if title:
|
if title:
|
||||||
@ -422,10 +436,10 @@ class ScriptRunner:
|
|||||||
else:
|
else:
|
||||||
return gr.update(visible=False)
|
return gr.update(visible=False)
|
||||||
|
|
||||||
self.infotext_fields.append( (dropdown, lambda x: gr.update(value=x.get('Script', 'None'))) )
|
self.infotext_fields.append((dropdown, lambda x: gr.update(value=x.get('Script', 'None'))))
|
||||||
self.infotext_fields.extend( [(script.group, onload_script_visibility) for script in self.selectable_scripts] )
|
self.infotext_fields.extend([(script.group, onload_script_visibility) for script in self.selectable_scripts])
|
||||||
|
|
||||||
return inputs
|
return self.inputs
|
||||||
|
|
||||||
def run(self, p, *args):
|
def run(self, p, *args):
|
||||||
script_index = args[0]
|
script_index = args[0]
|
||||||
@ -445,14 +459,21 @@ class ScriptRunner:
|
|||||||
|
|
||||||
return processed
|
return processed
|
||||||
|
|
||||||
|
def before_process(self, p):
|
||||||
|
for script in self.alwayson_scripts:
|
||||||
|
try:
|
||||||
|
script_args = p.script_args[script.args_from:script.args_to]
|
||||||
|
script.before_process(p, *script_args)
|
||||||
|
except Exception:
|
||||||
|
errors.report(f"Error running before_process: {script.filename}", exc_info=True)
|
||||||
|
|
||||||
def process(self, p):
|
def process(self, p):
|
||||||
for script in self.alwayson_scripts:
|
for script in self.alwayson_scripts:
|
||||||
try:
|
try:
|
||||||
script_args = p.script_args[script.args_from:script.args_to]
|
script_args = p.script_args[script.args_from:script.args_to]
|
||||||
script.process(p, *script_args)
|
script.process(p, *script_args)
|
||||||
except Exception:
|
except Exception:
|
||||||
print(f"Error running process: {script.filename}", file=sys.stderr)
|
errors.report(f"Error running process: {script.filename}", exc_info=True)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
|
|
||||||
def before_process_batch(self, p, **kwargs):
|
def before_process_batch(self, p, **kwargs):
|
||||||
for script in self.alwayson_scripts:
|
for script in self.alwayson_scripts:
|
||||||
@ -460,8 +481,7 @@ class ScriptRunner:
|
|||||||
script_args = p.script_args[script.args_from:script.args_to]
|
script_args = p.script_args[script.args_from:script.args_to]
|
||||||
script.before_process_batch(p, *script_args, **kwargs)
|
script.before_process_batch(p, *script_args, **kwargs)
|
||||||
except Exception:
|
except Exception:
|
||||||
print(f"Error running before_process_batch: {script.filename}", file=sys.stderr)
|
errors.report(f"Error running before_process_batch: {script.filename}", exc_info=True)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
|
|
||||||
def process_batch(self, p, **kwargs):
|
def process_batch(self, p, **kwargs):
|
||||||
for script in self.alwayson_scripts:
|
for script in self.alwayson_scripts:
|
||||||
@ -469,8 +489,7 @@ class ScriptRunner:
|
|||||||
script_args = p.script_args[script.args_from:script.args_to]
|
script_args = p.script_args[script.args_from:script.args_to]
|
||||||
script.process_batch(p, *script_args, **kwargs)
|
script.process_batch(p, *script_args, **kwargs)
|
||||||
except Exception:
|
except Exception:
|
||||||
print(f"Error running process_batch: {script.filename}", file=sys.stderr)
|
errors.report(f"Error running process_batch: {script.filename}", exc_info=True)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
|
|
||||||
def postprocess(self, p, processed):
|
def postprocess(self, p, processed):
|
||||||
for script in self.alwayson_scripts:
|
for script in self.alwayson_scripts:
|
||||||
@ -478,8 +497,7 @@ class ScriptRunner:
|
|||||||
script_args = p.script_args[script.args_from:script.args_to]
|
script_args = p.script_args[script.args_from:script.args_to]
|
||||||
script.postprocess(p, processed, *script_args)
|
script.postprocess(p, processed, *script_args)
|
||||||
except Exception:
|
except Exception:
|
||||||
print(f"Error running postprocess: {script.filename}", file=sys.stderr)
|
errors.report(f"Error running postprocess: {script.filename}", exc_info=True)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
|
|
||||||
def postprocess_batch(self, p, images, **kwargs):
|
def postprocess_batch(self, p, images, **kwargs):
|
||||||
for script in self.alwayson_scripts:
|
for script in self.alwayson_scripts:
|
||||||
@ -487,8 +505,7 @@ class ScriptRunner:
|
|||||||
script_args = p.script_args[script.args_from:script.args_to]
|
script_args = p.script_args[script.args_from:script.args_to]
|
||||||
script.postprocess_batch(p, *script_args, images=images, **kwargs)
|
script.postprocess_batch(p, *script_args, images=images, **kwargs)
|
||||||
except Exception:
|
except Exception:
|
||||||
print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr)
|
errors.report(f"Error running postprocess_batch: {script.filename}", exc_info=True)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
|
|
||||||
def postprocess_image(self, p, pp: PostprocessImageArgs):
|
def postprocess_image(self, p, pp: PostprocessImageArgs):
|
||||||
for script in self.alwayson_scripts:
|
for script in self.alwayson_scripts:
|
||||||
@ -496,24 +513,21 @@ class ScriptRunner:
|
|||||||
script_args = p.script_args[script.args_from:script.args_to]
|
script_args = p.script_args[script.args_from:script.args_to]
|
||||||
script.postprocess_image(p, pp, *script_args)
|
script.postprocess_image(p, pp, *script_args)
|
||||||
except Exception:
|
except Exception:
|
||||||
print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr)
|
errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
|
|
||||||
def before_component(self, component, **kwargs):
|
def before_component(self, component, **kwargs):
|
||||||
for script in self.scripts:
|
for script in self.scripts:
|
||||||
try:
|
try:
|
||||||
script.before_component(component, **kwargs)
|
script.before_component(component, **kwargs)
|
||||||
except Exception:
|
except Exception:
|
||||||
print(f"Error running before_component: {script.filename}", file=sys.stderr)
|
errors.report(f"Error running before_component: {script.filename}", exc_info=True)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
|
|
||||||
def after_component(self, component, **kwargs):
|
def after_component(self, component, **kwargs):
|
||||||
for script in self.scripts:
|
for script in self.scripts:
|
||||||
try:
|
try:
|
||||||
script.after_component(component, **kwargs)
|
script.after_component(component, **kwargs)
|
||||||
except Exception:
|
except Exception:
|
||||||
print(f"Error running after_component: {script.filename}", file=sys.stderr)
|
errors.report(f"Error running after_component: {script.filename}", exc_info=True)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
|
|
||||||
def reload_sources(self, cache):
|
def reload_sources(self, cache):
|
||||||
for si, script in list(enumerate(self.scripts)):
|
for si, script in list(enumerate(self.scripts)):
|
||||||
|
@ -3,7 +3,7 @@ from torch.nn.functional import silu
|
|||||||
from types import MethodType
|
from types import MethodType
|
||||||
|
|
||||||
import modules.textual_inversion.textual_inversion
|
import modules.textual_inversion.textual_inversion
|
||||||
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors
|
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet
|
||||||
from modules.hypernetworks import hypernetwork
|
from modules.hypernetworks import hypernetwork
|
||||||
from modules.shared import cmd_opts
|
from modules.shared import cmd_opts
|
||||||
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
|
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
|
||||||
@ -43,11 +43,16 @@ def list_optimizers():
|
|||||||
optimizers.extend(new_optimizers)
|
optimizers.extend(new_optimizers)
|
||||||
|
|
||||||
|
|
||||||
def apply_optimizations():
|
def apply_optimizations(option=None):
|
||||||
global current_optimizer
|
global current_optimizer
|
||||||
|
|
||||||
undo_optimizations()
|
undo_optimizations()
|
||||||
|
|
||||||
|
if len(optimizers) == 0:
|
||||||
|
# a script can access the model very early, and optimizations would not be filled by then
|
||||||
|
current_optimizer = None
|
||||||
|
return ''
|
||||||
|
|
||||||
ldm.modules.diffusionmodules.model.nonlinearity = silu
|
ldm.modules.diffusionmodules.model.nonlinearity = silu
|
||||||
ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
|
ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
|
||||||
|
|
||||||
@ -55,7 +60,7 @@ def apply_optimizations():
|
|||||||
current_optimizer.undo()
|
current_optimizer.undo()
|
||||||
current_optimizer = None
|
current_optimizer = None
|
||||||
|
|
||||||
selection = shared.opts.cross_attention_optimization
|
selection = option or shared.opts.cross_attention_optimization
|
||||||
if selection == "Automatic" and len(optimizers) > 0:
|
if selection == "Automatic" and len(optimizers) > 0:
|
||||||
matching_optimizer = next(iter([x for x in optimizers if x.cmd_opt and getattr(shared.cmd_opts, x.cmd_opt, False)]), optimizers[0])
|
matching_optimizer = next(iter([x for x in optimizers if x.cmd_opt and getattr(shared.cmd_opts, x.cmd_opt, False)]), optimizers[0])
|
||||||
else:
|
else:
|
||||||
@ -63,15 +68,19 @@ def apply_optimizations():
|
|||||||
|
|
||||||
if selection == "None":
|
if selection == "None":
|
||||||
matching_optimizer = None
|
matching_optimizer = None
|
||||||
|
elif selection == "Automatic" and shared.cmd_opts.disable_opt_split_attention:
|
||||||
|
matching_optimizer = None
|
||||||
elif matching_optimizer is None:
|
elif matching_optimizer is None:
|
||||||
matching_optimizer = optimizers[0]
|
matching_optimizer = optimizers[0]
|
||||||
|
|
||||||
if matching_optimizer is not None:
|
if matching_optimizer is not None:
|
||||||
print(f"Applying optimization: {matching_optimizer.name}")
|
print(f"Applying attention optimization: {matching_optimizer.name}... ", end='')
|
||||||
matching_optimizer.apply()
|
matching_optimizer.apply()
|
||||||
|
print("done.")
|
||||||
current_optimizer = matching_optimizer
|
current_optimizer = matching_optimizer
|
||||||
return current_optimizer.name
|
return current_optimizer.name
|
||||||
else:
|
else:
|
||||||
|
print("Disabling attention optimization")
|
||||||
return ''
|
return ''
|
||||||
|
|
||||||
|
|
||||||
@ -149,6 +158,13 @@ class StableDiffusionModelHijack:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
|
self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
|
||||||
|
|
||||||
|
def apply_optimizations(self, option=None):
|
||||||
|
try:
|
||||||
|
self.optimization_method = apply_optimizations(option)
|
||||||
|
except Exception as e:
|
||||||
|
errors.display(e, "applying cross attention optimization")
|
||||||
|
undo_optimizations()
|
||||||
|
|
||||||
def hijack(self, m):
|
def hijack(self, m):
|
||||||
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
|
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
|
||||||
model_embeddings = m.cond_stage_model.roberta.embeddings
|
model_embeddings = m.cond_stage_model.roberta.embeddings
|
||||||
@ -168,11 +184,7 @@ class StableDiffusionModelHijack:
|
|||||||
if m.cond_stage_key == "edit":
|
if m.cond_stage_key == "edit":
|
||||||
sd_hijack_unet.hijack_ddpm_edit()
|
sd_hijack_unet.hijack_ddpm_edit()
|
||||||
|
|
||||||
try:
|
self.apply_optimizations()
|
||||||
self.optimization_method = apply_optimizations()
|
|
||||||
except Exception as e:
|
|
||||||
errors.display(e, "applying cross attention optimization")
|
|
||||||
undo_optimizations()
|
|
||||||
|
|
||||||
self.clip = m.cond_stage_model
|
self.clip = m.cond_stage_model
|
||||||
|
|
||||||
@ -185,6 +197,11 @@ class StableDiffusionModelHijack:
|
|||||||
|
|
||||||
self.layers = flatten(m)
|
self.layers = flatten(m)
|
||||||
|
|
||||||
|
if not hasattr(ldm.modules.diffusionmodules.openaimodel, 'copy_of_UNetModel_forward_for_webui'):
|
||||||
|
ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui = ldm.modules.diffusionmodules.openaimodel.UNetModel.forward
|
||||||
|
|
||||||
|
ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_unet.UNetModel_forward
|
||||||
|
|
||||||
def undo_hijack(self, m):
|
def undo_hijack(self, m):
|
||||||
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
|
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
|
||||||
m.cond_stage_model = m.cond_stage_model.wrapped
|
m.cond_stage_model = m.cond_stage_model.wrapped
|
||||||
@ -206,6 +223,8 @@ class StableDiffusionModelHijack:
|
|||||||
self.layers = None
|
self.layers = None
|
||||||
self.clip = None
|
self.clip = None
|
||||||
|
|
||||||
|
ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui
|
||||||
|
|
||||||
def apply_circular(self, enable):
|
def apply_circular(self, enable):
|
||||||
if self.circular_enabled == enable:
|
if self.circular_enabled == enable:
|
||||||
return
|
return
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
import math
|
import math
|
||||||
import sys
|
|
||||||
import traceback
|
|
||||||
import psutil
|
import psutil
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -58,7 +57,7 @@ class SdOptimizationSdpNoMem(SdOptimization):
|
|||||||
name = "sdp-no-mem"
|
name = "sdp-no-mem"
|
||||||
label = "scaled dot product without memory efficient attention"
|
label = "scaled dot product without memory efficient attention"
|
||||||
cmd_opt = "opt_sdp_no_mem_attention"
|
cmd_opt = "opt_sdp_no_mem_attention"
|
||||||
priority = 90
|
priority = 80
|
||||||
|
|
||||||
def is_available(self):
|
def is_available(self):
|
||||||
return hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(torch.nn.functional.scaled_dot_product_attention)
|
return hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(torch.nn.functional.scaled_dot_product_attention)
|
||||||
@ -72,7 +71,7 @@ class SdOptimizationSdp(SdOptimizationSdpNoMem):
|
|||||||
name = "sdp"
|
name = "sdp"
|
||||||
label = "scaled dot product"
|
label = "scaled dot product"
|
||||||
cmd_opt = "opt_sdp_attention"
|
cmd_opt = "opt_sdp_attention"
|
||||||
priority = 80
|
priority = 70
|
||||||
|
|
||||||
def apply(self):
|
def apply(self):
|
||||||
ldm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward
|
ldm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward
|
||||||
@ -115,7 +114,7 @@ class SdOptimizationInvokeAI(SdOptimization):
|
|||||||
class SdOptimizationDoggettx(SdOptimization):
|
class SdOptimizationDoggettx(SdOptimization):
|
||||||
name = "Doggettx"
|
name = "Doggettx"
|
||||||
cmd_opt = "opt_split_attention"
|
cmd_opt = "opt_split_attention"
|
||||||
priority = 20
|
priority = 90
|
||||||
|
|
||||||
def apply(self):
|
def apply(self):
|
||||||
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
|
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
|
||||||
@ -139,8 +138,7 @@ if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
|
|||||||
import xformers.ops
|
import xformers.ops
|
||||||
shared.xformers_available = True
|
shared.xformers_available = True
|
||||||
except Exception:
|
except Exception:
|
||||||
print("Cannot import xformers", file=sys.stderr)
|
errors.report("Cannot import xformers", exc_info=True)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
|
|
||||||
|
|
||||||
def get_available_vram():
|
def get_available_vram():
|
||||||
|
@ -14,7 +14,7 @@ import ldm.modules.midas as midas
|
|||||||
|
|
||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
|
|
||||||
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config
|
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet
|
||||||
from modules.sd_hijack_inpainting import do_inpainting_hijack
|
from modules.sd_hijack_inpainting import do_inpainting_hijack
|
||||||
from modules.timer import Timer
|
from modules.timer import Timer
|
||||||
import tomesd
|
import tomesd
|
||||||
@ -164,6 +164,7 @@ def model_hash(filename):
|
|||||||
|
|
||||||
|
|
||||||
def select_checkpoint():
|
def select_checkpoint():
|
||||||
|
"""Raises `FileNotFoundError` if no checkpoints are found."""
|
||||||
model_checkpoint = shared.opts.sd_model_checkpoint
|
model_checkpoint = shared.opts.sd_model_checkpoint
|
||||||
|
|
||||||
checkpoint_info = checkpoint_alisases.get(model_checkpoint, None)
|
checkpoint_info = checkpoint_alisases.get(model_checkpoint, None)
|
||||||
@ -171,14 +172,14 @@ def select_checkpoint():
|
|||||||
return checkpoint_info
|
return checkpoint_info
|
||||||
|
|
||||||
if len(checkpoints_list) == 0:
|
if len(checkpoints_list) == 0:
|
||||||
print("No checkpoints found. When searching for checkpoints, looked at:", file=sys.stderr)
|
error_message = "No checkpoints found. When searching for checkpoints, looked at:"
|
||||||
if shared.cmd_opts.ckpt is not None:
|
if shared.cmd_opts.ckpt is not None:
|
||||||
print(f" - file {os.path.abspath(shared.cmd_opts.ckpt)}", file=sys.stderr)
|
error_message += f"\n - file {os.path.abspath(shared.cmd_opts.ckpt)}"
|
||||||
print(f" - directory {model_path}", file=sys.stderr)
|
error_message += f"\n - directory {model_path}"
|
||||||
if shared.cmd_opts.ckpt_dir is not None:
|
if shared.cmd_opts.ckpt_dir is not None:
|
||||||
print(f" - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}", file=sys.stderr)
|
error_message += f"\n - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}"
|
||||||
print("Can't run without a checkpoint. Find and place a .ckpt or .safetensors file into any of those locations. The program will exit.", file=sys.stderr)
|
error_message += "Can't run without a checkpoint. Find and place a .ckpt or .safetensors file into any of those locations."
|
||||||
exit(1)
|
raise FileNotFoundError(error_message)
|
||||||
|
|
||||||
checkpoint_info = next(iter(checkpoints_list.values()))
|
checkpoint_info = next(iter(checkpoints_list.values()))
|
||||||
if model_checkpoint is not None:
|
if model_checkpoint is not None:
|
||||||
@ -313,8 +314,6 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
|||||||
|
|
||||||
timer.record("apply half()")
|
timer.record("apply half()")
|
||||||
|
|
||||||
devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
|
|
||||||
devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
|
|
||||||
devices.dtype_unet = model.model.diffusion_model.dtype
|
devices.dtype_unet = model.model.diffusion_model.dtype
|
||||||
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
|
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
|
||||||
|
|
||||||
@ -423,7 +422,7 @@ class SdModelData:
|
|||||||
try:
|
try:
|
||||||
load_model()
|
load_model()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
errors.display(e, "loading stable diffusion model")
|
errors.display(e, "loading stable diffusion model", full_traceback=True)
|
||||||
print("", file=sys.stderr)
|
print("", file=sys.stderr)
|
||||||
print("Stable diffusion model failed to load", file=sys.stderr)
|
print("Stable diffusion model failed to load", file=sys.stderr)
|
||||||
self.sd_model = None
|
self.sd_model = None
|
||||||
@ -508,6 +507,11 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
|||||||
|
|
||||||
timer.record("scripts callbacks")
|
timer.record("scripts callbacks")
|
||||||
|
|
||||||
|
with devices.autocast(), torch.no_grad():
|
||||||
|
sd_model.cond_stage_model_empty_prompt = sd_model.cond_stage_model([""])
|
||||||
|
|
||||||
|
timer.record("calculate empty prompt")
|
||||||
|
|
||||||
print(f"Model loaded in {timer.summary()}.")
|
print(f"Model loaded in {timer.summary()}.")
|
||||||
|
|
||||||
return sd_model
|
return sd_model
|
||||||
@ -527,6 +531,8 @@ def reload_model_weights(sd_model=None, info=None):
|
|||||||
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
|
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
sd_unet.apply_unet("None")
|
||||||
|
|
||||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||||
lowvram.send_everything_to_cpu()
|
lowvram.send_everything_to_cpu()
|
||||||
else:
|
else:
|
||||||
|
@ -19,7 +19,8 @@ samplers_k_diffusion = [
|
|||||||
('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True, "uses_ensd": True}),
|
('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True, "uses_ensd": True}),
|
||||||
('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {"uses_ensd": True, "second_order": True}),
|
('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {"uses_ensd": True, "second_order": True}),
|
||||||
('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
|
('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
|
||||||
('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {"second_order": True}),
|
('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {"second_order": True, "brownian_noise": True}),
|
||||||
|
('DPM++ 2M SDE', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {"brownian_noise": True}),
|
||||||
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {"uses_ensd": True}),
|
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {"uses_ensd": True}),
|
||||||
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {"uses_ensd": True}),
|
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {"uses_ensd": True}),
|
||||||
('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
|
('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
|
||||||
@ -27,7 +28,8 @@ samplers_k_diffusion = [
|
|||||||
('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
|
('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
|
||||||
('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras', "uses_ensd": True, "second_order": True}),
|
('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras', "uses_ensd": True, "second_order": True}),
|
||||||
('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
|
('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
|
||||||
('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True}),
|
('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}),
|
||||||
|
('DPM++ 2M SDE Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True}),
|
||||||
]
|
]
|
||||||
|
|
||||||
samplers_data_k_diffusion = [
|
samplers_data_k_diffusion = [
|
||||||
@ -42,6 +44,14 @@ sampler_extra_params = {
|
|||||||
'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
k_diffusion_samplers_map = {x.name: x for x in samplers_data_k_diffusion}
|
||||||
|
k_diffusion_scheduler = {
|
||||||
|
'Automatic': None,
|
||||||
|
'karras': k_diffusion.sampling.get_sigmas_karras,
|
||||||
|
'exponential': k_diffusion.sampling.get_sigmas_exponential,
|
||||||
|
'polyexponential': k_diffusion.sampling.get_sigmas_polyexponential
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class CFGDenoiser(torch.nn.Module):
|
class CFGDenoiser(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
@ -123,6 +133,16 @@ class CFGDenoiser(torch.nn.Module):
|
|||||||
x_in = x_in[:-batch_size]
|
x_in = x_in[:-batch_size]
|
||||||
sigma_in = sigma_in[:-batch_size]
|
sigma_in = sigma_in[:-batch_size]
|
||||||
|
|
||||||
|
# TODO add infotext entry
|
||||||
|
if shared.opts.pad_cond_uncond and tensor.shape[1] != uncond.shape[1]:
|
||||||
|
empty = shared.sd_model.cond_stage_model_empty_prompt
|
||||||
|
num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1]
|
||||||
|
|
||||||
|
if num_repeats < 0:
|
||||||
|
tensor = torch.cat([tensor, empty.repeat((tensor.shape[0], -num_repeats, 1))], axis=1)
|
||||||
|
elif num_repeats > 0:
|
||||||
|
uncond = torch.cat([uncond, empty.repeat((uncond.shape[0], num_repeats, 1))], axis=1)
|
||||||
|
|
||||||
if tensor.shape[1] == uncond.shape[1] or skip_uncond:
|
if tensor.shape[1] == uncond.shape[1] or skip_uncond:
|
||||||
if is_edit_model:
|
if is_edit_model:
|
||||||
cond_in = torch.cat([tensor, uncond, uncond])
|
cond_in = torch.cat([tensor, uncond, uncond])
|
||||||
@ -228,7 +248,7 @@ class KDiffusionSampler:
|
|||||||
self.sampler_noises = None
|
self.sampler_noises = None
|
||||||
self.stop_at = None
|
self.stop_at = None
|
||||||
self.eta = None
|
self.eta = None
|
||||||
self.config = None
|
self.config = None # set by the function calling the constructor
|
||||||
self.last_latent = None
|
self.last_latent = None
|
||||||
self.s_min_uncond = None
|
self.s_min_uncond = None
|
||||||
|
|
||||||
@ -253,6 +273,13 @@ class KDiffusionSampler:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
return func()
|
return func()
|
||||||
|
except RecursionError:
|
||||||
|
print(
|
||||||
|
'Encountered RecursionError during sampling, returning last latent. '
|
||||||
|
'rho >5 with a polyexponential scheduler may cause this error. '
|
||||||
|
'You should try to use a smaller rho value instead.'
|
||||||
|
)
|
||||||
|
return self.last_latent
|
||||||
except sd_samplers_common.InterruptedException:
|
except sd_samplers_common.InterruptedException:
|
||||||
return self.last_latent
|
return self.last_latent
|
||||||
|
|
||||||
@ -292,6 +319,31 @@ class KDiffusionSampler:
|
|||||||
|
|
||||||
if p.sampler_noise_scheduler_override:
|
if p.sampler_noise_scheduler_override:
|
||||||
sigmas = p.sampler_noise_scheduler_override(steps)
|
sigmas = p.sampler_noise_scheduler_override(steps)
|
||||||
|
elif opts.k_sched_type != "Automatic":
|
||||||
|
m_sigma_min, m_sigma_max = (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
|
||||||
|
sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (m_sigma_min, m_sigma_max)
|
||||||
|
sigmas_kwargs = {
|
||||||
|
'sigma_min': sigma_min,
|
||||||
|
'sigma_max': sigma_max,
|
||||||
|
}
|
||||||
|
|
||||||
|
sigmas_func = k_diffusion_scheduler[opts.k_sched_type]
|
||||||
|
p.extra_generation_params["Schedule type"] = opts.k_sched_type
|
||||||
|
|
||||||
|
if opts.sigma_min != m_sigma_min and opts.sigma_min != 0:
|
||||||
|
sigmas_kwargs['sigma_min'] = opts.sigma_min
|
||||||
|
p.extra_generation_params["Schedule min sigma"] = opts.sigma_min
|
||||||
|
if opts.sigma_max != m_sigma_max and opts.sigma_max != 0:
|
||||||
|
sigmas_kwargs['sigma_max'] = opts.sigma_max
|
||||||
|
p.extra_generation_params["Schedule max sigma"] = opts.sigma_max
|
||||||
|
|
||||||
|
default_rho = 1. if opts.k_sched_type == "polyexponential" else 7.
|
||||||
|
|
||||||
|
if opts.k_sched_type != 'exponential' and opts.rho != 0 and opts.rho != default_rho:
|
||||||
|
sigmas_kwargs['rho'] = opts.rho
|
||||||
|
p.extra_generation_params["Schedule rho"] = opts.rho
|
||||||
|
|
||||||
|
sigmas = sigmas_func(n=steps, **sigmas_kwargs, device=shared.device)
|
||||||
elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
|
elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
|
||||||
sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
|
sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
|
||||||
|
|
||||||
@ -337,13 +389,13 @@ class KDiffusionSampler:
|
|||||||
if 'sigmas' in parameters:
|
if 'sigmas' in parameters:
|
||||||
extra_params_kwargs['sigmas'] = sigma_sched
|
extra_params_kwargs['sigmas'] = sigma_sched
|
||||||
|
|
||||||
if self.funcname == 'sample_dpmpp_sde':
|
if self.config.options.get('brownian_noise', False):
|
||||||
noise_sampler = self.create_noise_sampler(x, sigmas, p)
|
noise_sampler = self.create_noise_sampler(x, sigmas, p)
|
||||||
extra_params_kwargs['noise_sampler'] = noise_sampler
|
extra_params_kwargs['noise_sampler'] = noise_sampler
|
||||||
|
|
||||||
self.model_wrap_cfg.init_latent = x
|
self.model_wrap_cfg.init_latent = x
|
||||||
self.last_latent = x
|
self.last_latent = x
|
||||||
extra_args={
|
extra_args = {
|
||||||
'cond': conditioning,
|
'cond': conditioning,
|
||||||
'image_cond': image_conditioning,
|
'image_cond': image_conditioning,
|
||||||
'uncond': unconditional_conditioning,
|
'uncond': unconditional_conditioning,
|
||||||
@ -373,7 +425,7 @@ class KDiffusionSampler:
|
|||||||
else:
|
else:
|
||||||
extra_params_kwargs['sigmas'] = sigmas
|
extra_params_kwargs['sigmas'] = sigmas
|
||||||
|
|
||||||
if self.funcname == 'sample_dpmpp_sde':
|
if self.config.options.get('brownian_noise', False):
|
||||||
noise_sampler = self.create_noise_sampler(x, sigmas, p)
|
noise_sampler = self.create_noise_sampler(x, sigmas, p)
|
||||||
extra_params_kwargs['noise_sampler'] = noise_sampler
|
extra_params_kwargs['noise_sampler'] = noise_sampler
|
||||||
|
|
||||||
|
92
modules/sd_unet.py
Normal file
92
modules/sd_unet.py
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
import torch.nn
|
||||||
|
import ldm.modules.diffusionmodules.openaimodel
|
||||||
|
|
||||||
|
from modules import script_callbacks, shared, devices
|
||||||
|
|
||||||
|
unet_options = []
|
||||||
|
current_unet_option = None
|
||||||
|
current_unet = None
|
||||||
|
|
||||||
|
|
||||||
|
def list_unets():
|
||||||
|
new_unets = script_callbacks.list_unets_callback()
|
||||||
|
|
||||||
|
unet_options.clear()
|
||||||
|
unet_options.extend(new_unets)
|
||||||
|
|
||||||
|
|
||||||
|
def get_unet_option(option=None):
|
||||||
|
option = option or shared.opts.sd_unet
|
||||||
|
|
||||||
|
if option == "None":
|
||||||
|
return None
|
||||||
|
|
||||||
|
if option == "Automatic":
|
||||||
|
name = shared.sd_model.sd_checkpoint_info.model_name
|
||||||
|
|
||||||
|
options = [x for x in unet_options if x.model_name == name]
|
||||||
|
|
||||||
|
option = options[0].label if options else "None"
|
||||||
|
|
||||||
|
return next(iter([x for x in unet_options if x.label == option]), None)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_unet(option=None):
|
||||||
|
global current_unet_option
|
||||||
|
global current_unet
|
||||||
|
|
||||||
|
new_option = get_unet_option(option)
|
||||||
|
if new_option == current_unet_option:
|
||||||
|
return
|
||||||
|
|
||||||
|
if current_unet is not None:
|
||||||
|
print(f"Dectivating unet: {current_unet.option.label}")
|
||||||
|
current_unet.deactivate()
|
||||||
|
|
||||||
|
current_unet_option = new_option
|
||||||
|
if current_unet_option is None:
|
||||||
|
current_unet = None
|
||||||
|
|
||||||
|
if not (shared.cmd_opts.lowvram or shared.cmd_opts.medvram):
|
||||||
|
shared.sd_model.model.diffusion_model.to(devices.device)
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
shared.sd_model.model.diffusion_model.to(devices.cpu)
|
||||||
|
devices.torch_gc()
|
||||||
|
|
||||||
|
current_unet = current_unet_option.create_unet()
|
||||||
|
current_unet.option = current_unet_option
|
||||||
|
print(f"Activating unet: {current_unet.option.label}")
|
||||||
|
current_unet.activate()
|
||||||
|
|
||||||
|
|
||||||
|
class SdUnetOption:
|
||||||
|
model_name = None
|
||||||
|
"""name of related checkpoint - this option will be selected automatically for unet if the name of checkpoint matches this"""
|
||||||
|
|
||||||
|
label = None
|
||||||
|
"""name of the unet in UI"""
|
||||||
|
|
||||||
|
def create_unet(self):
|
||||||
|
"""returns SdUnet object to be used as a Unet instead of built-in unet when making pictures"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
class SdUnet(torch.nn.Module):
|
||||||
|
def forward(self, x, timesteps, context, *args, **kwargs):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def activate(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def deactivate(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs):
|
||||||
|
if current_unet is not None:
|
||||||
|
return current_unet.forward(x, timesteps, context, *args, **kwargs)
|
||||||
|
|
||||||
|
return ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui(self, x, timesteps, context, *args, **kwargs)
|
||||||
|
|
@ -6,6 +6,7 @@ import threading
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
|
|
||||||
import modules.interrogate
|
import modules.interrogate
|
||||||
@ -43,19 +44,6 @@ restricted_opts = {
|
|||||||
"outdir_init_images"
|
"outdir_init_images"
|
||||||
}
|
}
|
||||||
|
|
||||||
ui_reorder_categories = [
|
|
||||||
"inpaint",
|
|
||||||
"sampler",
|
|
||||||
"checkboxes",
|
|
||||||
"hires_fix",
|
|
||||||
"dimensions",
|
|
||||||
"cfg",
|
|
||||||
"seed",
|
|
||||||
"batch",
|
|
||||||
"override_settings",
|
|
||||||
"scripts",
|
|
||||||
]
|
|
||||||
|
|
||||||
# https://huggingface.co/datasets/freddyaboulton/gradio-theme-subdomains/resolve/main/subdomains.json
|
# https://huggingface.co/datasets/freddyaboulton/gradio-theme-subdomains/resolve/main/subdomains.json
|
||||||
gradio_hf_hub_themes = [
|
gradio_hf_hub_themes = [
|
||||||
"gradio/glass",
|
"gradio/glass",
|
||||||
@ -76,6 +64,9 @@ cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_op
|
|||||||
devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = \
|
devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = \
|
||||||
(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'esrgan', 'codeformer'])
|
(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'esrgan', 'codeformer'])
|
||||||
|
|
||||||
|
devices.dtype = torch.float32 if cmd_opts.no_half else torch.float16
|
||||||
|
devices.dtype_vae = torch.float32 if cmd_opts.no_half or cmd_opts.no_half_vae else torch.float16
|
||||||
|
|
||||||
device = devices.device
|
device = devices.device
|
||||||
weight_load_location = None if cmd_opts.lowram else "cpu"
|
weight_load_location = None if cmd_opts.lowram else "cpu"
|
||||||
|
|
||||||
@ -314,6 +305,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
|
|||||||
"grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"),
|
"grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"),
|
||||||
"grid_only_if_multiple": OptionInfo(True, "Do not save grids consisting of one picture"),
|
"grid_only_if_multiple": OptionInfo(True, "Do not save grids consisting of one picture"),
|
||||||
"grid_prevent_empty_spots": OptionInfo(False, "Prevent empty spots in grid (when set to autodetect)"),
|
"grid_prevent_empty_spots": OptionInfo(False, "Prevent empty spots in grid (when set to autodetect)"),
|
||||||
|
"grid_zip_filename_pattern": OptionInfo("", "Archive filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
|
||||||
"n_rows": OptionInfo(-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", gr.Slider, {"minimum": -1, "maximum": 16, "step": 1}),
|
"n_rows": OptionInfo(-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", gr.Slider, {"minimum": -1, "maximum": 16, "step": 1}),
|
||||||
|
|
||||||
"enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"),
|
"enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"),
|
||||||
@ -403,6 +395,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
|||||||
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||||
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list).info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"),
|
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list).info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"),
|
||||||
"sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
|
"sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
|
||||||
|
"sd_unet": OptionInfo("Automatic", "SD Unet", gr.Dropdown, lambda: {"choices": shared_items.sd_unet_items()}, refresh=shared_items.refresh_unet_list).info("choose Unet model: Automatic = use one with same filename as checkpoint; None = use Unet from checkpoint"),
|
||||||
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||||
"initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}),
|
"initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}),
|
||||||
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
|
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
|
||||||
@ -414,15 +407,16 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
|||||||
"comma_padding_backtrack": OptionInfo(20, "Prompt word wrap length limit", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1}).info("in tokens - for texts shorter than specified, if they don't fit into 75 token limit, move them to the next 75 token chunk"),
|
"comma_padding_backtrack": OptionInfo(20, "Prompt word wrap length limit", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1}).info("in tokens - for texts shorter than specified, if they don't fit into 75 token limit, move them to the next 75 token chunk"),
|
||||||
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP nrtwork; 1 ignores none, 2 ignores one layer"),
|
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP nrtwork; 1 ignores none, 2 ignores one layer"),
|
||||||
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
|
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
|
||||||
"randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU"]}).info("changes seeds drastically; use CPU to produce the same picture across different vidocard vendors"),
|
"randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('optimizations', "Optimizations"), {
|
options_templates.update(options_section(('optimizations', "Optimizations"), {
|
||||||
"cross_attention_optimization": OptionInfo("Automatic", "Cross attention optimization", gr.Dropdown, lambda: {"choices": shared_items.cross_attention_optimizations()}),
|
"cross_attention_optimization": OptionInfo("Automatic", "Cross attention optimization", gr.Dropdown, lambda: {"choices": shared_items.cross_attention_optimizations()}),
|
||||||
"s_min_uncond": OptionInfo(0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 4.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"),
|
"s_min_uncond": OptionInfo(0.0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 4.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"),
|
||||||
"token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"),
|
"token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"),
|
||||||
"token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio for img2img", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
|
"token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio for img2img", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
|
||||||
"token_merging_ratio_hr": OptionInfo(0.0, "Token merging ratio for high-res pass", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
|
"token_merging_ratio_hr": OptionInfo(0.0, "Token merging ratio for high-res pass", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
|
||||||
|
"pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt to be same length").info("improves performance when prompt and negative prompt have different lengths; changes seeds"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('compatibility', "Compatibility"), {
|
options_templates.update(options_section(('compatibility', "Compatibility"), {
|
||||||
@ -484,9 +478,10 @@ options_templates.update(options_section(('ui', "User interface"), {
|
|||||||
"quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_restart(),
|
"quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_restart(),
|
||||||
"ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(),
|
"ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(),
|
||||||
"hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(),
|
"hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(),
|
||||||
"ui_reorder": OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order").needs_restart(),
|
"ui_reorder_list": OptionInfo([], "txt2img/img2img UI item order", ui_components.DropdownMulti, lambda: {"choices": list(shared_items.ui_reorder_categories())}).info("selected items appear first").needs_restart(),
|
||||||
"hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires sampler selection").needs_restart(),
|
"hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires sampler selection").needs_restart(),
|
||||||
"hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_restart(),
|
"hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_restart(),
|
||||||
|
"disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_restart(),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('infotext', "Infotext"), {
|
options_templates.update(options_section(('infotext', "Infotext"), {
|
||||||
@ -515,6 +510,10 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
|
|||||||
's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||||
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||||
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||||
|
'k_sched_type': OptionInfo("Automatic", "scheduler type", gr.Dropdown, {"choices": ["Automatic", "karras", "exponential", "polyexponential"]}).info("lets you override the noise schedule for k-diffusion samplers; choosing Automatic disables the three parameters below"),
|
||||||
|
'sigma_min': OptionInfo(0.0, "sigma min", gr.Number).info("0 = default (~0.03); minimum noise strength for k-diffusion noise scheduler"),
|
||||||
|
'sigma_max': OptionInfo(0.0, "sigma max", gr.Number).info("0 = default (~14.6); maximum noise strength for k-diffusion noise schedule"),
|
||||||
|
'rho': OptionInfo(0.0, "rho", gr.Number).info("0 = default (7 for karras, 1 for polyexponential); higher values result in a more steep noise schedule (decreases faster)"),
|
||||||
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}).info("ENSD; does not improve anything, just produces different results for ancestral samplers - only useful for reproducing images"),
|
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}).info("ENSD; does not improve anything, just produces different results for ancestral samplers - only useful for reproducing images"),
|
||||||
'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma").link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/6044"),
|
'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma").link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/6044"),
|
||||||
'uni_pc_variant': OptionInfo("bh1", "UniPC variant", gr.Radio, {"choices": ["bh1", "bh2", "vary_coeff"]}),
|
'uni_pc_variant': OptionInfo("bh1", "UniPC variant", gr.Radio, {"choices": ["bh1", "bh2", "vary_coeff"]}),
|
||||||
@ -630,6 +629,10 @@ class Options:
|
|||||||
if self.data.get('quicksettings') is not None and self.data.get('quicksettings_list') is None:
|
if self.data.get('quicksettings') is not None and self.data.get('quicksettings_list') is None:
|
||||||
self.data['quicksettings_list'] = [i.strip() for i in self.data.get('quicksettings').split(',')]
|
self.data['quicksettings_list'] = [i.strip() for i in self.data.get('quicksettings').split(',')]
|
||||||
|
|
||||||
|
# 1.4.0 ui_reorder
|
||||||
|
if isinstance(self.data.get('ui_reorder'), str) and self.data.get('ui_reorder') and "ui_reorder_list" not in self.data:
|
||||||
|
self.data['ui_reorder_list'] = [i.strip() for i in self.data.get('ui_reorder').split(',')]
|
||||||
|
|
||||||
bad_settings = 0
|
bad_settings = 0
|
||||||
for k, v in self.data.items():
|
for k, v in self.data.items():
|
||||||
info = self.data_labels.get(k, None)
|
info = self.data_labels.get(k, None)
|
||||||
|
@ -29,3 +29,41 @@ def cross_attention_optimizations():
|
|||||||
return ["Automatic"] + [x.title() for x in modules.sd_hijack.optimizers] + ["None"]
|
return ["Automatic"] + [x.title() for x in modules.sd_hijack.optimizers] + ["None"]
|
||||||
|
|
||||||
|
|
||||||
|
def sd_unet_items():
|
||||||
|
import modules.sd_unet
|
||||||
|
|
||||||
|
return ["Automatic"] + [x.label for x in modules.sd_unet.unet_options] + ["None"]
|
||||||
|
|
||||||
|
|
||||||
|
def refresh_unet_list():
|
||||||
|
import modules.sd_unet
|
||||||
|
|
||||||
|
modules.sd_unet.list_unets()
|
||||||
|
|
||||||
|
|
||||||
|
ui_reorder_categories_builtin_items = [
|
||||||
|
"inpaint",
|
||||||
|
"sampler",
|
||||||
|
"checkboxes",
|
||||||
|
"hires_fix",
|
||||||
|
"dimensions",
|
||||||
|
"cfg",
|
||||||
|
"seed",
|
||||||
|
"batch",
|
||||||
|
"override_settings",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def ui_reorder_categories():
|
||||||
|
from modules import scripts
|
||||||
|
|
||||||
|
yield from ui_reorder_categories_builtin_items
|
||||||
|
|
||||||
|
sections = {}
|
||||||
|
for script in scripts.scripts_txt2img.scripts + scripts.scripts_img2img.scripts:
|
||||||
|
if isinstance(script.section, str):
|
||||||
|
sections[script.section] = 1
|
||||||
|
|
||||||
|
yield from sections
|
||||||
|
|
||||||
|
yield "scripts"
|
||||||
|
@ -1,8 +1,10 @@
|
|||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
|
import warnings
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import zlib
|
import zlib
|
||||||
from PIL import Image, ImageDraw, ImageFont
|
from PIL import Image, ImageDraw
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
@ -129,14 +131,17 @@ def extract_image_data_embed(image):
|
|||||||
|
|
||||||
|
|
||||||
def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, textfont=None):
|
def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, textfont=None):
|
||||||
|
from modules.images import get_font
|
||||||
|
if textfont:
|
||||||
|
warnings.warn(
|
||||||
|
'passing in a textfont to caption_image_overlay is deprecated and does nothing',
|
||||||
|
DeprecationWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
from math import cos
|
from math import cos
|
||||||
|
|
||||||
image = srcimage.copy()
|
image = srcimage.copy()
|
||||||
fontsize = 32
|
fontsize = 32
|
||||||
if textfont is None:
|
|
||||||
from modules.images import get_font
|
|
||||||
textfont = get_font(fontsize)
|
|
||||||
|
|
||||||
factor = 1.5
|
factor = 1.5
|
||||||
gradient = Image.new('RGBA', (1, image.size[1]), color=(0, 0, 0, 0))
|
gradient = Image.new('RGBA', (1, image.size[1]), color=(0, 0, 0, 0))
|
||||||
for y in range(image.size[1]):
|
for y in range(image.size[1]):
|
||||||
@ -147,12 +152,12 @@ def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, t
|
|||||||
|
|
||||||
draw = ImageDraw.Draw(image)
|
draw = ImageDraw.Draw(image)
|
||||||
|
|
||||||
font = ImageFont.truetype(textfont, fontsize)
|
font = get_font(fontsize)
|
||||||
padding = 10
|
padding = 10
|
||||||
|
|
||||||
_, _, w, h = draw.textbbox((0, 0), title, font=font)
|
_, _, w, h = draw.textbbox((0, 0), title, font=font)
|
||||||
fontsize = min(int(fontsize * (((image.size[0]*0.75)-(padding*4))/w)), 72)
|
fontsize = min(int(fontsize * (((image.size[0]*0.75)-(padding*4))/w)), 72)
|
||||||
font = ImageFont.truetype(textfont, fontsize)
|
font = get_font(fontsize)
|
||||||
_, _, w, h = draw.textbbox((0, 0), title, font=font)
|
_, _, w, h = draw.textbbox((0, 0), title, font=font)
|
||||||
draw.text((padding, padding), title, anchor='lt', font=font, fill=(255, 255, 255, 230))
|
draw.text((padding, padding), title, anchor='lt', font=font, fill=(255, 255, 255, 230))
|
||||||
|
|
||||||
@ -163,7 +168,7 @@ def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, t
|
|||||||
_, _, w, h = draw.textbbox((0, 0), footerRight, font=font)
|
_, _, w, h = draw.textbbox((0, 0), footerRight, font=font)
|
||||||
fontsize_right = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)
|
fontsize_right = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)
|
||||||
|
|
||||||
font = ImageFont.truetype(textfont, min(fontsize_left, fontsize_mid, fontsize_right))
|
font = get_font(min(fontsize_left, fontsize_mid, fontsize_right))
|
||||||
|
|
||||||
draw.text((padding, image.size[1]-padding), footerLeft, anchor='ls', font=font, fill=(255, 255, 255, 230))
|
draw.text((padding, image.size[1]-padding), footerLeft, anchor='ls', font=font, fill=(255, 255, 255, 230))
|
||||||
draw.text((image.size[0]/2, image.size[1]-padding), footerMid, anchor='ms', font=font, fill=(255, 255, 255, 230))
|
draw.text((image.size[0]/2, image.size[1]-padding), footerMid, anchor='ms', font=font, fill=(255, 255, 255, 230))
|
||||||
|
@ -1,6 +1,4 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
import traceback
|
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -14,7 +12,7 @@ import numpy as np
|
|||||||
from PIL import Image, PngImagePlugin
|
from PIL import Image, PngImagePlugin
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint
|
from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors
|
||||||
import modules.textual_inversion.dataset
|
import modules.textual_inversion.dataset
|
||||||
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
||||||
|
|
||||||
@ -120,16 +118,29 @@ class EmbeddingDatabase:
|
|||||||
self.embedding_dirs.clear()
|
self.embedding_dirs.clear()
|
||||||
|
|
||||||
def register_embedding(self, embedding, model):
|
def register_embedding(self, embedding, model):
|
||||||
self.word_embeddings[embedding.name] = embedding
|
return self.register_embedding_by_name(embedding, model, embedding.name)
|
||||||
|
|
||||||
ids = model.cond_stage_model.tokenize([embedding.name])[0]
|
|
||||||
|
|
||||||
|
def register_embedding_by_name(self, embedding, model, name):
|
||||||
|
ids = model.cond_stage_model.tokenize([name])[0]
|
||||||
first_id = ids[0]
|
first_id = ids[0]
|
||||||
if first_id not in self.ids_lookup:
|
if first_id not in self.ids_lookup:
|
||||||
self.ids_lookup[first_id] = []
|
self.ids_lookup[first_id] = []
|
||||||
|
if name in self.word_embeddings:
|
||||||
self.ids_lookup[first_id] = sorted(self.ids_lookup[first_id] + [(ids, embedding)], key=lambda x: len(x[0]), reverse=True)
|
# remove old one from the lookup list
|
||||||
|
lookup = [x for x in self.ids_lookup[first_id] if x[1].name!=name]
|
||||||
|
else:
|
||||||
|
lookup = self.ids_lookup[first_id]
|
||||||
|
if embedding is not None:
|
||||||
|
lookup += [(ids, embedding)]
|
||||||
|
self.ids_lookup[first_id] = sorted(lookup, key=lambda x: len(x[0]), reverse=True)
|
||||||
|
if embedding is None:
|
||||||
|
# unregister embedding with specified name
|
||||||
|
if name in self.word_embeddings:
|
||||||
|
del self.word_embeddings[name]
|
||||||
|
if len(self.ids_lookup[first_id])==0:
|
||||||
|
del self.ids_lookup[first_id]
|
||||||
|
return None
|
||||||
|
self.word_embeddings[name] = embedding
|
||||||
return embedding
|
return embedding
|
||||||
|
|
||||||
def get_expected_shape(self):
|
def get_expected_shape(self):
|
||||||
@ -207,8 +218,7 @@ class EmbeddingDatabase:
|
|||||||
|
|
||||||
self.load_from_file(fullfn, fn)
|
self.load_from_file(fullfn, fn)
|
||||||
except Exception:
|
except Exception:
|
||||||
print(f"Error loading embedding {fn}:", file=sys.stderr)
|
errors.report(f"Error loading embedding {fn}", exc_info=True)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
def load_textual_inversion_embeddings(self, force_reload=False):
|
def load_textual_inversion_embeddings(self, force_reload=False):
|
||||||
@ -632,8 +642,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
|
|||||||
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
|
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
|
||||||
save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True)
|
save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True)
|
||||||
except Exception:
|
except Exception:
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
errors.report("Error training embedding", exc_info=True)
|
||||||
pass
|
|
||||||
finally:
|
finally:
|
||||||
pbar.leave = False
|
pbar.leave = False
|
||||||
pbar.close()
|
pbar.close()
|
||||||
|
394
modules/ui.py
394
modules/ui.py
@ -2,20 +2,21 @@ import json
|
|||||||
import mimetypes
|
import mimetypes
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import gradio.routes
|
|
||||||
import gradio.utils
|
import gradio.utils
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image, PngImagePlugin # noqa: F401
|
from PIL import Image, PngImagePlugin # noqa: F401
|
||||||
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
|
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
|
||||||
|
|
||||||
from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, timer
|
from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors, shared_items, ui_settings, timer
|
||||||
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
|
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
|
||||||
from modules.paths import script_path, data_path
|
from modules.paths import script_path
|
||||||
|
from modules.ui_common import create_refresh_button
|
||||||
|
from modules.ui_gradio_extensions import reload_javascript
|
||||||
|
|
||||||
|
|
||||||
from modules.shared import opts, cmd_opts
|
from modules.shared import opts, cmd_opts
|
||||||
|
|
||||||
@ -35,6 +36,8 @@ import modules.hypernetworks.ui
|
|||||||
from modules.generation_parameters_copypaste import image_from_url_text
|
from modules.generation_parameters_copypaste import image_from_url_text
|
||||||
import modules.extras
|
import modules.extras
|
||||||
|
|
||||||
|
create_setting_component = ui_settings.create_setting_component
|
||||||
|
|
||||||
warnings.filterwarnings("default" if opts.show_warnings else "ignore", category=UserWarning)
|
warnings.filterwarnings("default" if opts.show_warnings else "ignore", category=UserWarning)
|
||||||
|
|
||||||
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
|
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
|
||||||
@ -231,9 +234,8 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info:
|
|||||||
res = all_seeds[index if 0 <= index < len(all_seeds) else 0]
|
res = all_seeds[index if 0 <= index < len(all_seeds) else 0]
|
||||||
|
|
||||||
except json.decoder.JSONDecodeError:
|
except json.decoder.JSONDecodeError:
|
||||||
if gen_info_string != '':
|
if gen_info_string:
|
||||||
print("Error parsing JSON generation info:", file=sys.stderr)
|
errors.report(f"Error parsing JSON generation info: {gen_info_string}")
|
||||||
print(gen_info_string, file=sys.stderr)
|
|
||||||
|
|
||||||
return [res, gr_show(False)]
|
return [res, gr_show(False)]
|
||||||
|
|
||||||
@ -272,12 +274,12 @@ def create_toprow(is_img2img):
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=80):
|
with gr.Column(scale=80):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=3, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)")
|
prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=3, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=80):
|
with gr.Column(scale=80):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)")
|
negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
|
||||||
|
|
||||||
button_interrogate = None
|
button_interrogate = None
|
||||||
button_deepbooru = None
|
button_deepbooru = None
|
||||||
@ -368,25 +370,6 @@ def apply_setting(key, value):
|
|||||||
return getattr(opts, key)
|
return getattr(opts, key)
|
||||||
|
|
||||||
|
|
||||||
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
|
|
||||||
def refresh():
|
|
||||||
refresh_method()
|
|
||||||
args = refreshed_args() if callable(refreshed_args) else refreshed_args
|
|
||||||
|
|
||||||
for k, v in args.items():
|
|
||||||
setattr(refresh_component, k, v)
|
|
||||||
|
|
||||||
return gr.update(**(args or {}))
|
|
||||||
|
|
||||||
refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id)
|
|
||||||
refresh_button.click(
|
|
||||||
fn=refresh,
|
|
||||||
inputs=[],
|
|
||||||
outputs=[refresh_component]
|
|
||||||
)
|
|
||||||
return refresh_button
|
|
||||||
|
|
||||||
|
|
||||||
def create_output_panel(tabname, outdir):
|
def create_output_panel(tabname, outdir):
|
||||||
return ui_common.create_output_panel(tabname, outdir)
|
return ui_common.create_output_panel(tabname, outdir)
|
||||||
|
|
||||||
@ -405,22 +388,12 @@ def create_sampler_and_steps_selection(choices, tabname):
|
|||||||
|
|
||||||
|
|
||||||
def ordered_ui_categories():
|
def ordered_ui_categories():
|
||||||
user_order = {x.strip(): i * 2 + 1 for i, x in enumerate(shared.opts.ui_reorder.split(","))}
|
user_order = {x.strip(): i * 2 + 1 for i, x in enumerate(shared.opts.ui_reorder_list)}
|
||||||
|
|
||||||
for _, category in sorted(enumerate(shared.ui_reorder_categories), key=lambda x: user_order.get(x[1], x[0] * 2 + 0)):
|
for _, category in sorted(enumerate(shared_items.ui_reorder_categories()), key=lambda x: user_order.get(x[1], x[0] * 2 + 0)):
|
||||||
yield category
|
yield category
|
||||||
|
|
||||||
|
|
||||||
def get_value_for_setting(key):
|
|
||||||
value = getattr(opts, key)
|
|
||||||
|
|
||||||
info = opts.data_labels[key]
|
|
||||||
args = info.component_args() if callable(info.component_args) else info.component_args or {}
|
|
||||||
args = {k: v for k, v in args.items() if k not in {'precision'}}
|
|
||||||
|
|
||||||
return gr.update(value=value, **args)
|
|
||||||
|
|
||||||
|
|
||||||
def create_override_settings_dropdown(tabname, row):
|
def create_override_settings_dropdown(tabname, row):
|
||||||
dropdown = gr.Dropdown([], label="Override settings", visible=False, elem_id=f"{tabname}_override_settings", multiselect=True)
|
dropdown = gr.Dropdown([], label="Override settings", visible=False, elem_id=f"{tabname}_override_settings", multiselect=True)
|
||||||
|
|
||||||
@ -456,6 +429,8 @@ def create_ui():
|
|||||||
|
|
||||||
with gr.Row().style(equal_height=False):
|
with gr.Row().style(equal_height=False):
|
||||||
with gr.Column(variant='compact', elem_id="txt2img_settings"):
|
with gr.Column(variant='compact', elem_id="txt2img_settings"):
|
||||||
|
modules.scripts.scripts_txt2img.prepare_ui()
|
||||||
|
|
||||||
for category in ordered_ui_categories():
|
for category in ordered_ui_categories():
|
||||||
if category == "sampler":
|
if category == "sampler":
|
||||||
steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img")
|
steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img")
|
||||||
@ -463,8 +438,8 @@ def create_ui():
|
|||||||
elif category == "dimensions":
|
elif category == "dimensions":
|
||||||
with FormRow():
|
with FormRow():
|
||||||
with gr.Column(elem_id="txt2img_column_size", scale=4):
|
with gr.Column(elem_id="txt2img_column_size", scale=4):
|
||||||
width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512, elem_id="txt2img_width")
|
width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width")
|
||||||
height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512, elem_id="txt2img_height")
|
height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height")
|
||||||
|
|
||||||
with gr.Column(elem_id="txt2img_dimensions_row", scale=1, elem_classes="dimensions-tools"):
|
with gr.Column(elem_id="txt2img_dimensions_row", scale=1, elem_classes="dimensions-tools"):
|
||||||
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn", label="Switch dims")
|
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn", label="Switch dims")
|
||||||
@ -505,10 +480,10 @@ def create_ui():
|
|||||||
with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container:
|
with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container:
|
||||||
with gr.Column(scale=80):
|
with gr.Column(scale=80):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
hr_prompt = gr.Textbox(label="Prompt", elem_id="hires_prompt", show_label=False, lines=3, placeholder="Prompt for hires fix pass.\nLeave empty to use the same prompt as in first pass.")
|
hr_prompt = gr.Textbox(label="Hires prompt", elem_id="hires_prompt", show_label=False, lines=3, placeholder="Prompt for hires fix pass.\nLeave empty to use the same prompt as in first pass.", elem_classes=["prompt"])
|
||||||
with gr.Column(scale=80):
|
with gr.Column(scale=80):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
hr_negative_prompt = gr.Textbox(label="Negative prompt", elem_id="hires_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt for hires fix pass.\nLeave empty to use the same negative prompt as in first pass.")
|
hr_negative_prompt = gr.Textbox(label="Hires negative prompt", elem_id="hires_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt for hires fix pass.\nLeave empty to use the same negative prompt as in first pass.", elem_classes=["prompt"])
|
||||||
|
|
||||||
elif category == "batch":
|
elif category == "batch":
|
||||||
if not opts.dimensions_and_batch_together:
|
if not opts.dimensions_and_batch_together:
|
||||||
@ -524,6 +499,9 @@ def create_ui():
|
|||||||
with FormGroup(elem_id="txt2img_script_container"):
|
with FormGroup(elem_id="txt2img_script_container"):
|
||||||
custom_inputs = modules.scripts.scripts_txt2img.setup_ui()
|
custom_inputs = modules.scripts.scripts_txt2img.setup_ui()
|
||||||
|
|
||||||
|
else:
|
||||||
|
modules.scripts.scripts_txt2img.setup_ui_for_section(category)
|
||||||
|
|
||||||
hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y]
|
hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y]
|
||||||
|
|
||||||
for component in hr_resolution_preview_inputs:
|
for component in hr_resolution_preview_inputs:
|
||||||
@ -616,7 +594,8 @@ def create_ui():
|
|||||||
outputs=[
|
outputs=[
|
||||||
txt2img_prompt,
|
txt2img_prompt,
|
||||||
txt_prompt_img
|
txt_prompt_img
|
||||||
]
|
],
|
||||||
|
show_progress=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
enable_hr.change(
|
enable_hr.change(
|
||||||
@ -779,6 +758,8 @@ def create_ui():
|
|||||||
with FormRow():
|
with FormRow():
|
||||||
resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize")
|
resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize")
|
||||||
|
|
||||||
|
modules.scripts.scripts_img2img.prepare_ui()
|
||||||
|
|
||||||
for category in ordered_ui_categories():
|
for category in ordered_ui_categories():
|
||||||
if category == "sampler":
|
if category == "sampler":
|
||||||
steps, sampler_index = create_sampler_and_steps_selection(samplers_for_img2img, "img2img")
|
steps, sampler_index = create_sampler_and_steps_selection(samplers_for_img2img, "img2img")
|
||||||
@ -792,8 +773,8 @@ def create_ui():
|
|||||||
with gr.Tab(label="Resize to") as tab_scale_to:
|
with gr.Tab(label="Resize to") as tab_scale_to:
|
||||||
with FormRow():
|
with FormRow():
|
||||||
with gr.Column(elem_id="img2img_column_size", scale=4):
|
with gr.Column(elem_id="img2img_column_size", scale=4):
|
||||||
width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512, elem_id="img2img_width")
|
width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
|
||||||
height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512, elem_id="img2img_height")
|
height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height")
|
||||||
with gr.Column(elem_id="img2img_dimensions_row", scale=1, elem_classes="dimensions-tools"):
|
with gr.Column(elem_id="img2img_dimensions_row", scale=1, elem_classes="dimensions-tools"):
|
||||||
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn")
|
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn")
|
||||||
detect_image_size_btn = ToolButton(value=detect_image_size_symbol, elem_id="img2img_detect_image_size_btn")
|
detect_image_size_btn = ToolButton(value=detect_image_size_symbol, elem_id="img2img_detect_image_size_btn")
|
||||||
@ -888,6 +869,8 @@ def create_ui():
|
|||||||
inputs=[],
|
inputs=[],
|
||||||
outputs=[inpaint_controls, mask_alpha],
|
outputs=[inpaint_controls, mask_alpha],
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
modules.scripts.scripts_img2img.setup_ui_for_section(category)
|
||||||
|
|
||||||
img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples)
|
img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples)
|
||||||
|
|
||||||
@ -902,7 +885,8 @@ def create_ui():
|
|||||||
outputs=[
|
outputs=[
|
||||||
img2img_prompt,
|
img2img_prompt,
|
||||||
img2img_prompt_img
|
img2img_prompt_img
|
||||||
]
|
],
|
||||||
|
show_progress=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
img2img_args = dict(
|
img2img_args = dict(
|
||||||
@ -1183,8 +1167,8 @@ def create_ui():
|
|||||||
with gr.Tab(label="Preprocess images", id="preprocess_images"):
|
with gr.Tab(label="Preprocess images", id="preprocess_images"):
|
||||||
process_src = gr.Textbox(label='Source directory', elem_id="train_process_src")
|
process_src = gr.Textbox(label='Source directory', elem_id="train_process_src")
|
||||||
process_dst = gr.Textbox(label='Destination directory', elem_id="train_process_dst")
|
process_dst = gr.Textbox(label='Destination directory', elem_id="train_process_dst")
|
||||||
process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512, elem_id="train_process_width")
|
process_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_process_width")
|
||||||
process_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512, elem_id="train_process_height")
|
process_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_process_height")
|
||||||
preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"], elem_id="train_preprocess_txt_action")
|
preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"], elem_id="train_preprocess_txt_action")
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
@ -1276,8 +1260,8 @@ def create_ui():
|
|||||||
template_file = gr.Dropdown(label='Prompt template', value="style_filewords.txt", elem_id="train_template_file", choices=get_textual_inversion_template_names())
|
template_file = gr.Dropdown(label='Prompt template', value="style_filewords.txt", elem_id="train_template_file", choices=get_textual_inversion_template_names())
|
||||||
create_refresh_button(template_file, textual_inversion.list_textual_inversion_templates, lambda: {"choices": get_textual_inversion_template_names()}, "refrsh_train_template_file")
|
create_refresh_button(template_file, textual_inversion.list_textual_inversion_templates, lambda: {"choices": get_textual_inversion_template_names()}, "refrsh_train_template_file")
|
||||||
|
|
||||||
training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512, elem_id="train_training_width")
|
training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width")
|
||||||
training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512, elem_id="train_training_height")
|
training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height")
|
||||||
varsize = gr.Checkbox(label="Do not resize images", value=False, elem_id="train_varsize")
|
varsize = gr.Checkbox(label="Do not resize images", value=False, elem_id="train_varsize")
|
||||||
steps = gr.Number(label='Max steps', value=100000, precision=0, elem_id="train_steps")
|
steps = gr.Number(label='Max steps', value=100000, precision=0, elem_id="train_steps")
|
||||||
|
|
||||||
@ -1460,195 +1444,10 @@ def create_ui():
|
|||||||
outputs=[],
|
outputs=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_setting_component(key, is_quicksettings=False):
|
|
||||||
def fun():
|
|
||||||
return opts.data[key] if key in opts.data else opts.data_labels[key].default
|
|
||||||
|
|
||||||
info = opts.data_labels[key]
|
|
||||||
t = type(info.default)
|
|
||||||
|
|
||||||
args = info.component_args() if callable(info.component_args) else info.component_args
|
|
||||||
|
|
||||||
if info.component is not None:
|
|
||||||
comp = info.component
|
|
||||||
elif t == str:
|
|
||||||
comp = gr.Textbox
|
|
||||||
elif t == int:
|
|
||||||
comp = gr.Number
|
|
||||||
elif t == bool:
|
|
||||||
comp = gr.Checkbox
|
|
||||||
else:
|
|
||||||
raise Exception(f'bad options item type: {t} for key {key}')
|
|
||||||
|
|
||||||
elem_id = f"setting_{key}"
|
|
||||||
|
|
||||||
if info.refresh is not None:
|
|
||||||
if is_quicksettings:
|
|
||||||
res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
|
|
||||||
create_refresh_button(res, info.refresh, info.component_args, f"refresh_{key}")
|
|
||||||
else:
|
|
||||||
with FormRow():
|
|
||||||
res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
|
|
||||||
create_refresh_button(res, info.refresh, info.component_args, f"refresh_{key}")
|
|
||||||
else:
|
|
||||||
res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
loadsave = ui_loadsave.UiLoadsave(cmd_opts.ui_config_file)
|
loadsave = ui_loadsave.UiLoadsave(cmd_opts.ui_config_file)
|
||||||
|
|
||||||
components = []
|
settings = ui_settings.UiSettings()
|
||||||
component_dict = {}
|
settings.create_ui(loadsave, dummy_component)
|
||||||
shared.settings_components = component_dict
|
|
||||||
|
|
||||||
script_callbacks.ui_settings_callback()
|
|
||||||
opts.reorder()
|
|
||||||
|
|
||||||
def run_settings(*args):
|
|
||||||
changed = []
|
|
||||||
|
|
||||||
for key, value, comp in zip(opts.data_labels.keys(), args, components):
|
|
||||||
assert comp == dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}"
|
|
||||||
|
|
||||||
for key, value, comp in zip(opts.data_labels.keys(), args, components):
|
|
||||||
if comp == dummy_component:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if opts.set(key, value):
|
|
||||||
changed.append(key)
|
|
||||||
|
|
||||||
try:
|
|
||||||
opts.save(shared.config_filename)
|
|
||||||
except RuntimeError:
|
|
||||||
return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.'
|
|
||||||
return opts.dumpjson(), f'{len(changed)} settings changed{": " if len(changed) > 0 else ""}{", ".join(changed)}.'
|
|
||||||
|
|
||||||
def run_settings_single(value, key):
|
|
||||||
if not opts.same_type(value, opts.data_labels[key].default):
|
|
||||||
return gr.update(visible=True), opts.dumpjson()
|
|
||||||
|
|
||||||
if not opts.set(key, value):
|
|
||||||
return gr.update(value=getattr(opts, key)), opts.dumpjson()
|
|
||||||
|
|
||||||
opts.save(shared.config_filename)
|
|
||||||
|
|
||||||
return get_value_for_setting(key), opts.dumpjson()
|
|
||||||
|
|
||||||
with gr.Blocks(analytics_enabled=False) as settings_interface:
|
|
||||||
with gr.Row():
|
|
||||||
with gr.Column(scale=6):
|
|
||||||
settings_submit = gr.Button(value="Apply settings", variant='primary', elem_id="settings_submit")
|
|
||||||
with gr.Column():
|
|
||||||
restart_gradio = gr.Button(value='Reload UI', variant='primary', elem_id="settings_restart_gradio")
|
|
||||||
|
|
||||||
result = gr.HTML(elem_id="settings_result")
|
|
||||||
|
|
||||||
quicksettings_names = opts.quicksettings_list
|
|
||||||
quicksettings_names = {x: i for i, x in enumerate(quicksettings_names) if x != 'quicksettings'}
|
|
||||||
|
|
||||||
quicksettings_list = []
|
|
||||||
|
|
||||||
previous_section = None
|
|
||||||
current_tab = None
|
|
||||||
current_row = None
|
|
||||||
with gr.Tabs(elem_id="settings"):
|
|
||||||
for i, (k, item) in enumerate(opts.data_labels.items()):
|
|
||||||
section_must_be_skipped = item.section[0] is None
|
|
||||||
|
|
||||||
if previous_section != item.section and not section_must_be_skipped:
|
|
||||||
elem_id, text = item.section
|
|
||||||
|
|
||||||
if current_tab is not None:
|
|
||||||
current_row.__exit__()
|
|
||||||
current_tab.__exit__()
|
|
||||||
|
|
||||||
gr.Group()
|
|
||||||
current_tab = gr.TabItem(elem_id=f"settings_{elem_id}", label=text)
|
|
||||||
current_tab.__enter__()
|
|
||||||
current_row = gr.Column(variant='compact')
|
|
||||||
current_row.__enter__()
|
|
||||||
|
|
||||||
previous_section = item.section
|
|
||||||
|
|
||||||
if k in quicksettings_names and not shared.cmd_opts.freeze_settings:
|
|
||||||
quicksettings_list.append((i, k, item))
|
|
||||||
components.append(dummy_component)
|
|
||||||
elif section_must_be_skipped:
|
|
||||||
components.append(dummy_component)
|
|
||||||
else:
|
|
||||||
component = create_setting_component(k)
|
|
||||||
component_dict[k] = component
|
|
||||||
components.append(component)
|
|
||||||
|
|
||||||
if current_tab is not None:
|
|
||||||
current_row.__exit__()
|
|
||||||
current_tab.__exit__()
|
|
||||||
|
|
||||||
with gr.TabItem("Defaults", id="defaults", elem_id="settings_tab_defaults"):
|
|
||||||
loadsave.create_ui()
|
|
||||||
|
|
||||||
with gr.TabItem("Actions", id="actions", elem_id="settings_tab_actions"):
|
|
||||||
request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
|
|
||||||
download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
|
|
||||||
reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
|
|
||||||
with gr.Row():
|
|
||||||
unload_sd_model = gr.Button(value='Unload SD checkpoint to free VRAM', elem_id="sett_unload_sd_model")
|
|
||||||
reload_sd_model = gr.Button(value='Reload the last SD checkpoint back into VRAM', elem_id="sett_reload_sd_model")
|
|
||||||
|
|
||||||
with gr.TabItem("Licenses", id="licenses", elem_id="settings_tab_licenses"):
|
|
||||||
gr.HTML(shared.html("licenses.html"), elem_id="licenses")
|
|
||||||
|
|
||||||
gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
|
|
||||||
|
|
||||||
|
|
||||||
def unload_sd_weights():
|
|
||||||
modules.sd_models.unload_model_weights()
|
|
||||||
|
|
||||||
def reload_sd_weights():
|
|
||||||
modules.sd_models.reload_model_weights()
|
|
||||||
|
|
||||||
unload_sd_model.click(
|
|
||||||
fn=unload_sd_weights,
|
|
||||||
inputs=[],
|
|
||||||
outputs=[]
|
|
||||||
)
|
|
||||||
|
|
||||||
reload_sd_model.click(
|
|
||||||
fn=reload_sd_weights,
|
|
||||||
inputs=[],
|
|
||||||
outputs=[]
|
|
||||||
)
|
|
||||||
|
|
||||||
request_notifications.click(
|
|
||||||
fn=lambda: None,
|
|
||||||
inputs=[],
|
|
||||||
outputs=[],
|
|
||||||
_js='function(){}'
|
|
||||||
)
|
|
||||||
|
|
||||||
download_localization.click(
|
|
||||||
fn=lambda: None,
|
|
||||||
inputs=[],
|
|
||||||
outputs=[],
|
|
||||||
_js='download_localization'
|
|
||||||
)
|
|
||||||
|
|
||||||
def reload_scripts():
|
|
||||||
modules.scripts.reload_script_body_only()
|
|
||||||
reload_javascript() # need to refresh the html page
|
|
||||||
|
|
||||||
reload_script_bodies.click(
|
|
||||||
fn=reload_scripts,
|
|
||||||
inputs=[],
|
|
||||||
outputs=[]
|
|
||||||
)
|
|
||||||
|
|
||||||
restart_gradio.click(
|
|
||||||
fn=shared.state.request_restart,
|
|
||||||
_js='restart_reload',
|
|
||||||
inputs=[],
|
|
||||||
outputs=[],
|
|
||||||
)
|
|
||||||
|
|
||||||
interfaces = [
|
interfaces = [
|
||||||
(txt2img_interface, "txt2img", "txt2img"),
|
(txt2img_interface, "txt2img", "txt2img"),
|
||||||
@ -1660,7 +1459,7 @@ def create_ui():
|
|||||||
]
|
]
|
||||||
|
|
||||||
interfaces += script_callbacks.ui_tabs_callback()
|
interfaces += script_callbacks.ui_tabs_callback()
|
||||||
interfaces += [(settings_interface, "Settings", "settings")]
|
interfaces += [(settings.interface, "Settings", "settings")]
|
||||||
|
|
||||||
extensions_interface = ui_extensions.create_ui()
|
extensions_interface = ui_extensions.create_ui()
|
||||||
interfaces += [(extensions_interface, "Extensions", "extensions")]
|
interfaces += [(extensions_interface, "Extensions", "extensions")]
|
||||||
@ -1670,10 +1469,7 @@ def create_ui():
|
|||||||
shared.tab_names.append(label)
|
shared.tab_names.append(label)
|
||||||
|
|
||||||
with gr.Blocks(theme=shared.gradio_theme, analytics_enabled=False, title="Stable Diffusion") as demo:
|
with gr.Blocks(theme=shared.gradio_theme, analytics_enabled=False, title="Stable Diffusion") as demo:
|
||||||
with gr.Row(elem_id="quicksettings", variant="compact"):
|
settings.add_quicksettings()
|
||||||
for _i, k, _item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])):
|
|
||||||
component = create_setting_component(k, is_quicksettings=True)
|
|
||||||
component_dict[k] = component
|
|
||||||
|
|
||||||
parameters_copypaste.connect_paste_params_buttons()
|
parameters_copypaste.connect_paste_params_buttons()
|
||||||
|
|
||||||
@ -1704,55 +1500,17 @@ def create_ui():
|
|||||||
footer = footer.format(versions=versions_html())
|
footer = footer.format(versions=versions_html())
|
||||||
gr.HTML(footer, elem_id="footer")
|
gr.HTML(footer, elem_id="footer")
|
||||||
|
|
||||||
text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False)
|
settings.add_functionality(demo)
|
||||||
settings_submit.click(
|
|
||||||
fn=wrap_gradio_call(run_settings, extra_outputs=[gr.update()]),
|
|
||||||
inputs=components,
|
|
||||||
outputs=[text_settings, result],
|
|
||||||
)
|
|
||||||
|
|
||||||
for _i, k, _item in quicksettings_list:
|
|
||||||
component = component_dict[k]
|
|
||||||
info = opts.data_labels[k]
|
|
||||||
|
|
||||||
change_handler = component.release if hasattr(component, 'release') else component.change
|
|
||||||
change_handler(
|
|
||||||
fn=lambda value, k=k: run_settings_single(value, key=k),
|
|
||||||
inputs=[component],
|
|
||||||
outputs=[component, text_settings],
|
|
||||||
show_progress=info.refresh is not None,
|
|
||||||
)
|
|
||||||
|
|
||||||
update_image_cfg_scale_visibility = lambda: gr.update(visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit")
|
update_image_cfg_scale_visibility = lambda: gr.update(visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit")
|
||||||
text_settings.change(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale])
|
settings.text_settings.change(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale])
|
||||||
demo.load(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale])
|
demo.load(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale])
|
||||||
|
|
||||||
button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False)
|
|
||||||
button_set_checkpoint.click(
|
|
||||||
fn=lambda value, _: run_settings_single(value, key='sd_model_checkpoint'),
|
|
||||||
_js="function(v){ var res = desiredCheckpointName; desiredCheckpointName = ''; return [res || v, null]; }",
|
|
||||||
inputs=[component_dict['sd_model_checkpoint'], dummy_component],
|
|
||||||
outputs=[component_dict['sd_model_checkpoint'], text_settings],
|
|
||||||
)
|
|
||||||
|
|
||||||
component_keys = [k for k in opts.data_labels.keys() if k in component_dict]
|
|
||||||
|
|
||||||
def get_settings_values():
|
|
||||||
return [get_value_for_setting(key) for key in component_keys]
|
|
||||||
|
|
||||||
demo.load(
|
|
||||||
fn=get_settings_values,
|
|
||||||
inputs=[],
|
|
||||||
outputs=[component_dict[k] for k in component_keys],
|
|
||||||
queue=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
def modelmerger(*args):
|
def modelmerger(*args):
|
||||||
try:
|
try:
|
||||||
results = modules.extras.run_modelmerger(*args)
|
results = modules.extras.run_modelmerger(*args)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Error loading/saving model file:", file=sys.stderr)
|
errors.report("Error loading/saving model file", exc_info=True)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
modules.sd_models.list_models() # to remove the potentially missing models from the list
|
modules.sd_models.list_models() # to remove the potentially missing models from the list
|
||||||
return [*[gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)], f"Error merging checkpoints: {e}"]
|
return [*[gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)], f"Error merging checkpoints: {e}"]
|
||||||
return results
|
return results
|
||||||
@ -1780,7 +1538,7 @@ def create_ui():
|
|||||||
primary_model_name,
|
primary_model_name,
|
||||||
secondary_model_name,
|
secondary_model_name,
|
||||||
tertiary_model_name,
|
tertiary_model_name,
|
||||||
component_dict['sd_model_checkpoint'],
|
settings.component_dict['sd_model_checkpoint'],
|
||||||
modelmerger_result,
|
modelmerger_result,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@ -1794,70 +1552,6 @@ def create_ui():
|
|||||||
return demo
|
return demo
|
||||||
|
|
||||||
|
|
||||||
def webpath(fn):
|
|
||||||
if fn.startswith(script_path):
|
|
||||||
web_path = os.path.relpath(fn, script_path).replace('\\', '/')
|
|
||||||
else:
|
|
||||||
web_path = os.path.abspath(fn)
|
|
||||||
|
|
||||||
return f'file={web_path}?{os.path.getmtime(fn)}'
|
|
||||||
|
|
||||||
|
|
||||||
def javascript_html():
|
|
||||||
# Ensure localization is in `window` before scripts
|
|
||||||
head = f'<script type="text/javascript">{localization.localization_js(shared.opts.localization)}</script>\n'
|
|
||||||
|
|
||||||
script_js = os.path.join(script_path, "script.js")
|
|
||||||
head += f'<script type="text/javascript" src="{webpath(script_js)}"></script>\n'
|
|
||||||
|
|
||||||
for script in modules.scripts.list_scripts("javascript", ".js"):
|
|
||||||
head += f'<script type="text/javascript" src="{webpath(script.path)}"></script>\n'
|
|
||||||
|
|
||||||
for script in modules.scripts.list_scripts("javascript", ".mjs"):
|
|
||||||
head += f'<script type="module" src="{webpath(script.path)}"></script>\n'
|
|
||||||
|
|
||||||
if cmd_opts.theme:
|
|
||||||
head += f'<script type="text/javascript">set_theme(\"{cmd_opts.theme}\");</script>\n'
|
|
||||||
|
|
||||||
return head
|
|
||||||
|
|
||||||
|
|
||||||
def css_html():
|
|
||||||
head = ""
|
|
||||||
|
|
||||||
def stylesheet(fn):
|
|
||||||
return f'<link rel="stylesheet" property="stylesheet" href="{webpath(fn)}">'
|
|
||||||
|
|
||||||
for cssfile in modules.scripts.list_files_with_name("style.css"):
|
|
||||||
if not os.path.isfile(cssfile):
|
|
||||||
continue
|
|
||||||
|
|
||||||
head += stylesheet(cssfile)
|
|
||||||
|
|
||||||
if os.path.exists(os.path.join(data_path, "user.css")):
|
|
||||||
head += stylesheet(os.path.join(data_path, "user.css"))
|
|
||||||
|
|
||||||
return head
|
|
||||||
|
|
||||||
|
|
||||||
def reload_javascript():
|
|
||||||
js = javascript_html()
|
|
||||||
css = css_html()
|
|
||||||
|
|
||||||
def template_response(*args, **kwargs):
|
|
||||||
res = shared.GradioTemplateResponseOriginal(*args, **kwargs)
|
|
||||||
res.body = res.body.replace(b'</head>', f'{js}</head>'.encode("utf8"))
|
|
||||||
res.body = res.body.replace(b'</body>', f'{css}</body>'.encode("utf8"))
|
|
||||||
res.init_headers()
|
|
||||||
return res
|
|
||||||
|
|
||||||
gradio.routes.templates.TemplateResponse = template_response
|
|
||||||
|
|
||||||
|
|
||||||
if not hasattr(shared, 'GradioTemplateResponseOriginal'):
|
|
||||||
shared.GradioTemplateResponseOriginal = gradio.routes.templates.TemplateResponse
|
|
||||||
|
|
||||||
|
|
||||||
def versions_html():
|
def versions_html():
|
||||||
import torch
|
import torch
|
||||||
import launch
|
import launch
|
||||||
|
@ -10,8 +10,11 @@ import subprocess as sp
|
|||||||
from modules import call_queue, shared
|
from modules import call_queue, shared
|
||||||
from modules.generation_parameters_copypaste import image_from_url_text
|
from modules.generation_parameters_copypaste import image_from_url_text
|
||||||
import modules.images
|
import modules.images
|
||||||
|
from modules.ui_components import ToolButton
|
||||||
|
|
||||||
|
|
||||||
folder_symbol = '\U0001f4c2' # 📂
|
folder_symbol = '\U0001f4c2' # 📂
|
||||||
|
refresh_symbol = '\U0001f504' # 🔄
|
||||||
|
|
||||||
|
|
||||||
def update_generation_info(generation_info, html_info, img_index):
|
def update_generation_info(generation_info, html_info, img_index):
|
||||||
@ -50,9 +53,10 @@ def save_files(js_data, images, do_make_zip, index):
|
|||||||
save_to_dirs = shared.opts.use_save_to_dirs_for_ui
|
save_to_dirs = shared.opts.use_save_to_dirs_for_ui
|
||||||
extension: str = shared.opts.samples_format
|
extension: str = shared.opts.samples_format
|
||||||
start_index = 0
|
start_index = 0
|
||||||
|
only_one = False
|
||||||
|
|
||||||
if index > -1 and shared.opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only
|
if index > -1 and shared.opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only
|
||||||
|
only_one = True
|
||||||
images = [images[index]]
|
images = [images[index]]
|
||||||
start_index = index
|
start_index = index
|
||||||
|
|
||||||
@ -70,6 +74,7 @@ def save_files(js_data, images, do_make_zip, index):
|
|||||||
is_grid = image_index < p.index_of_first_image
|
is_grid = image_index < p.index_of_first_image
|
||||||
i = 0 if is_grid else (image_index - p.index_of_first_image)
|
i = 0 if is_grid else (image_index - p.index_of_first_image)
|
||||||
|
|
||||||
|
p.batch_index = image_index-1
|
||||||
fullfn, txt_fullfn = modules.images.save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs)
|
fullfn, txt_fullfn = modules.images.save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs)
|
||||||
|
|
||||||
filename = os.path.relpath(fullfn, path)
|
filename = os.path.relpath(fullfn, path)
|
||||||
@ -83,7 +88,10 @@ def save_files(js_data, images, do_make_zip, index):
|
|||||||
|
|
||||||
# Make Zip
|
# Make Zip
|
||||||
if do_make_zip:
|
if do_make_zip:
|
||||||
zip_filepath = os.path.join(path, "images.zip")
|
zip_fileseed = p.all_seeds[index-1] if only_one else p.all_seeds[0]
|
||||||
|
namegen = modules.images.FilenameGenerator(p, zip_fileseed, p.all_prompts[0], image, True)
|
||||||
|
zip_filename = namegen.apply(shared.opts.grid_zip_filename_pattern or "[datetime]_[[model_name]]_[seed]-[seed_last]")
|
||||||
|
zip_filepath = os.path.join(path, f"{zip_filename}.zip")
|
||||||
|
|
||||||
from zipfile import ZipFile
|
from zipfile import ZipFile
|
||||||
with ZipFile(zip_filepath, "w") as zip_file:
|
with ZipFile(zip_filepath, "w") as zip_file:
|
||||||
@ -211,3 +219,23 @@ Requested path was: {f}
|
|||||||
))
|
))
|
||||||
|
|
||||||
return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log
|
return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log
|
||||||
|
|
||||||
|
|
||||||
|
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
|
||||||
|
def refresh():
|
||||||
|
refresh_method()
|
||||||
|
args = refreshed_args() if callable(refreshed_args) else refreshed_args
|
||||||
|
|
||||||
|
for k, v in args.items():
|
||||||
|
setattr(refresh_component, k, v)
|
||||||
|
|
||||||
|
return gr.update(**(args or {}))
|
||||||
|
|
||||||
|
refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id)
|
||||||
|
refresh_button.click(
|
||||||
|
fn=refresh,
|
||||||
|
inputs=[],
|
||||||
|
outputs=[refresh_component]
|
||||||
|
)
|
||||||
|
return refresh_button
|
||||||
|
|
||||||
|
@ -1,10 +1,8 @@
|
|||||||
import json
|
import json
|
||||||
import os.path
|
import os.path
|
||||||
import sys
|
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import traceback
|
|
||||||
|
|
||||||
import git
|
import git
|
||||||
|
|
||||||
@ -13,7 +11,7 @@ import html
|
|||||||
import shutil
|
import shutil
|
||||||
import errno
|
import errno
|
||||||
|
|
||||||
from modules import extensions, shared, paths, config_states
|
from modules import extensions, shared, paths, config_states, errors
|
||||||
from modules.paths_internal import config_states_dir
|
from modules.paths_internal import config_states_dir
|
||||||
from modules.call_queue import wrap_gradio_gpu_call
|
from modules.call_queue import wrap_gradio_gpu_call
|
||||||
|
|
||||||
@ -46,8 +44,7 @@ def apply_and_restart(disable_list, update_list, disable_all):
|
|||||||
try:
|
try:
|
||||||
ext.fetch_and_reset_hard()
|
ext.fetch_and_reset_hard()
|
||||||
except Exception:
|
except Exception:
|
||||||
print(f"Error getting updates for {ext.name}:", file=sys.stderr)
|
errors.report(f"Error getting updates for {ext.name}", exc_info=True)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
|
|
||||||
shared.opts.disabled_extensions = disabled
|
shared.opts.disabled_extensions = disabled
|
||||||
shared.opts.disable_all_extensions = disable_all
|
shared.opts.disable_all_extensions = disable_all
|
||||||
@ -113,8 +110,7 @@ def check_updates(id_task, disable_list):
|
|||||||
if 'FETCH_HEAD' not in str(e):
|
if 'FETCH_HEAD' not in str(e):
|
||||||
raise
|
raise
|
||||||
except Exception:
|
except Exception:
|
||||||
print(f"Error checking updates for {ext.name}:", file=sys.stderr)
|
errors.report(f"Error checking updates for {ext.name}", exc_info=True)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
|
|
||||||
shared.state.nextjob()
|
shared.state.nextjob()
|
||||||
|
|
||||||
@ -345,12 +341,12 @@ def install_extension_from_url(dirname, url, branch_name=None):
|
|||||||
shutil.rmtree(tmpdir, True)
|
shutil.rmtree(tmpdir, True)
|
||||||
if not branch_name:
|
if not branch_name:
|
||||||
# if no branch is specified, use the default branch
|
# if no branch is specified, use the default branch
|
||||||
with git.Repo.clone_from(url, tmpdir) as repo:
|
with git.Repo.clone_from(url, tmpdir, filter=['blob:none']) as repo:
|
||||||
repo.remote().fetch()
|
repo.remote().fetch()
|
||||||
for submodule in repo.submodules:
|
for submodule in repo.submodules:
|
||||||
submodule.update()
|
submodule.update()
|
||||||
else:
|
else:
|
||||||
with git.Repo.clone_from(url, tmpdir, branch=branch_name) as repo:
|
with git.Repo.clone_from(url, tmpdir, filter=['blob:none'], branch=branch_name) as repo:
|
||||||
repo.remote().fetch()
|
repo.remote().fetch()
|
||||||
for submodule in repo.submodules:
|
for submodule in repo.submodules:
|
||||||
submodule.update()
|
submodule.update()
|
||||||
@ -490,8 +486,14 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text="
|
|||||||
|
|
||||||
|
|
||||||
def preload_extensions_git_metadata():
|
def preload_extensions_git_metadata():
|
||||||
|
t0 = time.time()
|
||||||
for extension in extensions.extensions:
|
for extension in extensions.extensions:
|
||||||
extension.read_info_from_repo()
|
extension.read_info_from_repo()
|
||||||
|
print(
|
||||||
|
f"preload_extensions_git_metadata for "
|
||||||
|
f"{len(extensions.extensions)} extensions took "
|
||||||
|
f"{time.time() - t0:.2f}s"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_ui():
|
def create_ui():
|
||||||
|
69
modules/ui_gradio_extensions.py
Normal file
69
modules/ui_gradio_extensions.py
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
import os
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
from modules import localization, shared, scripts
|
||||||
|
from modules.paths import script_path, data_path
|
||||||
|
|
||||||
|
|
||||||
|
def webpath(fn):
|
||||||
|
if fn.startswith(script_path):
|
||||||
|
web_path = os.path.relpath(fn, script_path).replace('\\', '/')
|
||||||
|
else:
|
||||||
|
web_path = os.path.abspath(fn)
|
||||||
|
|
||||||
|
return f'file={web_path}?{os.path.getmtime(fn)}'
|
||||||
|
|
||||||
|
|
||||||
|
def javascript_html():
|
||||||
|
# Ensure localization is in `window` before scripts
|
||||||
|
head = f'<script type="text/javascript">{localization.localization_js(shared.opts.localization)}</script>\n'
|
||||||
|
|
||||||
|
script_js = os.path.join(script_path, "script.js")
|
||||||
|
head += f'<script type="text/javascript" src="{webpath(script_js)}"></script>\n'
|
||||||
|
|
||||||
|
for script in scripts.list_scripts("javascript", ".js"):
|
||||||
|
head += f'<script type="text/javascript" src="{webpath(script.path)}"></script>\n'
|
||||||
|
|
||||||
|
for script in scripts.list_scripts("javascript", ".mjs"):
|
||||||
|
head += f'<script type="module" src="{webpath(script.path)}"></script>\n'
|
||||||
|
|
||||||
|
if shared.cmd_opts.theme:
|
||||||
|
head += f'<script type="text/javascript">set_theme(\"{shared.cmd_opts.theme}\");</script>\n'
|
||||||
|
|
||||||
|
return head
|
||||||
|
|
||||||
|
|
||||||
|
def css_html():
|
||||||
|
head = ""
|
||||||
|
|
||||||
|
def stylesheet(fn):
|
||||||
|
return f'<link rel="stylesheet" property="stylesheet" href="{webpath(fn)}">'
|
||||||
|
|
||||||
|
for cssfile in scripts.list_files_with_name("style.css"):
|
||||||
|
if not os.path.isfile(cssfile):
|
||||||
|
continue
|
||||||
|
|
||||||
|
head += stylesheet(cssfile)
|
||||||
|
|
||||||
|
if os.path.exists(os.path.join(data_path, "user.css")):
|
||||||
|
head += stylesheet(os.path.join(data_path, "user.css"))
|
||||||
|
|
||||||
|
return head
|
||||||
|
|
||||||
|
|
||||||
|
def reload_javascript():
|
||||||
|
js = javascript_html()
|
||||||
|
css = css_html()
|
||||||
|
|
||||||
|
def template_response(*args, **kwargs):
|
||||||
|
res = shared.GradioTemplateResponseOriginal(*args, **kwargs)
|
||||||
|
res.body = res.body.replace(b'</head>', f'{js}</head>'.encode("utf8"))
|
||||||
|
res.body = res.body.replace(b'</body>', f'{css}</body>'.encode("utf8"))
|
||||||
|
res.init_headers()
|
||||||
|
return res
|
||||||
|
|
||||||
|
gr.routes.templates.TemplateResponse = template_response
|
||||||
|
|
||||||
|
|
||||||
|
if not hasattr(shared, 'GradioTemplateResponseOriginal'):
|
||||||
|
shared.GradioTemplateResponseOriginal = gr.routes.templates.TemplateResponse
|
263
modules/ui_settings.py
Normal file
263
modules/ui_settings.py
Normal file
@ -0,0 +1,263 @@
|
|||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
from modules import ui_common, shared, script_callbacks, scripts, sd_models
|
||||||
|
from modules.call_queue import wrap_gradio_call
|
||||||
|
from modules.shared import opts
|
||||||
|
from modules.ui_components import FormRow
|
||||||
|
from modules.ui_gradio_extensions import reload_javascript
|
||||||
|
|
||||||
|
|
||||||
|
def get_value_for_setting(key):
|
||||||
|
value = getattr(opts, key)
|
||||||
|
|
||||||
|
info = opts.data_labels[key]
|
||||||
|
args = info.component_args() if callable(info.component_args) else info.component_args or {}
|
||||||
|
args = {k: v for k, v in args.items() if k not in {'precision'}}
|
||||||
|
|
||||||
|
return gr.update(value=value, **args)
|
||||||
|
|
||||||
|
|
||||||
|
def create_setting_component(key, is_quicksettings=False):
|
||||||
|
def fun():
|
||||||
|
return opts.data[key] if key in opts.data else opts.data_labels[key].default
|
||||||
|
|
||||||
|
info = opts.data_labels[key]
|
||||||
|
t = type(info.default)
|
||||||
|
|
||||||
|
args = info.component_args() if callable(info.component_args) else info.component_args
|
||||||
|
|
||||||
|
if info.component is not None:
|
||||||
|
comp = info.component
|
||||||
|
elif t == str:
|
||||||
|
comp = gr.Textbox
|
||||||
|
elif t == int:
|
||||||
|
comp = gr.Number
|
||||||
|
elif t == bool:
|
||||||
|
comp = gr.Checkbox
|
||||||
|
else:
|
||||||
|
raise Exception(f'bad options item type: {t} for key {key}')
|
||||||
|
|
||||||
|
elem_id = f"setting_{key}"
|
||||||
|
|
||||||
|
if info.refresh is not None:
|
||||||
|
if is_quicksettings:
|
||||||
|
res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
|
||||||
|
ui_common.create_refresh_button(res, info.refresh, info.component_args, f"refresh_{key}")
|
||||||
|
else:
|
||||||
|
with FormRow():
|
||||||
|
res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
|
||||||
|
ui_common.create_refresh_button(res, info.refresh, info.component_args, f"refresh_{key}")
|
||||||
|
else:
|
||||||
|
res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
class UiSettings:
|
||||||
|
submit = None
|
||||||
|
result = None
|
||||||
|
interface = None
|
||||||
|
components = None
|
||||||
|
component_dict = None
|
||||||
|
dummy_component = None
|
||||||
|
quicksettings_list = None
|
||||||
|
quicksettings_names = None
|
||||||
|
text_settings = None
|
||||||
|
|
||||||
|
def run_settings(self, *args):
|
||||||
|
changed = []
|
||||||
|
|
||||||
|
for key, value, comp in zip(opts.data_labels.keys(), args, self.components):
|
||||||
|
assert comp == self.dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}"
|
||||||
|
|
||||||
|
for key, value, comp in zip(opts.data_labels.keys(), args, self.components):
|
||||||
|
if comp == self.dummy_component:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if opts.set(key, value):
|
||||||
|
changed.append(key)
|
||||||
|
|
||||||
|
try:
|
||||||
|
opts.save(shared.config_filename)
|
||||||
|
except RuntimeError:
|
||||||
|
return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.'
|
||||||
|
return opts.dumpjson(), f'{len(changed)} settings changed{": " if len(changed) > 0 else ""}{", ".join(changed)}.'
|
||||||
|
|
||||||
|
def run_settings_single(self, value, key):
|
||||||
|
if not opts.same_type(value, opts.data_labels[key].default):
|
||||||
|
return gr.update(visible=True), opts.dumpjson()
|
||||||
|
|
||||||
|
if not opts.set(key, value):
|
||||||
|
return gr.update(value=getattr(opts, key)), opts.dumpjson()
|
||||||
|
|
||||||
|
opts.save(shared.config_filename)
|
||||||
|
|
||||||
|
return get_value_for_setting(key), opts.dumpjson()
|
||||||
|
|
||||||
|
def create_ui(self, loadsave, dummy_component):
|
||||||
|
self.components = []
|
||||||
|
self.component_dict = {}
|
||||||
|
self.dummy_component = dummy_component
|
||||||
|
|
||||||
|
shared.settings_components = self.component_dict
|
||||||
|
|
||||||
|
script_callbacks.ui_settings_callback()
|
||||||
|
opts.reorder()
|
||||||
|
|
||||||
|
with gr.Blocks(analytics_enabled=False) as settings_interface:
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column(scale=6):
|
||||||
|
self.submit = gr.Button(value="Apply settings", variant='primary', elem_id="settings_submit")
|
||||||
|
with gr.Column():
|
||||||
|
restart_gradio = gr.Button(value='Reload UI', variant='primary', elem_id="settings_restart_gradio")
|
||||||
|
|
||||||
|
self.result = gr.HTML(elem_id="settings_result")
|
||||||
|
|
||||||
|
self.quicksettings_names = opts.quicksettings_list
|
||||||
|
self.quicksettings_names = {x: i for i, x in enumerate(self.quicksettings_names) if x != 'quicksettings'}
|
||||||
|
|
||||||
|
self.quicksettings_list = []
|
||||||
|
|
||||||
|
previous_section = None
|
||||||
|
current_tab = None
|
||||||
|
current_row = None
|
||||||
|
with gr.Tabs(elem_id="settings"):
|
||||||
|
for i, (k, item) in enumerate(opts.data_labels.items()):
|
||||||
|
section_must_be_skipped = item.section[0] is None
|
||||||
|
|
||||||
|
if previous_section != item.section and not section_must_be_skipped:
|
||||||
|
elem_id, text = item.section
|
||||||
|
|
||||||
|
if current_tab is not None:
|
||||||
|
current_row.__exit__()
|
||||||
|
current_tab.__exit__()
|
||||||
|
|
||||||
|
gr.Group()
|
||||||
|
current_tab = gr.TabItem(elem_id=f"settings_{elem_id}", label=text)
|
||||||
|
current_tab.__enter__()
|
||||||
|
current_row = gr.Column(variant='compact')
|
||||||
|
current_row.__enter__()
|
||||||
|
|
||||||
|
previous_section = item.section
|
||||||
|
|
||||||
|
if k in self.quicksettings_names and not shared.cmd_opts.freeze_settings:
|
||||||
|
self.quicksettings_list.append((i, k, item))
|
||||||
|
self.components.append(dummy_component)
|
||||||
|
elif section_must_be_skipped:
|
||||||
|
self.components.append(dummy_component)
|
||||||
|
else:
|
||||||
|
component = create_setting_component(k)
|
||||||
|
self.component_dict[k] = component
|
||||||
|
self.components.append(component)
|
||||||
|
|
||||||
|
if current_tab is not None:
|
||||||
|
current_row.__exit__()
|
||||||
|
current_tab.__exit__()
|
||||||
|
|
||||||
|
with gr.TabItem("Defaults", id="defaults", elem_id="settings_tab_defaults"):
|
||||||
|
loadsave.create_ui()
|
||||||
|
|
||||||
|
with gr.TabItem("Actions", id="actions", elem_id="settings_tab_actions"):
|
||||||
|
request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
|
||||||
|
download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
|
||||||
|
reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
|
||||||
|
with gr.Row():
|
||||||
|
unload_sd_model = gr.Button(value='Unload SD checkpoint to free VRAM', elem_id="sett_unload_sd_model")
|
||||||
|
reload_sd_model = gr.Button(value='Reload the last SD checkpoint back into VRAM', elem_id="sett_reload_sd_model")
|
||||||
|
|
||||||
|
with gr.TabItem("Licenses", id="licenses", elem_id="settings_tab_licenses"):
|
||||||
|
gr.HTML(shared.html("licenses.html"), elem_id="licenses")
|
||||||
|
|
||||||
|
gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
|
||||||
|
|
||||||
|
self.text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False)
|
||||||
|
|
||||||
|
unload_sd_model.click(
|
||||||
|
fn=sd_models.unload_model_weights,
|
||||||
|
inputs=[],
|
||||||
|
outputs=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
reload_sd_model.click(
|
||||||
|
fn=sd_models.reload_model_weights,
|
||||||
|
inputs=[],
|
||||||
|
outputs=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
request_notifications.click(
|
||||||
|
fn=lambda: None,
|
||||||
|
inputs=[],
|
||||||
|
outputs=[],
|
||||||
|
_js='function(){}'
|
||||||
|
)
|
||||||
|
|
||||||
|
download_localization.click(
|
||||||
|
fn=lambda: None,
|
||||||
|
inputs=[],
|
||||||
|
outputs=[],
|
||||||
|
_js='download_localization'
|
||||||
|
)
|
||||||
|
|
||||||
|
def reload_scripts():
|
||||||
|
scripts.reload_script_body_only()
|
||||||
|
reload_javascript() # need to refresh the html page
|
||||||
|
|
||||||
|
reload_script_bodies.click(
|
||||||
|
fn=reload_scripts,
|
||||||
|
inputs=[],
|
||||||
|
outputs=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
restart_gradio.click(
|
||||||
|
fn=shared.state.request_restart,
|
||||||
|
_js='restart_reload',
|
||||||
|
inputs=[],
|
||||||
|
outputs=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.interface = settings_interface
|
||||||
|
|
||||||
|
def add_quicksettings(self):
|
||||||
|
with gr.Row(elem_id="quicksettings", variant="compact"):
|
||||||
|
for _i, k, _item in sorted(self.quicksettings_list, key=lambda x: self.quicksettings_names.get(x[1], x[0])):
|
||||||
|
component = create_setting_component(k, is_quicksettings=True)
|
||||||
|
self.component_dict[k] = component
|
||||||
|
|
||||||
|
def add_functionality(self, demo):
|
||||||
|
self.submit.click(
|
||||||
|
fn=wrap_gradio_call(lambda *args: self.run_settings(*args), extra_outputs=[gr.update()]),
|
||||||
|
inputs=self.components,
|
||||||
|
outputs=[self.text_settings, self.result],
|
||||||
|
)
|
||||||
|
|
||||||
|
for _i, k, _item in self.quicksettings_list:
|
||||||
|
component = self.component_dict[k]
|
||||||
|
info = opts.data_labels[k]
|
||||||
|
|
||||||
|
change_handler = component.release if hasattr(component, 'release') else component.change
|
||||||
|
change_handler(
|
||||||
|
fn=lambda value, k=k: self.run_settings_single(value, key=k),
|
||||||
|
inputs=[component],
|
||||||
|
outputs=[component, self.text_settings],
|
||||||
|
show_progress=info.refresh is not None,
|
||||||
|
)
|
||||||
|
|
||||||
|
button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False)
|
||||||
|
button_set_checkpoint.click(
|
||||||
|
fn=lambda value, _: self.run_settings_single(value, key='sd_model_checkpoint'),
|
||||||
|
_js="function(v){ var res = desiredCheckpointName; desiredCheckpointName = ''; return [res || v, null]; }",
|
||||||
|
inputs=[self.component_dict['sd_model_checkpoint'], self.dummy_component],
|
||||||
|
outputs=[self.component_dict['sd_model_checkpoint'], self.text_settings],
|
||||||
|
)
|
||||||
|
|
||||||
|
component_keys = [k for k in opts.data_labels.keys() if k in self.component_dict]
|
||||||
|
|
||||||
|
def get_settings_values():
|
||||||
|
return [get_value_for_setting(key) for key in component_keys]
|
||||||
|
|
||||||
|
demo.load(
|
||||||
|
fn=get_settings_values,
|
||||||
|
inputs=[],
|
||||||
|
outputs=[self.component_dict[k] for k in component_keys],
|
||||||
|
queue=False,
|
||||||
|
)
|
@ -3,7 +3,7 @@ import tempfile
|
|||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import gradio as gr
|
import gradio.components
|
||||||
|
|
||||||
from PIL import PngImagePlugin
|
from PIL import PngImagePlugin
|
||||||
|
|
||||||
@ -31,13 +31,16 @@ def check_tmp_file(gradio, filename):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def save_pil_to_file(pil_image, dir=None):
|
def save_pil_to_file(self, pil_image, dir=None):
|
||||||
already_saved_as = getattr(pil_image, 'already_saved_as', None)
|
already_saved_as = getattr(pil_image, 'already_saved_as', None)
|
||||||
if already_saved_as and os.path.isfile(already_saved_as):
|
if already_saved_as and os.path.isfile(already_saved_as):
|
||||||
register_tmp_file(shared.demo, already_saved_as)
|
register_tmp_file(shared.demo, already_saved_as)
|
||||||
|
filename = already_saved_as
|
||||||
|
|
||||||
file_obj = Savedfile(f'{already_saved_as}?{os.path.getmtime(already_saved_as)}')
|
if not shared.opts.save_images_add_number:
|
||||||
return file_obj
|
filename += f'?{os.path.getmtime(already_saved_as)}'
|
||||||
|
|
||||||
|
return filename
|
||||||
|
|
||||||
if shared.opts.temp_dir != "":
|
if shared.opts.temp_dir != "":
|
||||||
dir = shared.opts.temp_dir
|
dir = shared.opts.temp_dir
|
||||||
@ -51,11 +54,11 @@ def save_pil_to_file(pil_image, dir=None):
|
|||||||
|
|
||||||
file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir)
|
file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir)
|
||||||
pil_image.save(file_obj, pnginfo=(metadata if use_metadata else None))
|
pil_image.save(file_obj, pnginfo=(metadata if use_metadata else None))
|
||||||
return file_obj
|
return file_obj.name
|
||||||
|
|
||||||
|
|
||||||
# override save to file function so that it also writes PNG info
|
# override save to file function so that it also writes PNG info
|
||||||
gr.processing_utils.save_pil_to_file = save_pil_to_file
|
gradio.components.IOComponent.pil_to_temp_file = save_pil_to_file
|
||||||
|
|
||||||
|
|
||||||
def on_tmpdir_changed():
|
def on_tmpdir_changed():
|
||||||
|
@ -53,8 +53,8 @@ class Upscaler:
|
|||||||
|
|
||||||
def upscale(self, img: PIL.Image, scale, selected_model: str = None):
|
def upscale(self, img: PIL.Image, scale, selected_model: str = None):
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
dest_w = int(img.width * scale)
|
dest_w = round((img.width * scale - 4) / 8) * 8
|
||||||
dest_h = int(img.height * scale)
|
dest_h = round((img.height * scale - 4) / 8) * 8
|
||||||
|
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
shape = (img.width, img.height)
|
shape = (img.width, img.height)
|
||||||
|
@ -1,32 +1,32 @@
|
|||||||
astunparse
|
GitPython
|
||||||
blendmodes
|
|
||||||
accelerate
|
|
||||||
basicsr
|
|
||||||
gfpgan
|
|
||||||
gradio==3.31.0
|
|
||||||
numpy
|
|
||||||
omegaconf
|
|
||||||
opencv-contrib-python
|
|
||||||
requests
|
|
||||||
piexif
|
|
||||||
Pillow
|
Pillow
|
||||||
pytorch_lightning==1.7.7
|
accelerate
|
||||||
realesrgan
|
|
||||||
scikit-image>=0.19
|
basicsr
|
||||||
timm==0.4.12
|
blendmodes
|
||||||
transformers==4.25.1
|
|
||||||
torch
|
|
||||||
einops
|
|
||||||
jsonmerge
|
|
||||||
clean-fid
|
clean-fid
|
||||||
resize-right
|
einops
|
||||||
torchdiffeq
|
gfpgan
|
||||||
|
gradio==3.32.0
|
||||||
|
inflection
|
||||||
|
jsonmerge
|
||||||
kornia
|
kornia
|
||||||
lark
|
lark
|
||||||
inflection
|
numpy
|
||||||
GitPython
|
omegaconf
|
||||||
torchsde
|
|
||||||
safetensors
|
piexif
|
||||||
psutil
|
psutil
|
||||||
rich
|
pytorch_lightning
|
||||||
|
realesrgan
|
||||||
|
requests
|
||||||
|
resize-right
|
||||||
|
|
||||||
|
safetensors
|
||||||
|
scikit-image>=0.19
|
||||||
|
timm
|
||||||
tomesd
|
tomesd
|
||||||
|
torch
|
||||||
|
torchdiffeq
|
||||||
|
torchsde
|
||||||
|
transformers==4.25.1
|
||||||
|
@ -1,29 +1,30 @@
|
|||||||
blendmodes==2022
|
GitPython==3.1.30
|
||||||
transformers==4.25.1
|
Pillow==9.5.0
|
||||||
accelerate==0.18.0
|
accelerate==0.18.0
|
||||||
basicsr==1.4.2
|
basicsr==1.4.2
|
||||||
gfpgan==1.3.8
|
blendmodes==2022
|
||||||
gradio==3.31.0
|
|
||||||
numpy==1.23.5
|
|
||||||
Pillow==9.5.0
|
|
||||||
realesrgan==0.3.0
|
|
||||||
torch
|
|
||||||
omegaconf==2.2.3
|
|
||||||
pytorch_lightning==1.9.4
|
|
||||||
scikit-image==0.20.0
|
|
||||||
timm==0.6.7
|
|
||||||
piexif==1.1.3
|
|
||||||
einops==0.4.1
|
|
||||||
jsonmerge==1.8.0
|
|
||||||
clean-fid==0.1.35
|
clean-fid==0.1.35
|
||||||
resize-right==0.0.2
|
einops==0.4.1
|
||||||
torchdiffeq==0.2.3
|
fastapi==0.94.0
|
||||||
|
gfpgan==1.3.8
|
||||||
|
gradio==3.32.0
|
||||||
|
httpcore<=0.15
|
||||||
|
inflection==0.5.1
|
||||||
|
jsonmerge==1.8.0
|
||||||
kornia==0.6.7
|
kornia==0.6.7
|
||||||
lark==1.1.2
|
lark==1.1.2
|
||||||
inflection==0.5.1
|
numpy==1.23.5
|
||||||
GitPython==3.1.30
|
omegaconf==2.2.3
|
||||||
torchsde==0.2.5
|
piexif==1.1.3
|
||||||
|
psutil~=5.9.5
|
||||||
|
pytorch_lightning==1.9.4
|
||||||
|
realesrgan==0.3.0
|
||||||
|
resize-right==0.0.2
|
||||||
safetensors==0.3.1
|
safetensors==0.3.1
|
||||||
httpcore<=0.15
|
scikit-image==0.20.0
|
||||||
fastapi==0.94.0
|
timm==0.6.7
|
||||||
tomesd==0.1.2
|
tomesd==0.1.2
|
||||||
|
torch
|
||||||
|
torchdiffeq==0.2.3
|
||||||
|
torchsde==0.2.5
|
||||||
|
transformers==4.25.1
|
||||||
|
73
script.js
73
script.js
@ -10,44 +10,94 @@ function gradioApp() {
|
|||||||
return elem.shadowRoot ? elem.shadowRoot : elem;
|
return elem.shadowRoot ? elem.shadowRoot : elem;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the currently selected top-level UI tab button (e.g. the button that says "Extras").
|
||||||
|
*/
|
||||||
function get_uiCurrentTab() {
|
function get_uiCurrentTab() {
|
||||||
return gradioApp().querySelector('#tabs button.selected');
|
return gradioApp().querySelector('#tabs > .tab-nav > button.selected');
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the first currently visible top-level UI tab content (e.g. the div hosting the "txt2img" UI).
|
||||||
|
*/
|
||||||
function get_uiCurrentTabContent() {
|
function get_uiCurrentTabContent() {
|
||||||
return gradioApp().querySelector('.tabitem[id^=tab_]:not([style*="display: none"])');
|
return gradioApp().querySelector('#tabs > .tabitem[id^=tab_]:not([style*="display: none"])');
|
||||||
}
|
}
|
||||||
|
|
||||||
var uiUpdateCallbacks = [];
|
var uiUpdateCallbacks = [];
|
||||||
|
var uiAfterUpdateCallbacks = [];
|
||||||
var uiLoadedCallbacks = [];
|
var uiLoadedCallbacks = [];
|
||||||
var uiTabChangeCallbacks = [];
|
var uiTabChangeCallbacks = [];
|
||||||
var optionsChangedCallbacks = [];
|
var optionsChangedCallbacks = [];
|
||||||
|
var uiAfterUpdateTimeout = null;
|
||||||
var uiCurrentTab = null;
|
var uiCurrentTab = null;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Register callback to be called at each UI update.
|
||||||
|
* The callback receives an array of MutationRecords as an argument.
|
||||||
|
*/
|
||||||
function onUiUpdate(callback) {
|
function onUiUpdate(callback) {
|
||||||
uiUpdateCallbacks.push(callback);
|
uiUpdateCallbacks.push(callback);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Register callback to be called soon after UI updates.
|
||||||
|
* The callback receives no arguments.
|
||||||
|
*
|
||||||
|
* This is preferred over `onUiUpdate` if you don't need
|
||||||
|
* access to the MutationRecords, as your function will
|
||||||
|
* not be called quite as often.
|
||||||
|
*/
|
||||||
|
function onAfterUiUpdate(callback) {
|
||||||
|
uiAfterUpdateCallbacks.push(callback);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Register callback to be called when the UI is loaded.
|
||||||
|
* The callback receives no arguments.
|
||||||
|
*/
|
||||||
function onUiLoaded(callback) {
|
function onUiLoaded(callback) {
|
||||||
uiLoadedCallbacks.push(callback);
|
uiLoadedCallbacks.push(callback);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Register callback to be called when the UI tab is changed.
|
||||||
|
* The callback receives no arguments.
|
||||||
|
*/
|
||||||
function onUiTabChange(callback) {
|
function onUiTabChange(callback) {
|
||||||
uiTabChangeCallbacks.push(callback);
|
uiTabChangeCallbacks.push(callback);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Register callback to be called when the options are changed.
|
||||||
|
* The callback receives no arguments.
|
||||||
|
* @param callback
|
||||||
|
*/
|
||||||
function onOptionsChanged(callback) {
|
function onOptionsChanged(callback) {
|
||||||
optionsChangedCallbacks.push(callback);
|
optionsChangedCallbacks.push(callback);
|
||||||
}
|
}
|
||||||
|
|
||||||
function runCallback(x, m) {
|
function executeCallbacks(queue, arg) {
|
||||||
try {
|
for (const callback of queue) {
|
||||||
x(m);
|
try {
|
||||||
} catch (e) {
|
callback(arg);
|
||||||
(console.error || console.log).call(console, e.message, e);
|
} catch (e) {
|
||||||
|
console.error("error running callback", callback, ":", e);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
function executeCallbacks(queue, m) {
|
|
||||||
queue.forEach(function(x) {
|
/**
|
||||||
runCallback(x, m);
|
* Schedule the execution of the callbacks registered with onAfterUiUpdate.
|
||||||
});
|
* The callbacks are executed after a short while, unless another call to this function
|
||||||
|
* is made before that time. IOW, the callbacks are executed only once, even
|
||||||
|
* when there are multiple mutations observed.
|
||||||
|
*/
|
||||||
|
function scheduleAfterUiUpdateCallbacks() {
|
||||||
|
clearTimeout(uiAfterUpdateTimeout);
|
||||||
|
uiAfterUpdateTimeout = setTimeout(function() {
|
||||||
|
executeCallbacks(uiAfterUpdateCallbacks);
|
||||||
|
}, 200);
|
||||||
}
|
}
|
||||||
|
|
||||||
var executedOnLoaded = false;
|
var executedOnLoaded = false;
|
||||||
@ -60,6 +110,7 @@ document.addEventListener("DOMContentLoaded", function() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
executeCallbacks(uiUpdateCallbacks, m);
|
executeCallbacks(uiUpdateCallbacks, m);
|
||||||
|
scheduleAfterUiUpdateCallbacks();
|
||||||
const newTab = get_uiCurrentTab();
|
const newTab = get_uiCurrentTab();
|
||||||
if (newTab && (newTab !== uiCurrentTab)) {
|
if (newTab && (newTab !== uiCurrentTab)) {
|
||||||
uiCurrentTab = newTab;
|
uiCurrentTab = newTab;
|
||||||
|
@ -1,13 +1,11 @@
|
|||||||
import copy
|
import copy
|
||||||
import random
|
import random
|
||||||
import sys
|
|
||||||
import traceback
|
|
||||||
import shlex
|
import shlex
|
||||||
|
|
||||||
import modules.scripts as scripts
|
import modules.scripts as scripts
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from modules import sd_samplers
|
from modules import sd_samplers, errors
|
||||||
from modules.processing import Processed, process_images
|
from modules.processing import Processed, process_images
|
||||||
from modules.shared import state
|
from modules.shared import state
|
||||||
|
|
||||||
@ -136,8 +134,7 @@ class Script(scripts.Script):
|
|||||||
try:
|
try:
|
||||||
args = cmdargs(line)
|
args = cmdargs(line)
|
||||||
except Exception:
|
except Exception:
|
||||||
print(f"Error parsing line {line} as commandline:", file=sys.stderr)
|
errors.report(f"Error parsing line {line} as commandline", exc_info=True)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
args = {"prompt": line}
|
args = {"prompt": line}
|
||||||
else:
|
else:
|
||||||
args = {"prompt": line}
|
args = {"prompt": line}
|
||||||
|
@ -10,7 +10,7 @@ import numpy as np
|
|||||||
import modules.scripts as scripts
|
import modules.scripts as scripts
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from modules import images, sd_samplers, processing, sd_models, sd_vae
|
from modules import images, sd_samplers, processing, sd_models, sd_vae, sd_samplers_kdiffusion
|
||||||
from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img
|
from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img
|
||||||
from modules.shared import opts, state
|
from modules.shared import opts, state
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
@ -220,6 +220,10 @@ axis_options = [
|
|||||||
AxisOption("Sigma min", float, apply_field("s_tmin")),
|
AxisOption("Sigma min", float, apply_field("s_tmin")),
|
||||||
AxisOption("Sigma max", float, apply_field("s_tmax")),
|
AxisOption("Sigma max", float, apply_field("s_tmax")),
|
||||||
AxisOption("Sigma noise", float, apply_field("s_noise")),
|
AxisOption("Sigma noise", float, apply_field("s_noise")),
|
||||||
|
AxisOption("Schedule type", str, apply_override("k_sched_type"), choices=lambda: list(sd_samplers_kdiffusion.k_diffusion_scheduler)),
|
||||||
|
AxisOption("Schedule min sigma", float, apply_override("sigma_min")),
|
||||||
|
AxisOption("Schedule max sigma", float, apply_override("sigma_max")),
|
||||||
|
AxisOption("Schedule rho", float, apply_override("rho")),
|
||||||
AxisOption("Eta", float, apply_field("eta")),
|
AxisOption("Eta", float, apply_field("eta")),
|
||||||
AxisOption("Clip skip", int, apply_clip_skip),
|
AxisOption("Clip skip", int, apply_clip_skip),
|
||||||
AxisOption("Denoising", float, apply_field("denoising_strength")),
|
AxisOption("Denoising", float, apply_field("denoising_strength")),
|
||||||
|
18
style.css
18
style.css
@ -760,13 +760,22 @@ footer {
|
|||||||
.extra-network-cards .card .metadata-button, .extra-network-thumbs .card .metadata-button{
|
.extra-network-cards .card .metadata-button, .extra-network-thumbs .card .metadata-button{
|
||||||
display: none;
|
display: none;
|
||||||
position: absolute;
|
position: absolute;
|
||||||
right: 0;
|
|
||||||
color: white;
|
color: white;
|
||||||
|
right: 0;
|
||||||
|
}
|
||||||
|
.extra-network-cards .card .metadata-button {
|
||||||
text-shadow: 2px 2px 3px black;
|
text-shadow: 2px 2px 3px black;
|
||||||
padding: 0.25em;
|
padding: 0.25em;
|
||||||
font-size: 22pt;
|
font-size: 22pt;
|
||||||
width: 1.5em;
|
width: 1.5em;
|
||||||
}
|
}
|
||||||
|
.extra-network-thumbs .card .metadata-button {
|
||||||
|
text-shadow: 1px 1px 2px black;
|
||||||
|
padding: 0;
|
||||||
|
font-size: 16pt;
|
||||||
|
width: 1em;
|
||||||
|
top: -0.25em;
|
||||||
|
}
|
||||||
.extra-network-cards .card:hover .metadata-button, .extra-network-thumbs .card:hover .metadata-button{
|
.extra-network-cards .card:hover .metadata-button, .extra-network-thumbs .card:hover .metadata-button{
|
||||||
display: inline-block;
|
display: inline-block;
|
||||||
}
|
}
|
||||||
@ -791,6 +800,13 @@ footer {
|
|||||||
position: relative;
|
position: relative;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.extra-network-thumbs .card .preview{
|
||||||
|
position: absolute;
|
||||||
|
object-fit: cover;
|
||||||
|
width: 100%;
|
||||||
|
height:100%;
|
||||||
|
}
|
||||||
|
|
||||||
.extra-network-thumbs .card:hover .additional a {
|
.extra-network-thumbs .card:hover .additional a {
|
||||||
display: inline-block;
|
display: inline-block;
|
||||||
}
|
}
|
||||||
|
@ -36,7 +36,6 @@
|
|||||||
|
|
||||||
# Fixed git commits
|
# Fixed git commits
|
||||||
#export STABLE_DIFFUSION_COMMIT_HASH=""
|
#export STABLE_DIFFUSION_COMMIT_HASH=""
|
||||||
#export TAMING_TRANSFORMERS_COMMIT_HASH=""
|
|
||||||
#export CODEFORMER_COMMIT_HASH=""
|
#export CODEFORMER_COMMIT_HASH=""
|
||||||
#export BLIP_COMMIT_HASH=""
|
#export BLIP_COMMIT_HASH=""
|
||||||
|
|
||||||
|
44
webui.py
44
webui.py
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
@ -18,7 +20,7 @@ import logging
|
|||||||
|
|
||||||
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
|
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
|
||||||
|
|
||||||
from modules import paths, timer, import_hook, errors # noqa: F401
|
from modules import paths, timer, import_hook, errors, devices # noqa: F401
|
||||||
|
|
||||||
startup_timer = timer.startup_timer
|
startup_timer = timer.startup_timer
|
||||||
|
|
||||||
@ -56,6 +58,7 @@ import modules.sd_hijack
|
|||||||
import modules.sd_hijack_optimizations
|
import modules.sd_hijack_optimizations
|
||||||
import modules.sd_models
|
import modules.sd_models
|
||||||
import modules.sd_vae
|
import modules.sd_vae
|
||||||
|
import modules.sd_unet
|
||||||
import modules.txt2img
|
import modules.txt2img
|
||||||
import modules.script_callbacks
|
import modules.script_callbacks
|
||||||
import modules.textual_inversion.textual_inversion
|
import modules.textual_inversion.textual_inversion
|
||||||
@ -132,7 +135,7 @@ there are reports of issues with training tab on the latest version.
|
|||||||
Use --skip-version-check commandline argument to disable this check.
|
Use --skip-version-check commandline argument to disable this check.
|
||||||
""".strip())
|
""".strip())
|
||||||
|
|
||||||
expected_xformers_version = "0.0.17"
|
expected_xformers_version = "0.0.20"
|
||||||
if shared.xformers_available:
|
if shared.xformers_available:
|
||||||
import xformers
|
import xformers
|
||||||
|
|
||||||
@ -289,9 +292,25 @@ def initialize_rest(*, reload_script_modules=False):
|
|||||||
modules.sd_hijack.list_optimizers()
|
modules.sd_hijack.list_optimizers()
|
||||||
startup_timer.record("scripts list_optimizers")
|
startup_timer.record("scripts list_optimizers")
|
||||||
|
|
||||||
# load model in parallel to other startup stuff
|
modules.sd_unet.list_unets()
|
||||||
# (when reloading, this does nothing)
|
startup_timer.record("scripts list_unets")
|
||||||
Thread(target=lambda: shared.sd_model).start()
|
|
||||||
|
def load_model():
|
||||||
|
"""
|
||||||
|
Accesses shared.sd_model property to load model.
|
||||||
|
After it's available, if it has been loaded before this access by some extension,
|
||||||
|
its optimization may be None because the list of optimizaers has neet been filled
|
||||||
|
by that time, so we apply optimization again.
|
||||||
|
"""
|
||||||
|
|
||||||
|
shared.sd_model # noqa: B018
|
||||||
|
|
||||||
|
if modules.sd_hijack.current_optimizer is None:
|
||||||
|
modules.sd_hijack.apply_optimizations()
|
||||||
|
|
||||||
|
Thread(target=load_model).start()
|
||||||
|
|
||||||
|
Thread(target=devices.first_time_calculation).start()
|
||||||
|
|
||||||
shared.reload_hypernetworks()
|
shared.reload_hypernetworks()
|
||||||
startup_timer.record("reload hypernetworks")
|
startup_timer.record("reload hypernetworks")
|
||||||
@ -368,17 +387,6 @@ def webui():
|
|||||||
|
|
||||||
gradio_auth_creds = list(get_gradio_auth_creds()) or None
|
gradio_auth_creds = list(get_gradio_auth_creds()) or None
|
||||||
|
|
||||||
# this restores the missing /docs endpoint
|
|
||||||
if launch_api and not hasattr(FastAPI, 'original_setup'):
|
|
||||||
# TODO: replace this with `launch(app_kwargs=...)` if https://github.com/gradio-app/gradio/pull/4282 gets merged
|
|
||||||
def fastapi_setup(self):
|
|
||||||
self.docs_url = "/docs"
|
|
||||||
self.redoc_url = "/redoc"
|
|
||||||
self.original_setup()
|
|
||||||
|
|
||||||
FastAPI.original_setup = FastAPI.setup
|
|
||||||
FastAPI.setup = fastapi_setup
|
|
||||||
|
|
||||||
app, local_url, share_url = shared.demo.launch(
|
app, local_url, share_url = shared.demo.launch(
|
||||||
share=cmd_opts.share,
|
share=cmd_opts.share,
|
||||||
server_name=server_name,
|
server_name=server_name,
|
||||||
@ -391,6 +399,10 @@ def webui():
|
|||||||
inbrowser=cmd_opts.autolaunch,
|
inbrowser=cmd_opts.autolaunch,
|
||||||
prevent_thread_lock=True,
|
prevent_thread_lock=True,
|
||||||
allowed_paths=cmd_opts.gradio_allowed_path,
|
allowed_paths=cmd_opts.gradio_allowed_path,
|
||||||
|
app_kwargs={
|
||||||
|
"docs_url": "/docs",
|
||||||
|
"redoc_url": "/redoc",
|
||||||
|
},
|
||||||
)
|
)
|
||||||
if cmd_opts.add_stop_route:
|
if cmd_opts.add_stop_route:
|
||||||
app.add_route("/_stop", stop_route, methods=["POST"])
|
app.add_route("/_stop", stop_route, methods=["POST"])
|
||||||
|
9
webui.sh
9
webui.sh
@ -124,9 +124,12 @@ case "$gpu_info" in
|
|||||||
*)
|
*)
|
||||||
;;
|
;;
|
||||||
esac
|
esac
|
||||||
if echo "$gpu_info" | grep -q "AMD" && [[ -z "${TORCH_COMMAND}" ]]
|
if ! echo "$gpu_info" | grep -q "NVIDIA";
|
||||||
then
|
then
|
||||||
export TORCH_COMMAND="pip install torch==2.0.1+rocm5.4.2 torchvision==0.15.2+rocm5.4.2 --index-url https://download.pytorch.org/whl/rocm5.4.2"
|
if echo "$gpu_info" | grep -q "AMD" && [[ -z "${TORCH_COMMAND}" ]]
|
||||||
|
then
|
||||||
|
export TORCH_COMMAND="pip install torch==2.0.1+rocm5.4.2 torchvision==0.15.2+rocm5.4.2 --index-url https://download.pytorch.org/whl/rocm5.4.2"
|
||||||
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
for preq in "${GIT}" "${python_cmd}"
|
for preq in "${GIT}" "${python_cmd}"
|
||||||
@ -190,7 +193,7 @@ fi
|
|||||||
# Try using TCMalloc on Linux
|
# Try using TCMalloc on Linux
|
||||||
prepare_tcmalloc() {
|
prepare_tcmalloc() {
|
||||||
if [[ "${OSTYPE}" == "linux"* ]] && [[ -z "${NO_TCMALLOC}" ]] && [[ -z "${LD_PRELOAD}" ]]; then
|
if [[ "${OSTYPE}" == "linux"* ]] && [[ -z "${NO_TCMALLOC}" ]] && [[ -z "${LD_PRELOAD}" ]]; then
|
||||||
TCMALLOC="$(ldconfig -p | grep -Po "libtcmalloc.so.\d" | head -n 1)"
|
TCMALLOC="$(PATH=/usr/sbin:$PATH ldconfig -p | grep -Po "libtcmalloc(_minimal|)\.so\.\d" | head -n 1)"
|
||||||
if [[ ! -z "${TCMALLOC}" ]]; then
|
if [[ ! -z "${TCMALLOC}" ]]; then
|
||||||
echo "Using TCMalloc: ${TCMALLOC}"
|
echo "Using TCMalloc: ${TCMALLOC}"
|
||||||
export LD_PRELOAD="${TCMALLOC}"
|
export LD_PRELOAD="${TCMALLOC}"
|
||||||
|
Loading…
Reference in New Issue
Block a user