Merge branch 'dev' into sync-req
This commit is contained in:
commit
177d4b6828
4
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
4
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
@ -43,8 +43,8 @@ body:
|
|||||||
- type: input
|
- type: input
|
||||||
id: commit
|
id: commit
|
||||||
attributes:
|
attributes:
|
||||||
label: Commit where the problem happens
|
label: Version or Commit where the problem happens
|
||||||
description: Which commit are you running ? (Do not write *Latest version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Commit** link at the bottom of the UI, or from the cmd/terminal if you can't launch it.)
|
description: "Which webui version or commit are you running ? (Do not write *Latest Version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Version: v1.2.3** link at the bottom of the UI, or from the cmd/terminal if you can't launch it.)"
|
||||||
validations:
|
validations:
|
||||||
required: true
|
required: true
|
||||||
- type: dropdown
|
- type: dropdown
|
||||||
|
@ -1,9 +1,8 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
from basicsr.utils.download_util import load_file_from_url
|
from basicsr.utils.download_util import load_file_from_url
|
||||||
|
|
||||||
|
from modules.errors import print_error
|
||||||
from modules.upscaler import Upscaler, UpscalerData
|
from modules.upscaler import Upscaler, UpscalerData
|
||||||
from ldsr_model_arch import LDSR
|
from ldsr_model_arch import LDSR
|
||||||
from modules import shared, script_callbacks
|
from modules import shared, script_callbacks
|
||||||
@ -51,10 +50,8 @@ class UpscalerLDSR(Upscaler):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
return LDSR(model, yaml)
|
return LDSR(model, yaml)
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
print("Error importing LDSR:", file=sys.stderr)
|
print_error("Error importing LDSR", exc_info=True)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def do_upscale(self, img, path):
|
def do_upscale(self, img, path):
|
||||||
|
@ -10,7 +10,7 @@ from contextlib import contextmanager
|
|||||||
from torch.optim.lr_scheduler import LambdaLR
|
from torch.optim.lr_scheduler import LambdaLR
|
||||||
|
|
||||||
from ldm.modules.ema import LitEma
|
from ldm.modules.ema import LitEma
|
||||||
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
|
from vqvae_quantize import VectorQuantizer2 as VectorQuantizer
|
||||||
from ldm.modules.diffusionmodules.model import Encoder, Decoder
|
from ldm.modules.diffusionmodules.model import Encoder, Decoder
|
||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
|
|
||||||
|
147
extensions-builtin/LDSR/vqvae_quantize.py
Normal file
147
extensions-builtin/LDSR/vqvae_quantize.py
Normal file
@ -0,0 +1,147 @@
|
|||||||
|
# Vendored from https://raw.githubusercontent.com/CompVis/taming-transformers/24268930bf1dce879235a7fddd0b2355b84d7ea6/taming/modules/vqvae/quantize.py,
|
||||||
|
# where the license is as follows:
|
||||||
|
#
|
||||||
|
# Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer
|
||||||
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
# of this software and associated documentation files (the "Software"), to deal
|
||||||
|
# in the Software without restriction, including without limitation the rights
|
||||||
|
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
# copies of the Software, and to permit persons to whom the Software is
|
||||||
|
# furnished to do so, subject to the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be included in all
|
||||||
|
# copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||||
|
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||||
|
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||||
|
# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
||||||
|
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
|
||||||
|
# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
|
||||||
|
# OR OTHER DEALINGS IN THE SOFTWARE./
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import numpy as np
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
|
||||||
|
class VectorQuantizer2(nn.Module):
|
||||||
|
"""
|
||||||
|
Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
|
||||||
|
avoids costly matrix multiplications and allows for post-hoc remapping of indices.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# NOTE: due to a bug the beta term was applied to the wrong term. for
|
||||||
|
# backwards compatibility we use the buggy version by default, but you can
|
||||||
|
# specify legacy=False to fix it.
|
||||||
|
def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random",
|
||||||
|
sane_index_shape=False, legacy=True):
|
||||||
|
super().__init__()
|
||||||
|
self.n_e = n_e
|
||||||
|
self.e_dim = e_dim
|
||||||
|
self.beta = beta
|
||||||
|
self.legacy = legacy
|
||||||
|
|
||||||
|
self.embedding = nn.Embedding(self.n_e, self.e_dim)
|
||||||
|
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
||||||
|
|
||||||
|
self.remap = remap
|
||||||
|
if self.remap is not None:
|
||||||
|
self.register_buffer("used", torch.tensor(np.load(self.remap)))
|
||||||
|
self.re_embed = self.used.shape[0]
|
||||||
|
self.unknown_index = unknown_index # "random" or "extra" or integer
|
||||||
|
if self.unknown_index == "extra":
|
||||||
|
self.unknown_index = self.re_embed
|
||||||
|
self.re_embed = self.re_embed + 1
|
||||||
|
print(f"Remapping {self.n_e} indices to {self.re_embed} indices. "
|
||||||
|
f"Using {self.unknown_index} for unknown indices.")
|
||||||
|
else:
|
||||||
|
self.re_embed = n_e
|
||||||
|
|
||||||
|
self.sane_index_shape = sane_index_shape
|
||||||
|
|
||||||
|
def remap_to_used(self, inds):
|
||||||
|
ishape = inds.shape
|
||||||
|
assert len(ishape) > 1
|
||||||
|
inds = inds.reshape(ishape[0], -1)
|
||||||
|
used = self.used.to(inds)
|
||||||
|
match = (inds[:, :, None] == used[None, None, ...]).long()
|
||||||
|
new = match.argmax(-1)
|
||||||
|
unknown = match.sum(2) < 1
|
||||||
|
if self.unknown_index == "random":
|
||||||
|
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
|
||||||
|
else:
|
||||||
|
new[unknown] = self.unknown_index
|
||||||
|
return new.reshape(ishape)
|
||||||
|
|
||||||
|
def unmap_to_all(self, inds):
|
||||||
|
ishape = inds.shape
|
||||||
|
assert len(ishape) > 1
|
||||||
|
inds = inds.reshape(ishape[0], -1)
|
||||||
|
used = self.used.to(inds)
|
||||||
|
if self.re_embed > self.used.shape[0]: # extra token
|
||||||
|
inds[inds >= self.used.shape[0]] = 0 # simply set to zero
|
||||||
|
back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
|
||||||
|
return back.reshape(ishape)
|
||||||
|
|
||||||
|
def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
|
||||||
|
assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
|
||||||
|
assert rescale_logits is False, "Only for interface compatible with Gumbel"
|
||||||
|
assert return_logits is False, "Only for interface compatible with Gumbel"
|
||||||
|
# reshape z -> (batch, height, width, channel) and flatten
|
||||||
|
z = rearrange(z, 'b c h w -> b h w c').contiguous()
|
||||||
|
z_flattened = z.view(-1, self.e_dim)
|
||||||
|
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
||||||
|
|
||||||
|
d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
|
||||||
|
torch.sum(self.embedding.weight ** 2, dim=1) - 2 * \
|
||||||
|
torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
|
||||||
|
|
||||||
|
min_encoding_indices = torch.argmin(d, dim=1)
|
||||||
|
z_q = self.embedding(min_encoding_indices).view(z.shape)
|
||||||
|
perplexity = None
|
||||||
|
min_encodings = None
|
||||||
|
|
||||||
|
# compute loss for embedding
|
||||||
|
if not self.legacy:
|
||||||
|
loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + \
|
||||||
|
torch.mean((z_q - z.detach()) ** 2)
|
||||||
|
else:
|
||||||
|
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * \
|
||||||
|
torch.mean((z_q - z.detach()) ** 2)
|
||||||
|
|
||||||
|
# preserve gradients
|
||||||
|
z_q = z + (z_q - z).detach()
|
||||||
|
|
||||||
|
# reshape back to match original input shape
|
||||||
|
z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
|
||||||
|
|
||||||
|
if self.remap is not None:
|
||||||
|
min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis
|
||||||
|
min_encoding_indices = self.remap_to_used(min_encoding_indices)
|
||||||
|
min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
|
||||||
|
|
||||||
|
if self.sane_index_shape:
|
||||||
|
min_encoding_indices = min_encoding_indices.reshape(
|
||||||
|
z_q.shape[0], z_q.shape[2], z_q.shape[3])
|
||||||
|
|
||||||
|
return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
|
||||||
|
|
||||||
|
def get_codebook_entry(self, indices, shape):
|
||||||
|
# shape specifying (batch, height, width, channel)
|
||||||
|
if self.remap is not None:
|
||||||
|
indices = indices.reshape(shape[0], -1) # add batch axis
|
||||||
|
indices = self.unmap_to_all(indices)
|
||||||
|
indices = indices.reshape(-1) # flatten again
|
||||||
|
|
||||||
|
# get quantized latent vectors
|
||||||
|
z_q = self.embedding(indices)
|
||||||
|
|
||||||
|
if shape is not None:
|
||||||
|
z_q = z_q.view(shape)
|
||||||
|
# reshape back to match original input shape
|
||||||
|
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
||||||
|
|
||||||
|
return z_q
|
@ -1,6 +1,5 @@
|
|||||||
import os.path
|
import os.path
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
|
||||||
|
|
||||||
import PIL.Image
|
import PIL.Image
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -12,6 +11,8 @@ from basicsr.utils.download_util import load_file_from_url
|
|||||||
import modules.upscaler
|
import modules.upscaler
|
||||||
from modules import devices, modelloader, script_callbacks
|
from modules import devices, modelloader, script_callbacks
|
||||||
from scunet_model_arch import SCUNet as net
|
from scunet_model_arch import SCUNet as net
|
||||||
|
|
||||||
|
from modules.errors import print_error
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
|
|
||||||
|
|
||||||
@ -38,8 +39,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
|
|||||||
scaler_data = modules.upscaler.UpscalerData(name, file, self, 4)
|
scaler_data = modules.upscaler.UpscalerData(name, file, self, 4)
|
||||||
scalers.append(scaler_data)
|
scalers.append(scaler_data)
|
||||||
except Exception:
|
except Exception:
|
||||||
print(f"Error loading ScuNET model: {file}", file=sys.stderr)
|
print_error(f"Error loading ScuNET model: {file}", exc_info=True)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
if add_model2:
|
if add_model2:
|
||||||
scaler_data2 = modules.upscaler.UpscalerData(self.model_name2, self.model_url2, self)
|
scaler_data2 = modules.upscaler.UpscalerData(self.model_name2, self.model_url2, self)
|
||||||
scalers.append(scaler_data2)
|
scalers.append(scaler_data2)
|
||||||
|
431
extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js
Normal file
431
extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js
Normal file
@ -0,0 +1,431 @@
|
|||||||
|
// Main
|
||||||
|
|
||||||
|
// Helper functions
|
||||||
|
// Get active tab
|
||||||
|
function getActiveTab(elements, all = false) {
|
||||||
|
const tabs = elements.img2imgTabs.querySelectorAll("button");
|
||||||
|
|
||||||
|
if (all) return tabs;
|
||||||
|
|
||||||
|
for (let tab of tabs) {
|
||||||
|
if (tab.classList.contains("selected")) {
|
||||||
|
return tab;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
onUiLoaded(async() => {
|
||||||
|
const hotkeysConfig = {
|
||||||
|
resetZoom: "KeyR",
|
||||||
|
fitToScreen: "KeyS",
|
||||||
|
moveKey: "KeyF",
|
||||||
|
overlap: "KeyO"
|
||||||
|
};
|
||||||
|
|
||||||
|
let isMoving = false;
|
||||||
|
let mouseX, mouseY;
|
||||||
|
|
||||||
|
const elementIDs = {
|
||||||
|
sketch: "#img2img_sketch",
|
||||||
|
inpaint: "#img2maskimg",
|
||||||
|
inpaintSketch: "#inpaint_sketch",
|
||||||
|
img2imgTabs: "#mode_img2img .tab-nav"
|
||||||
|
};
|
||||||
|
|
||||||
|
async function getElements() {
|
||||||
|
const elements = await Promise.all(
|
||||||
|
Object.values(elementIDs).map(id => document.querySelector(id))
|
||||||
|
);
|
||||||
|
return Object.fromEntries(
|
||||||
|
Object.keys(elementIDs).map((key, index) => [key, elements[index]])
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
const elements = await getElements();
|
||||||
|
|
||||||
|
function applyZoomAndPan(targetElement, elemId) {
|
||||||
|
targetElement.style.transformOrigin = "0 0";
|
||||||
|
let [zoomLevel, panX, panY] = [1, 0, 0];
|
||||||
|
let fullScreenMode = false;
|
||||||
|
|
||||||
|
// In the course of research, it was found that the tag img is very harmful when zooming and creates white canvases. This hack allows you to almost never think about this problem, it has no effect on webui.
|
||||||
|
function fixCanvas() {
|
||||||
|
const activeTab = getActiveTab(elements).textContent.trim();
|
||||||
|
|
||||||
|
if (activeTab !== "img2img") {
|
||||||
|
const img = targetElement.querySelector(`${elemId} img`);
|
||||||
|
|
||||||
|
if (img && img.style.display !== "none") {
|
||||||
|
img.style.display = "none";
|
||||||
|
img.style.visibility = "hidden";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset the zoom level and pan position of the target element to their initial values
|
||||||
|
function resetZoom() {
|
||||||
|
zoomLevel = 1;
|
||||||
|
panX = 0;
|
||||||
|
panY = 0;
|
||||||
|
|
||||||
|
fixCanvas();
|
||||||
|
targetElement.style.transform = `scale(${zoomLevel}) translate(${panX}px, ${panY}px)`;
|
||||||
|
|
||||||
|
const canvas = gradioApp().querySelector(
|
||||||
|
`${elemId} canvas[key="interface"]`
|
||||||
|
);
|
||||||
|
|
||||||
|
toggleOverlap("off");
|
||||||
|
fullScreenMode = false;
|
||||||
|
|
||||||
|
if (
|
||||||
|
canvas &&
|
||||||
|
parseFloat(canvas.style.width) > 865 &&
|
||||||
|
parseFloat(targetElement.style.width) > 865
|
||||||
|
) {
|
||||||
|
fitToElement();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
targetElement.style.width = "";
|
||||||
|
if (canvas) {
|
||||||
|
targetElement.style.height = canvas.style.height;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Toggle the zIndex of the target element between two values, allowing it to overlap or be overlapped by other elements
|
||||||
|
function toggleOverlap(forced = "") {
|
||||||
|
const zIndex1 = "0";
|
||||||
|
const zIndex2 = "998";
|
||||||
|
|
||||||
|
targetElement.style.zIndex =
|
||||||
|
targetElement.style.zIndex !== zIndex2 ? zIndex2 : zIndex1;
|
||||||
|
|
||||||
|
if (forced === "off") {
|
||||||
|
targetElement.style.zIndex = zIndex1;
|
||||||
|
} else if (forced === "on") {
|
||||||
|
targetElement.style.zIndex = zIndex2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Adjust the brush size based on the deltaY value from a mouse wheel event
|
||||||
|
function adjustBrushSize(
|
||||||
|
elemId,
|
||||||
|
deltaY,
|
||||||
|
withoutValue = false,
|
||||||
|
percentage = 5
|
||||||
|
) {
|
||||||
|
const input =
|
||||||
|
gradioApp().querySelector(
|
||||||
|
`${elemId} input[aria-label='Brush radius']`
|
||||||
|
) ||
|
||||||
|
gradioApp().querySelector(
|
||||||
|
`${elemId} button[aria-label="Use brush"]`
|
||||||
|
);
|
||||||
|
|
||||||
|
if (input) {
|
||||||
|
input.click();
|
||||||
|
if (!withoutValue) {
|
||||||
|
const maxValue =
|
||||||
|
parseFloat(input.getAttribute("max")) || 100;
|
||||||
|
const changeAmount = maxValue * (percentage / 100);
|
||||||
|
const newValue =
|
||||||
|
parseFloat(input.value) +
|
||||||
|
(deltaY > 0 ? -changeAmount : changeAmount);
|
||||||
|
input.value = Math.min(Math.max(newValue, 0), maxValue);
|
||||||
|
input.dispatchEvent(new Event("change"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset zoom when uploading a new image
|
||||||
|
const fileInput = gradioApp().querySelector(
|
||||||
|
`${elemId} input[type="file"][accept="image/*"].svelte-116rqfv`
|
||||||
|
);
|
||||||
|
fileInput.addEventListener("click", resetZoom);
|
||||||
|
|
||||||
|
// Update the zoom level and pan position of the target element based on the values of the zoomLevel, panX and panY variables
|
||||||
|
function updateZoom(newZoomLevel, mouseX, mouseY) {
|
||||||
|
newZoomLevel = Math.max(0.5, Math.min(newZoomLevel, 15));
|
||||||
|
panX += mouseX - (mouseX * newZoomLevel) / zoomLevel;
|
||||||
|
panY += mouseY - (mouseY * newZoomLevel) / zoomLevel;
|
||||||
|
|
||||||
|
targetElement.style.transformOrigin = "0 0";
|
||||||
|
targetElement.style.transform = `translate(${panX}px, ${panY}px) scale(${newZoomLevel})`;
|
||||||
|
|
||||||
|
toggleOverlap("on");
|
||||||
|
return newZoomLevel;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Change the zoom level based on user interaction
|
||||||
|
function changeZoomLevel(operation, e) {
|
||||||
|
if (e.shiftKey) {
|
||||||
|
e.preventDefault();
|
||||||
|
|
||||||
|
let zoomPosX, zoomPosY;
|
||||||
|
let delta = 0.2;
|
||||||
|
if (zoomLevel > 7) {
|
||||||
|
delta = 0.9;
|
||||||
|
} else if (zoomLevel > 2) {
|
||||||
|
delta = 0.6;
|
||||||
|
}
|
||||||
|
|
||||||
|
zoomPosX = e.clientX;
|
||||||
|
zoomPosY = e.clientY;
|
||||||
|
|
||||||
|
fullScreenMode = false;
|
||||||
|
zoomLevel = updateZoom(
|
||||||
|
zoomLevel + (operation === "+" ? delta : -delta),
|
||||||
|
zoomPosX - targetElement.getBoundingClientRect().left,
|
||||||
|
zoomPosY - targetElement.getBoundingClientRect().top
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This function fits the target element to the screen by calculating
|
||||||
|
* the required scale and offsets. It also updates the global variables
|
||||||
|
* zoomLevel, panX, and panY to reflect the new state.
|
||||||
|
*/
|
||||||
|
|
||||||
|
function fitToElement() {
|
||||||
|
//Reset Zoom
|
||||||
|
targetElement.style.transform = `translate(${0}px, ${0}px) scale(${1})`;
|
||||||
|
|
||||||
|
// Get element and screen dimensions
|
||||||
|
const elementWidth = targetElement.offsetWidth;
|
||||||
|
const elementHeight = targetElement.offsetHeight;
|
||||||
|
const parentElement = targetElement.parentElement;
|
||||||
|
const screenWidth = parentElement.clientWidth;
|
||||||
|
const screenHeight = parentElement.clientHeight;
|
||||||
|
|
||||||
|
// Get element's coordinates relative to the parent element
|
||||||
|
const elementRect = targetElement.getBoundingClientRect();
|
||||||
|
const parentRect = parentElement.getBoundingClientRect();
|
||||||
|
const elementX = elementRect.x - parentRect.x;
|
||||||
|
|
||||||
|
// Calculate scale and offsets
|
||||||
|
const scaleX = screenWidth / elementWidth;
|
||||||
|
const scaleY = screenHeight / elementHeight;
|
||||||
|
const scale = Math.min(scaleX, scaleY);
|
||||||
|
|
||||||
|
const transformOrigin =
|
||||||
|
window.getComputedStyle(targetElement).transformOrigin;
|
||||||
|
const [originX, originY] = transformOrigin.split(" ");
|
||||||
|
const originXValue = parseFloat(originX);
|
||||||
|
const originYValue = parseFloat(originY);
|
||||||
|
|
||||||
|
const offsetX =
|
||||||
|
(screenWidth - elementWidth * scale) / 2 -
|
||||||
|
originXValue * (1 - scale);
|
||||||
|
const offsetY =
|
||||||
|
(screenHeight - elementHeight * scale) / 2.5 -
|
||||||
|
originYValue * (1 - scale);
|
||||||
|
|
||||||
|
// Apply scale and offsets to the element
|
||||||
|
targetElement.style.transform = `translate(${offsetX}px, ${offsetY}px) scale(${scale})`;
|
||||||
|
|
||||||
|
// Update global variables
|
||||||
|
zoomLevel = scale;
|
||||||
|
panX = offsetX;
|
||||||
|
panY = offsetY;
|
||||||
|
|
||||||
|
fullScreenMode = false;
|
||||||
|
toggleOverlap("off");
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This function fits the target element to the screen by calculating
|
||||||
|
* the required scale and offsets. It also updates the global variables
|
||||||
|
* zoomLevel, panX, and panY to reflect the new state.
|
||||||
|
*/
|
||||||
|
|
||||||
|
// Fullscreen mode
|
||||||
|
function fitToScreen() {
|
||||||
|
const canvas = gradioApp().querySelector(
|
||||||
|
`${elemId} canvas[key="interface"]`
|
||||||
|
);
|
||||||
|
|
||||||
|
if (!canvas) return;
|
||||||
|
|
||||||
|
if (canvas.offsetWidth > 862) {
|
||||||
|
targetElement.style.width = canvas.offsetWidth + "px";
|
||||||
|
}
|
||||||
|
|
||||||
|
if (fullScreenMode) {
|
||||||
|
resetZoom();
|
||||||
|
fullScreenMode = false;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
//Reset Zoom
|
||||||
|
targetElement.style.transform = `translate(${0}px, ${0}px) scale(${1})`;
|
||||||
|
|
||||||
|
// Get scrollbar width to right-align the image
|
||||||
|
const scrollbarWidth = window.innerWidth - document.documentElement.clientWidth;
|
||||||
|
|
||||||
|
// Get element and screen dimensions
|
||||||
|
const elementWidth = targetElement.offsetWidth;
|
||||||
|
const elementHeight = targetElement.offsetHeight;
|
||||||
|
const screenWidth = window.innerWidth - scrollbarWidth;
|
||||||
|
const screenHeight = window.innerHeight;
|
||||||
|
|
||||||
|
// Get element's coordinates relative to the page
|
||||||
|
const elementRect = targetElement.getBoundingClientRect();
|
||||||
|
const elementY = elementRect.y;
|
||||||
|
const elementX = elementRect.x;
|
||||||
|
|
||||||
|
// Calculate scale and offsets
|
||||||
|
const scaleX = screenWidth / elementWidth;
|
||||||
|
const scaleY = screenHeight / elementHeight;
|
||||||
|
const scale = Math.min(scaleX, scaleY);
|
||||||
|
|
||||||
|
// Get the current transformOrigin
|
||||||
|
const computedStyle = window.getComputedStyle(targetElement);
|
||||||
|
const transformOrigin = computedStyle.transformOrigin;
|
||||||
|
const [originX, originY] = transformOrigin.split(" ");
|
||||||
|
const originXValue = parseFloat(originX);
|
||||||
|
const originYValue = parseFloat(originY);
|
||||||
|
|
||||||
|
// Calculate offsets with respect to the transformOrigin
|
||||||
|
const offsetX =
|
||||||
|
(screenWidth - elementWidth * scale) / 2 -
|
||||||
|
elementX -
|
||||||
|
originXValue * (1 - scale);
|
||||||
|
const offsetY =
|
||||||
|
(screenHeight - elementHeight * scale) / 2 -
|
||||||
|
elementY -
|
||||||
|
originYValue * (1 - scale);
|
||||||
|
|
||||||
|
// Apply scale and offsets to the element
|
||||||
|
targetElement.style.transform = `translate(${offsetX}px, ${offsetY}px) scale(${scale})`;
|
||||||
|
|
||||||
|
// Update global variables
|
||||||
|
zoomLevel = scale;
|
||||||
|
panX = offsetX;
|
||||||
|
panY = offsetY;
|
||||||
|
|
||||||
|
fullScreenMode = true;
|
||||||
|
toggleOverlap("on");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle keydown events
|
||||||
|
function handleKeyDown(event) {
|
||||||
|
const hotkeyActions = {
|
||||||
|
[hotkeysConfig.resetZoom]: resetZoom,
|
||||||
|
[hotkeysConfig.overlap]: toggleOverlap,
|
||||||
|
[hotkeysConfig.fitToScreen]: fitToScreen
|
||||||
|
// [hotkeysConfig.moveKey] : moveCanvas,
|
||||||
|
};
|
||||||
|
|
||||||
|
const action = hotkeyActions[event.code];
|
||||||
|
if (action) {
|
||||||
|
event.preventDefault();
|
||||||
|
action(event);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get Mouse position
|
||||||
|
function getMousePosition(e) {
|
||||||
|
mouseX = e.offsetX;
|
||||||
|
mouseY = e.offsetY;
|
||||||
|
}
|
||||||
|
|
||||||
|
targetElement.addEventListener("mousemove", getMousePosition);
|
||||||
|
|
||||||
|
// Handle events only inside the targetElement
|
||||||
|
let isKeyDownHandlerAttached = false;
|
||||||
|
|
||||||
|
function handleMouseMove() {
|
||||||
|
if (!isKeyDownHandlerAttached) {
|
||||||
|
document.addEventListener("keydown", handleKeyDown);
|
||||||
|
isKeyDownHandlerAttached = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function handleMouseLeave() {
|
||||||
|
if (isKeyDownHandlerAttached) {
|
||||||
|
document.removeEventListener("keydown", handleKeyDown);
|
||||||
|
isKeyDownHandlerAttached = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add mouse event handlers
|
||||||
|
targetElement.addEventListener("mousemove", handleMouseMove);
|
||||||
|
targetElement.addEventListener("mouseleave", handleMouseLeave);
|
||||||
|
|
||||||
|
// Reset zoom when click on another tab
|
||||||
|
elements.img2imgTabs.addEventListener("click", resetZoom);
|
||||||
|
elements.img2imgTabs.addEventListener("click", () => {
|
||||||
|
// targetElement.style.width = "";
|
||||||
|
if (parseInt(targetElement.style.width) > 865) {
|
||||||
|
setTimeout(fitToElement, 0);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
targetElement.addEventListener("wheel", e => {
|
||||||
|
// change zoom level
|
||||||
|
const operation = e.deltaY > 0 ? "-" : "+";
|
||||||
|
changeZoomLevel(operation, e);
|
||||||
|
|
||||||
|
// Handle brush size adjustment with ctrl key pressed
|
||||||
|
if (e.ctrlKey || e.metaKey) {
|
||||||
|
e.preventDefault();
|
||||||
|
|
||||||
|
// Increase or decrease brush size based on scroll direction
|
||||||
|
adjustBrushSize(elemId, e.deltaY);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Handle the move event for pan functionality. Updates the panX and panY variables and applies the new transform to the target element.
|
||||||
|
* @param {MouseEvent} e - The mouse event.
|
||||||
|
*/
|
||||||
|
function handleMoveKeyDown(e) {
|
||||||
|
if (e.code === hotkeysConfig.moveKey) {
|
||||||
|
if (!e.ctrlKey && !e.metaKey) {
|
||||||
|
isMoving = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function handleMoveKeyUp(e) {
|
||||||
|
if (e.code === hotkeysConfig.moveKey) {
|
||||||
|
isMoving = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
document.addEventListener("keydown", handleMoveKeyDown);
|
||||||
|
document.addEventListener("keyup", handleMoveKeyUp);
|
||||||
|
|
||||||
|
// Detect zoom level and update the pan speed.
|
||||||
|
function updatePanPosition(movementX, movementY) {
|
||||||
|
let panSpeed = 1.5;
|
||||||
|
|
||||||
|
if (zoomLevel > 8) {
|
||||||
|
panSpeed = 2.5;
|
||||||
|
}
|
||||||
|
|
||||||
|
panX = panX + movementX * panSpeed;
|
||||||
|
panY = panY + movementY * panSpeed;
|
||||||
|
|
||||||
|
targetElement.style.transform = `translate(${panX}px, ${panY}px) scale(${zoomLevel})`;
|
||||||
|
toggleOverlap("on");
|
||||||
|
}
|
||||||
|
|
||||||
|
function handleMoveByKey(e) {
|
||||||
|
if (isMoving) {
|
||||||
|
updatePanPosition(e.movementX, e.movementY);
|
||||||
|
targetElement.style.pointerEvents = "none";
|
||||||
|
} else {
|
||||||
|
targetElement.style.pointerEvents = "auto";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
gradioApp().addEventListener("mousemove", handleMoveByKey);
|
||||||
|
}
|
||||||
|
|
||||||
|
applyZoomAndPan(elements.sketch, elementIDs.sketch);
|
||||||
|
applyZoomAndPan(elements.inpaint, elementIDs.inpaint);
|
||||||
|
applyZoomAndPan(elements.inpaintSketch, elementIDs.inpaintSketch);
|
||||||
|
});
|
@ -1,7 +1,9 @@
|
|||||||
|
let gamepads = [];
|
||||||
|
|
||||||
window.addEventListener('gamepadconnected', (e) => {
|
window.addEventListener('gamepadconnected', (e) => {
|
||||||
const index = e.gamepad.index;
|
const index = e.gamepad.index;
|
||||||
let isWaiting = false;
|
let isWaiting = false;
|
||||||
setInterval(async() => {
|
gamepads[index] = setInterval(async() => {
|
||||||
if (!opts.js_modal_lightbox_gamepad || isWaiting) return;
|
if (!opts.js_modal_lightbox_gamepad || isWaiting) return;
|
||||||
const gamepad = navigator.getGamepads()[index];
|
const gamepad = navigator.getGamepads()[index];
|
||||||
const xValue = gamepad.axes[0];
|
const xValue = gamepad.axes[0];
|
||||||
@ -24,6 +26,10 @@ window.addEventListener('gamepadconnected', (e) => {
|
|||||||
}, 10);
|
}, 10);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
window.addEventListener('gamepaddisconnected', (e) => {
|
||||||
|
clearInterval(gamepads[e.gamepad.index]);
|
||||||
|
});
|
||||||
|
|
||||||
/*
|
/*
|
||||||
Primarily for vr controller type pointer devices.
|
Primarily for vr controller type pointer devices.
|
||||||
I use the wheel event because there's currently no way to do it properly with web xr.
|
I use the wheel event because there's currently no way to do it properly with web xr.
|
||||||
|
@ -16,6 +16,7 @@ from secrets import compare_digest
|
|||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing
|
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing
|
||||||
from modules.api import models
|
from modules.api import models
|
||||||
|
from modules.errors import print_error
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
||||||
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
|
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
|
||||||
@ -23,6 +24,7 @@ from modules.textual_inversion.preprocess import preprocess
|
|||||||
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
|
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
|
||||||
from PIL import PngImagePlugin,Image
|
from PIL import PngImagePlugin,Image
|
||||||
from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights
|
from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights
|
||||||
|
from modules.sd_vae import vae_dict
|
||||||
from modules.sd_models_config import find_checkpoint_config_near_filename
|
from modules.sd_models_config import find_checkpoint_config_near_filename
|
||||||
from modules.realesrgan_model import get_realesrgan_models
|
from modules.realesrgan_model import get_realesrgan_models
|
||||||
from modules import devices
|
from modules import devices
|
||||||
@ -108,7 +110,6 @@ def api_middleware(app: FastAPI):
|
|||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
console = Console()
|
console = Console()
|
||||||
except Exception:
|
except Exception:
|
||||||
import traceback
|
|
||||||
rich_available = False
|
rich_available = False
|
||||||
|
|
||||||
@app.middleware("http")
|
@app.middleware("http")
|
||||||
@ -139,11 +140,12 @@ def api_middleware(app: FastAPI):
|
|||||||
"errors": str(e),
|
"errors": str(e),
|
||||||
}
|
}
|
||||||
if not isinstance(e, HTTPException): # do not print backtrace on known httpexceptions
|
if not isinstance(e, HTTPException): # do not print backtrace on known httpexceptions
|
||||||
print(f"API error: {request.method}: {request.url} {err}")
|
message = f"API error: {request.method}: {request.url} {err}"
|
||||||
if rich_available:
|
if rich_available:
|
||||||
|
print(message)
|
||||||
console.print_exception(show_locals=True, max_frames=2, extra_lines=1, suppress=[anyio, starlette], word_wrap=False, width=min([console.width, 200]))
|
console.print_exception(show_locals=True, max_frames=2, extra_lines=1, suppress=[anyio, starlette], word_wrap=False, width=min([console.width, 200]))
|
||||||
else:
|
else:
|
||||||
traceback.print_exc()
|
print_error(message, exc_info=True)
|
||||||
return JSONResponse(status_code=vars(e).get('status_code', 500), content=jsonable_encoder(err))
|
return JSONResponse(status_code=vars(e).get('status_code', 500), content=jsonable_encoder(err))
|
||||||
|
|
||||||
@app.middleware("http")
|
@app.middleware("http")
|
||||||
@ -189,6 +191,7 @@ class Api:
|
|||||||
self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem])
|
self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem])
|
||||||
self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem])
|
self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem])
|
||||||
self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem])
|
self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem])
|
||||||
|
self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=List[models.SDVaeItem])
|
||||||
self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem])
|
self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem])
|
||||||
self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[models.FaceRestorerItem])
|
self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[models.FaceRestorerItem])
|
||||||
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem])
|
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem])
|
||||||
@ -541,6 +544,9 @@ class Api:
|
|||||||
def get_sd_models(self):
|
def get_sd_models(self):
|
||||||
return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in checkpoints_list.values()]
|
return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in checkpoints_list.values()]
|
||||||
|
|
||||||
|
def get_sd_vaes(self):
|
||||||
|
return [{"model_name": x, "filename": vae_dict[x]} for x in vae_dict.keys()]
|
||||||
|
|
||||||
def get_hypernetworks(self):
|
def get_hypernetworks(self):
|
||||||
return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
|
return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
|
||||||
|
|
||||||
|
@ -249,6 +249,10 @@ class SDModelItem(BaseModel):
|
|||||||
filename: str = Field(title="Filename")
|
filename: str = Field(title="Filename")
|
||||||
config: Optional[str] = Field(title="Config file")
|
config: Optional[str] = Field(title="Config file")
|
||||||
|
|
||||||
|
class SDVaeItem(BaseModel):
|
||||||
|
model_name: str = Field(title="Model Name")
|
||||||
|
filename: str = Field(title="Filename")
|
||||||
|
|
||||||
class HypernetworkItem(BaseModel):
|
class HypernetworkItem(BaseModel):
|
||||||
name: str = Field(title="Name")
|
name: str = Field(title="Name")
|
||||||
path: Optional[str] = Field(title="Path")
|
path: Optional[str] = Field(title="Path")
|
||||||
|
@ -1,10 +1,9 @@
|
|||||||
import html
|
import html
|
||||||
import sys
|
|
||||||
import threading
|
import threading
|
||||||
import traceback
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from modules import shared, progress
|
from modules import shared, progress
|
||||||
|
from modules.errors import print_error
|
||||||
|
|
||||||
queue_lock = threading.Lock()
|
queue_lock = threading.Lock()
|
||||||
|
|
||||||
@ -56,16 +55,14 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
|
|||||||
try:
|
try:
|
||||||
res = list(func(*args, **kwargs))
|
res = list(func(*args, **kwargs))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# When printing out our debug argument list, do not print out more than a MB of text
|
# When printing out our debug argument list,
|
||||||
max_debug_str_len = 131072 # (1024*1024)/8
|
# do not print out more than a 100 KB of text
|
||||||
|
max_debug_str_len = 131072
|
||||||
print("Error completing request", file=sys.stderr)
|
message = "Error completing request"
|
||||||
argStr = f"Arguments: {args} {kwargs}"
|
arg_str = f"Arguments: {args} {kwargs}"[:max_debug_str_len]
|
||||||
print(argStr[:max_debug_str_len], file=sys.stderr)
|
if len(arg_str) > max_debug_str_len:
|
||||||
if len(argStr) > max_debug_str_len:
|
arg_str += f" (Argument list truncated at {max_debug_str_len}/{len(arg_str)} characters)"
|
||||||
print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr)
|
print_error(f"{message}\n{arg_str}", exc_info=True)
|
||||||
|
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
|
|
||||||
shared.state.job = ""
|
shared.state.job = ""
|
||||||
shared.state.job_count = 0
|
shared.state.job_count = 0
|
||||||
@ -108,4 +105,3 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
|
|||||||
return tuple(res)
|
return tuple(res)
|
||||||
|
|
||||||
return f
|
return f
|
||||||
|
|
||||||
|
@ -11,7 +11,7 @@ parser.add_argument("--skip-python-version-check", action='store_true', help="la
|
|||||||
parser.add_argument("--skip-torch-cuda-test", action='store_true', help="launch.py argument: do not check if CUDA is able to work properly")
|
parser.add_argument("--skip-torch-cuda-test", action='store_true', help="launch.py argument: do not check if CUDA is able to work properly")
|
||||||
parser.add_argument("--reinstall-xformers", action='store_true', help="launch.py argument: install the appropriate version of xformers even if you have some version already installed")
|
parser.add_argument("--reinstall-xformers", action='store_true', help="launch.py argument: install the appropriate version of xformers even if you have some version already installed")
|
||||||
parser.add_argument("--reinstall-torch", action='store_true', help="launch.py argument: install the appropriate version of torch even if you have some version already installed")
|
parser.add_argument("--reinstall-torch", action='store_true', help="launch.py argument: install the appropriate version of torch even if you have some version already installed")
|
||||||
parser.add_argument("--update-check", action='store_true', help="launch.py argument: chck for updates at startup")
|
parser.add_argument("--update-check", action='store_true', help="launch.py argument: check for updates at startup")
|
||||||
parser.add_argument("--test-server", action='store_true', help="launch.py argument: configure server for testing")
|
parser.add_argument("--test-server", action='store_true', help="launch.py argument: configure server for testing")
|
||||||
parser.add_argument("--skip-prepare-environment", action='store_true', help="launch.py argument: skip all environment preparation")
|
parser.add_argument("--skip-prepare-environment", action='store_true', help="launch.py argument: skip all environment preparation")
|
||||||
parser.add_argument("--skip-install", action='store_true', help="launch.py argument: skip installation of packages")
|
parser.add_argument("--skip-install", action='store_true', help="launch.py argument: skip installation of packages")
|
||||||
|
@ -1,6 +1,4 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import torch
|
import torch
|
||||||
@ -8,6 +6,7 @@ import torch
|
|||||||
import modules.face_restoration
|
import modules.face_restoration
|
||||||
import modules.shared
|
import modules.shared
|
||||||
from modules import shared, devices, modelloader
|
from modules import shared, devices, modelloader
|
||||||
|
from modules.errors import print_error
|
||||||
from modules.paths import models_path
|
from modules.paths import models_path
|
||||||
|
|
||||||
# codeformer people made a choice to include modified basicsr library to their project which makes
|
# codeformer people made a choice to include modified basicsr library to their project which makes
|
||||||
@ -105,8 +104,8 @@ def setup_model(dirname):
|
|||||||
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
|
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
|
||||||
del output
|
del output
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
except Exception as error:
|
except Exception:
|
||||||
print(f'\tFailed inference for CodeFormer: {error}', file=sys.stderr)
|
print_error('Failed inference for CodeFormer', exc_info=True)
|
||||||
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
|
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
|
||||||
|
|
||||||
restored_face = restored_face.astype('uint8')
|
restored_face = restored_face.astype('uint8')
|
||||||
@ -135,7 +134,6 @@ def setup_model(dirname):
|
|||||||
shared.face_restorers.append(codeformer)
|
shared.face_restorers.append(codeformer)
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
print("Error setting up CodeFormer:", file=sys.stderr)
|
print_error("Error setting up CodeFormer", exc_info=True)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
|
|
||||||
# sys.path = stored_sys_path
|
# sys.path = stored_sys_path
|
||||||
|
@ -3,8 +3,6 @@ Supports saving and restoring webui and extensions from a known working set of c
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
import traceback
|
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
import tqdm
|
import tqdm
|
||||||
@ -14,6 +12,7 @@ from collections import OrderedDict
|
|||||||
import git
|
import git
|
||||||
|
|
||||||
from modules import shared, extensions
|
from modules import shared, extensions
|
||||||
|
from modules.errors import print_error
|
||||||
from modules.paths_internal import script_path, config_states_dir
|
from modules.paths_internal import script_path, config_states_dir
|
||||||
|
|
||||||
|
|
||||||
@ -53,8 +52,7 @@ def get_webui_config():
|
|||||||
if os.path.exists(os.path.join(script_path, ".git")):
|
if os.path.exists(os.path.join(script_path, ".git")):
|
||||||
webui_repo = git.Repo(script_path)
|
webui_repo = git.Repo(script_path)
|
||||||
except Exception:
|
except Exception:
|
||||||
print(f"Error reading webui git info from {script_path}:", file=sys.stderr)
|
print_error(f"Error reading webui git info from {script_path}", exc_info=True)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
|
|
||||||
webui_remote = None
|
webui_remote = None
|
||||||
webui_commit_hash = None
|
webui_commit_hash = None
|
||||||
@ -134,8 +132,7 @@ def restore_webui_config(config):
|
|||||||
if os.path.exists(os.path.join(script_path, ".git")):
|
if os.path.exists(os.path.join(script_path, ".git")):
|
||||||
webui_repo = git.Repo(script_path)
|
webui_repo = git.Repo(script_path)
|
||||||
except Exception:
|
except Exception:
|
||||||
print(f"Error reading webui git info from {script_path}:", file=sys.stderr)
|
print_error(f"Error reading webui git info from {script_path}", exc_info=True)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -143,8 +140,7 @@ def restore_webui_config(config):
|
|||||||
webui_repo.git.reset(webui_commit_hash, hard=True)
|
webui_repo.git.reset(webui_commit_hash, hard=True)
|
||||||
print(f"* Restored webui to commit {webui_commit_hash}.")
|
print(f"* Restored webui to commit {webui_commit_hash}.")
|
||||||
except Exception:
|
except Exception:
|
||||||
print(f"Error restoring webui to commit {webui_commit_hash}:", file=sys.stderr)
|
print_error(f"Error restoring webui to commit{webui_commit_hash}")
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
|
|
||||||
|
|
||||||
def restore_extension_config(config):
|
def restore_extension_config(config):
|
||||||
|
@ -1,7 +1,23 @@
|
|||||||
import sys
|
import sys
|
||||||
|
import textwrap
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
|
||||||
|
def print_error(
|
||||||
|
message: str,
|
||||||
|
*,
|
||||||
|
exc_info: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Print an error message to stderr, with optional traceback.
|
||||||
|
"""
|
||||||
|
for line in message.splitlines():
|
||||||
|
print("***", line, file=sys.stderr)
|
||||||
|
if exc_info:
|
||||||
|
print(textwrap.indent(traceback.format_exc(), " "), file=sys.stderr)
|
||||||
|
print("---")
|
||||||
|
|
||||||
|
|
||||||
def print_error_explanation(message):
|
def print_error_explanation(message):
|
||||||
lines = message.strip().split("\n")
|
lines = message.strip().split("\n")
|
||||||
max_len = max([len(x) for x in lines])
|
max_len = max([len(x) for x in lines])
|
||||||
|
@ -1,11 +1,9 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
import threading
|
import threading
|
||||||
import traceback
|
|
||||||
|
|
||||||
import git
|
|
||||||
|
|
||||||
from modules import shared
|
from modules import shared
|
||||||
|
from modules.errors import print_error
|
||||||
|
from modules.gitpython_hack import Repo
|
||||||
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401
|
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401
|
||||||
|
|
||||||
extensions = []
|
extensions = []
|
||||||
@ -54,10 +52,9 @@ class Extension:
|
|||||||
repo = None
|
repo = None
|
||||||
try:
|
try:
|
||||||
if os.path.exists(os.path.join(self.path, ".git")):
|
if os.path.exists(os.path.join(self.path, ".git")):
|
||||||
repo = git.Repo(self.path)
|
repo = Repo(self.path)
|
||||||
except Exception:
|
except Exception:
|
||||||
print(f"Error reading github repository info from {self.path}:", file=sys.stderr)
|
print_error(f"Error reading github repository info from {self.path}", exc_info=True)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
|
|
||||||
if repo is None or repo.bare:
|
if repo is None or repo.bare:
|
||||||
self.remote = None
|
self.remote = None
|
||||||
@ -72,8 +69,8 @@ class Extension:
|
|||||||
self.commit_hash = commit.hexsha
|
self.commit_hash = commit.hexsha
|
||||||
self.version = self.commit_hash[:8]
|
self.version = self.commit_hash[:8]
|
||||||
|
|
||||||
except Exception as ex:
|
except Exception:
|
||||||
print(f"Failed reading extension data from Git repository ({self.name}): {ex}", file=sys.stderr)
|
print_error(f"Failed reading extension data from Git repository ({self.name})", exc_info=True)
|
||||||
self.remote = None
|
self.remote = None
|
||||||
|
|
||||||
self.have_info_from_repo = True
|
self.have_info_from_repo = True
|
||||||
@ -94,7 +91,7 @@ class Extension:
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
def check_updates(self):
|
def check_updates(self):
|
||||||
repo = git.Repo(self.path)
|
repo = Repo(self.path)
|
||||||
for fetch in repo.remote().fetch(dry_run=True):
|
for fetch in repo.remote().fetch(dry_run=True):
|
||||||
if fetch.flags != fetch.HEAD_UPTODATE:
|
if fetch.flags != fetch.HEAD_UPTODATE:
|
||||||
self.can_update = True
|
self.can_update = True
|
||||||
@ -116,7 +113,7 @@ class Extension:
|
|||||||
self.status = "latest"
|
self.status = "latest"
|
||||||
|
|
||||||
def fetch_and_reset_hard(self, commit='origin'):
|
def fetch_and_reset_hard(self, commit='origin'):
|
||||||
repo = git.Repo(self.path)
|
repo = Repo(self.path)
|
||||||
# Fix: `error: Your local changes to the following files would be overwritten by merge`,
|
# Fix: `error: Your local changes to the following files would be overwritten by merge`,
|
||||||
# because WSL2 Docker set 755 file permissions instead of 644, this results to the error.
|
# because WSL2 Docker set 755 file permissions instead of 644, this results to the error.
|
||||||
repo.git.fetch(all=True)
|
repo.git.fetch(all=True)
|
||||||
|
@ -1,12 +1,11 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
import facexlib
|
import facexlib
|
||||||
import gfpgan
|
import gfpgan
|
||||||
|
|
||||||
import modules.face_restoration
|
import modules.face_restoration
|
||||||
from modules import paths, shared, devices, modelloader
|
from modules import paths, shared, devices, modelloader
|
||||||
|
from modules.errors import print_error
|
||||||
|
|
||||||
model_dir = "GFPGAN"
|
model_dir = "GFPGAN"
|
||||||
user_path = None
|
user_path = None
|
||||||
@ -112,5 +111,4 @@ def setup_model(dirname):
|
|||||||
|
|
||||||
shared.face_restorers.append(FaceRestorerGFPGAN())
|
shared.face_restorers.append(FaceRestorerGFPGAN())
|
||||||
except Exception:
|
except Exception:
|
||||||
print("Error setting up GFPGAN:", file=sys.stderr)
|
print_error("Error setting up GFPGAN", exc_info=True)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
|
42
modules/gitpython_hack.py
Normal file
42
modules/gitpython_hack.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import io
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
import git
|
||||||
|
|
||||||
|
|
||||||
|
class Git(git.Git):
|
||||||
|
"""
|
||||||
|
Git subclassed to never use persistent processes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _get_persistent_cmd(self, attr_name, cmd_name, *args, **kwargs):
|
||||||
|
raise NotImplementedError(f"Refusing to use persistent process: {attr_name} ({cmd_name} {args} {kwargs})")
|
||||||
|
|
||||||
|
def get_object_header(self, ref: str | bytes) -> tuple[str, str, int]:
|
||||||
|
ret = subprocess.check_output(
|
||||||
|
[self.GIT_PYTHON_GIT_EXECUTABLE, "cat-file", "--batch-check"],
|
||||||
|
input=self._prepare_ref(ref),
|
||||||
|
cwd=self._working_dir,
|
||||||
|
timeout=2,
|
||||||
|
)
|
||||||
|
return self._parse_object_header(ret)
|
||||||
|
|
||||||
|
def stream_object_data(self, ref: str) -> tuple[str, str, int, "Git.CatFileContentStream"]:
|
||||||
|
# Not really streaming, per se; this buffers the entire object in memory.
|
||||||
|
# Shouldn't be a problem for our use case, since we're only using this for
|
||||||
|
# object headers (commit objects).
|
||||||
|
ret = subprocess.check_output(
|
||||||
|
[self.GIT_PYTHON_GIT_EXECUTABLE, "cat-file", "--batch"],
|
||||||
|
input=self._prepare_ref(ref),
|
||||||
|
cwd=self._working_dir,
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
|
bio = io.BytesIO(ret)
|
||||||
|
hexsha, typename, size = self._parse_object_header(bio.readline())
|
||||||
|
return (hexsha, typename, size, self.CatFileContentStream(size, bio))
|
||||||
|
|
||||||
|
|
||||||
|
class Repo(git.Repo):
|
||||||
|
GitCommandWrapperType = Git
|
@ -2,8 +2,6 @@ import datetime
|
|||||||
import glob
|
import glob
|
||||||
import html
|
import html
|
||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
import traceback
|
|
||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
import modules.textual_inversion.dataset
|
import modules.textual_inversion.dataset
|
||||||
@ -12,6 +10,7 @@ import tqdm
|
|||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from ldm.util import default
|
from ldm.util import default
|
||||||
from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint
|
from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint
|
||||||
|
from modules.errors import print_error
|
||||||
from modules.textual_inversion import textual_inversion, logging
|
from modules.textual_inversion import textual_inversion, logging
|
||||||
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
||||||
from torch import einsum
|
from torch import einsum
|
||||||
@ -325,17 +324,14 @@ def load_hypernetwork(name):
|
|||||||
if path is None:
|
if path is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
hypernetwork = Hypernetwork()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
hypernetwork = Hypernetwork()
|
||||||
hypernetwork.load(path)
|
hypernetwork.load(path)
|
||||||
|
return hypernetwork
|
||||||
except Exception:
|
except Exception:
|
||||||
print(f"Error loading hypernetwork {path}", file=sys.stderr)
|
print_error(f"Error loading hypernetwork {path}", exc_info=True)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return hypernetwork
|
|
||||||
|
|
||||||
|
|
||||||
def load_hypernetworks(names, multipliers=None):
|
def load_hypernetworks(names, multipliers=None):
|
||||||
already_loaded = {}
|
already_loaded = {}
|
||||||
@ -770,7 +766,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
|
|||||||
</p>
|
</p>
|
||||||
"""
|
"""
|
||||||
except Exception:
|
except Exception:
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print_error("Exception in training hypernetwork", exc_info=True)
|
||||||
finally:
|
finally:
|
||||||
pbar.leave = False
|
pbar.leave = False
|
||||||
pbar.close()
|
pbar.close()
|
||||||
|
@ -1,6 +1,4 @@
|
|||||||
import datetime
|
import datetime
|
||||||
import sys
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
import pytz
|
import pytz
|
||||||
import io
|
import io
|
||||||
@ -18,6 +16,7 @@ import json
|
|||||||
import hashlib
|
import hashlib
|
||||||
|
|
||||||
from modules import sd_samplers, shared, script_callbacks, errors
|
from modules import sd_samplers, shared, script_callbacks, errors
|
||||||
|
from modules.errors import print_error
|
||||||
from modules.paths_internal import roboto_ttf_file
|
from modules.paths_internal import roboto_ttf_file
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
|
|
||||||
@ -464,8 +463,7 @@ class FilenameGenerator:
|
|||||||
replacement = fun(self, *pattern_args)
|
replacement = fun(self, *pattern_args)
|
||||||
except Exception:
|
except Exception:
|
||||||
replacement = None
|
replacement = None
|
||||||
print(f"Error adding [{pattern}] to filename", file=sys.stderr)
|
print_error(f"Error adding [{pattern}] to filename", exc_info=True)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
|
|
||||||
if replacement == NOTHING_AND_SKIP_PREVIOUS_TEXT:
|
if replacement == NOTHING_AND_SKIP_PREVIOUS_TEXT:
|
||||||
continue
|
continue
|
||||||
@ -511,9 +509,12 @@ def save_image_with_geninfo(image, geninfo, filename, extension=None, existing_p
|
|||||||
existing_pnginfo['parameters'] = geninfo
|
existing_pnginfo['parameters'] = geninfo
|
||||||
|
|
||||||
if extension.lower() == '.png':
|
if extension.lower() == '.png':
|
||||||
pnginfo_data = PngImagePlugin.PngInfo()
|
if opts.enable_pnginfo:
|
||||||
for k, v in (existing_pnginfo or {}).items():
|
pnginfo_data = PngImagePlugin.PngInfo()
|
||||||
pnginfo_data.add_text(k, str(v))
|
for k, v in (existing_pnginfo or {}).items():
|
||||||
|
pnginfo_data.add_text(k, str(v))
|
||||||
|
else:
|
||||||
|
pnginfo_data = None
|
||||||
|
|
||||||
image.save(filename, format=image_format, quality=opts.jpeg_quality, pnginfo=pnginfo_data)
|
image.save(filename, format=image_format, quality=opts.jpeg_quality, pnginfo=pnginfo_data)
|
||||||
|
|
||||||
@ -697,8 +698,7 @@ def read_info_from_image(image):
|
|||||||
Negative prompt: {json_info["uc"]}
|
Negative prompt: {json_info["uc"]}
|
||||||
Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]}, Seed: {json_info["seed"]}, Size: {image.width}x{image.height}, Clip skip: 2, ENSD: 31337"""
|
Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]}, Seed: {json_info["seed"]}, Size: {image.width}x{image.height}, Clip skip: 2, ENSD: 31337"""
|
||||||
except Exception:
|
except Exception:
|
||||||
print("Error parsing NovelAI image generation parameters:", file=sys.stderr)
|
print_error("Error parsing NovelAI image generation parameters", exc_info=True)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
|
|
||||||
return geninfo, items
|
return geninfo, items
|
||||||
|
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import re
|
import re
|
||||||
@ -12,6 +11,7 @@ from torchvision import transforms
|
|||||||
from torchvision.transforms.functional import InterpolationMode
|
from torchvision.transforms.functional import InterpolationMode
|
||||||
|
|
||||||
from modules import devices, paths, shared, lowvram, modelloader, errors
|
from modules import devices, paths, shared, lowvram, modelloader, errors
|
||||||
|
from modules.errors import print_error
|
||||||
|
|
||||||
blip_image_eval_size = 384
|
blip_image_eval_size = 384
|
||||||
clip_model_name = 'ViT-L/14'
|
clip_model_name = 'ViT-L/14'
|
||||||
@ -216,8 +216,7 @@ class InterrogateModels:
|
|||||||
res += f", {match}"
|
res += f", {match}"
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
print("Error interrogating", file=sys.stderr)
|
print_error("Error interrogating", exc_info=True)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
res += "<error>"
|
res += "<error>"
|
||||||
|
|
||||||
self.unload()
|
self.unload()
|
||||||
|
@ -8,6 +8,7 @@ import json
|
|||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
|
||||||
from modules import cmd_args
|
from modules import cmd_args
|
||||||
|
from modules.errors import print_error
|
||||||
from modules.paths_internal import script_path, extensions_dir
|
from modules.paths_internal import script_path, extensions_dir
|
||||||
|
|
||||||
args, _ = cmd_args.parser.parse_known_args()
|
args, _ = cmd_args.parser.parse_known_args()
|
||||||
@ -188,7 +189,7 @@ def run_extension_installer(extension_dir):
|
|||||||
|
|
||||||
print(run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env))
|
print(run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e, file=sys.stderr)
|
print_error(str(e))
|
||||||
|
|
||||||
|
|
||||||
def list_extensions(settings_file):
|
def list_extensions(settings_file):
|
||||||
@ -198,8 +199,8 @@ def list_extensions(settings_file):
|
|||||||
if os.path.isfile(settings_file):
|
if os.path.isfile(settings_file):
|
||||||
with open(settings_file, "r", encoding="utf8") as file:
|
with open(settings_file, "r", encoding="utf8") as file:
|
||||||
settings = json.load(file)
|
settings = json.load(file)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
print(e, file=sys.stderr)
|
print_error("Could not load settings", exc_info=True)
|
||||||
|
|
||||||
disabled_extensions = set(settings.get('disabled_extensions', []))
|
disabled_extensions = set(settings.get('disabled_extensions', []))
|
||||||
disable_all_extensions = settings.get('disable_all_extensions', 'none')
|
disable_all_extensions = settings.get('disable_all_extensions', 'none')
|
||||||
@ -229,13 +230,11 @@ def prepare_environment():
|
|||||||
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip")
|
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip")
|
||||||
|
|
||||||
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
|
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
|
||||||
taming_transformers_repo = os.environ.get('TAMING_TRANSFORMERS_REPO', "https://github.com/CompVis/taming-transformers.git")
|
|
||||||
k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
|
k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
|
||||||
codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git')
|
codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git')
|
||||||
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
|
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
|
||||||
|
|
||||||
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
|
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
|
||||||
taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6")
|
|
||||||
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "c9fe758757e022f05ca5a53fa8fac28889e4f1cf")
|
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "c9fe758757e022f05ca5a53fa8fac28889e4f1cf")
|
||||||
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
|
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
|
||||||
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
|
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
|
||||||
@ -286,7 +285,6 @@ def prepare_environment():
|
|||||||
os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True)
|
os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True)
|
||||||
|
|
||||||
git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
|
git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
|
||||||
git_clone(taming_transformers_repo, repo_dir('taming-transformers'), "Taming Transformers", taming_transformers_commit_hash)
|
|
||||||
git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
|
git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
|
||||||
git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
|
git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
|
||||||
git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
|
git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
|
||||||
|
@ -1,8 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
|
from modules.errors import print_error
|
||||||
|
|
||||||
localizations = {}
|
localizations = {}
|
||||||
|
|
||||||
@ -31,7 +30,6 @@ def localization_js(current_localization_name: str) -> str:
|
|||||||
with open(fn, "r", encoding="utf8") as file:
|
with open(fn, "r", encoding="utf8") as file:
|
||||||
data = json.load(file)
|
data = json.load(file)
|
||||||
except Exception:
|
except Exception:
|
||||||
print(f"Error loading localization from {fn}:", file=sys.stderr)
|
print_error(f"Error loading localization from {fn}", exc_info=True)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
|
|
||||||
return f"window.localization = {json.dumps(data)}"
|
return f"window.localization = {json.dumps(data)}"
|
||||||
|
@ -20,7 +20,6 @@ assert sd_path is not None, f"Couldn't find Stable Diffusion in any of: {possibl
|
|||||||
|
|
||||||
path_dirs = [
|
path_dirs = [
|
||||||
(sd_path, 'ldm', 'Stable Diffusion', []),
|
(sd_path, 'ldm', 'Stable Diffusion', []),
|
||||||
(os.path.join(sd_path, '../taming-transformers'), 'taming', 'Taming Transformers', []),
|
|
||||||
(os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []),
|
(os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []),
|
||||||
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
|
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
|
||||||
(os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
|
(os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
@ -23,7 +24,6 @@ import modules.images as images
|
|||||||
import modules.styles
|
import modules.styles
|
||||||
import modules.sd_models as sd_models
|
import modules.sd_models as sd_models
|
||||||
import modules.sd_vae as sd_vae
|
import modules.sd_vae as sd_vae
|
||||||
import logging
|
|
||||||
from ldm.data.util import AddMiDaS
|
from ldm.data.util import AddMiDaS
|
||||||
from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
|
from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
|
||||||
|
|
||||||
@ -321,14 +321,13 @@ class StableDiffusionProcessing:
|
|||||||
have been used before. The second element is where the previously
|
have been used before. The second element is where the previously
|
||||||
computed result is stored.
|
computed result is stored.
|
||||||
"""
|
"""
|
||||||
|
if cache[0] is not None and (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info) == cache[0]:
|
||||||
if cache[0] is not None and (required_prompts, steps) == cache[0]:
|
|
||||||
return cache[1]
|
return cache[1]
|
||||||
|
|
||||||
with devices.autocast():
|
with devices.autocast():
|
||||||
cache[1] = function(shared.sd_model, required_prompts, steps)
|
cache[1] = function(shared.sd_model, required_prompts, steps)
|
||||||
|
|
||||||
cache[0] = (required_prompts, steps)
|
cache[0] = (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info)
|
||||||
return cache[1]
|
return cache[1]
|
||||||
|
|
||||||
def setup_conds(self):
|
def setup_conds(self):
|
||||||
|
@ -1,12 +1,11 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from basicsr.utils.download_util import load_file_from_url
|
from basicsr.utils.download_util import load_file_from_url
|
||||||
from realesrgan import RealESRGANer
|
from realesrgan import RealESRGANer
|
||||||
|
|
||||||
|
from modules.errors import print_error
|
||||||
from modules.upscaler import Upscaler, UpscalerData
|
from modules.upscaler import Upscaler, UpscalerData
|
||||||
from modules.shared import cmd_opts, opts
|
from modules.shared import cmd_opts, opts
|
||||||
from modules import modelloader
|
from modules import modelloader
|
||||||
@ -36,8 +35,7 @@ class UpscalerRealESRGAN(Upscaler):
|
|||||||
self.scalers.append(scaler)
|
self.scalers.append(scaler)
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
print("Error importing Real-ESRGAN:", file=sys.stderr)
|
print_error("Error importing Real-ESRGAN", exc_info=True)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
self.enable = False
|
self.enable = False
|
||||||
self.scalers = []
|
self.scalers = []
|
||||||
|
|
||||||
@ -76,9 +74,8 @@ class UpscalerRealESRGAN(Upscaler):
|
|||||||
info.local_data_path = load_file_from_url(url=info.data_path, model_dir=self.model_download_path, progress=True)
|
info.local_data_path = load_file_from_url(url=info.data_path, model_dir=self.model_download_path, progress=True)
|
||||||
|
|
||||||
return info
|
return info
|
||||||
except Exception as e:
|
except Exception:
|
||||||
print(f"Error making Real-ESRGAN models list: {e}", file=sys.stderr)
|
print_error("Error making Real-ESRGAN models list", exc_info=True)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def load_models(self, _):
|
def load_models(self, _):
|
||||||
@ -135,5 +132,4 @@ def get_realesrgan_models(scaler):
|
|||||||
]
|
]
|
||||||
return models
|
return models
|
||||||
except Exception:
|
except Exception:
|
||||||
print("Error making Real-ESRGAN models list:", file=sys.stderr)
|
print_error("Error making Real-ESRGAN models list", exc_info=True)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
|
@ -2,8 +2,6 @@
|
|||||||
|
|
||||||
import pickle
|
import pickle
|
||||||
import collections
|
import collections
|
||||||
import sys
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import numpy
|
import numpy
|
||||||
@ -11,6 +9,8 @@ import _codecs
|
|||||||
import zipfile
|
import zipfile
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
from modules.errors import print_error
|
||||||
|
|
||||||
# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
|
# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
|
||||||
TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
|
TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
|
||||||
|
|
||||||
@ -136,17 +136,20 @@ def load_with_extra(filename, extra_handler=None, *args, **kwargs):
|
|||||||
check_pt(filename, extra_handler)
|
check_pt(filename, extra_handler)
|
||||||
|
|
||||||
except pickle.UnpicklingError:
|
except pickle.UnpicklingError:
|
||||||
print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
|
print_error(
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
f"Error verifying pickled file from {filename}\n"
|
||||||
print("-----> !!!! The file is most likely corrupted !!!! <-----", file=sys.stderr)
|
"-----> !!!! The file is most likely corrupted !!!! <-----\n"
|
||||||
print("You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", file=sys.stderr)
|
"You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
|
print_error(
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
f"Error verifying pickled file from {filename}\n"
|
||||||
print("\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr)
|
f"The file may be malicious, so the program is not going to read it.\n"
|
||||||
print("You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr)
|
f"You can skip this check with --disable-safe-unpickle commandline argument.\n\n",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return unsafe_torch_load(filename, *args, **kwargs)
|
return unsafe_torch_load(filename, *args, **kwargs)
|
||||||
@ -190,4 +193,3 @@ with safe.Extra(handler):
|
|||||||
unsafe_torch_load = torch.load
|
unsafe_torch_load = torch.load
|
||||||
torch.load = load
|
torch.load = load
|
||||||
global_extra_handler = None
|
global_extra_handler = None
|
||||||
|
|
||||||
|
@ -1,16 +1,15 @@
|
|||||||
import sys
|
|
||||||
import traceback
|
|
||||||
from collections import namedtuple
|
|
||||||
import inspect
|
import inspect
|
||||||
|
from collections import namedtuple
|
||||||
from typing import Optional, Dict, Any
|
from typing import Optional, Dict, Any
|
||||||
|
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from gradio import Blocks
|
from gradio import Blocks
|
||||||
|
|
||||||
|
from modules.errors import print_error
|
||||||
|
|
||||||
|
|
||||||
def report_exception(c, job):
|
def report_exception(c, job):
|
||||||
print(f"Error executing callback {job} for {c.script}", file=sys.stderr)
|
print_error(f"Error executing callback {job} for {c.script}", exc_info=True)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
|
|
||||||
|
|
||||||
class ImageSaveParams:
|
class ImageSaveParams:
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
import traceback
|
|
||||||
import importlib.util
|
import importlib.util
|
||||||
|
|
||||||
|
from modules.errors import print_error
|
||||||
|
|
||||||
|
|
||||||
def load_module(path):
|
def load_module(path):
|
||||||
module_spec = importlib.util.spec_from_file_location(os.path.basename(path), path)
|
module_spec = importlib.util.spec_from_file_location(os.path.basename(path), path)
|
||||||
@ -27,5 +27,4 @@ def preload_extensions(extensions_dir, parser):
|
|||||||
module.preload(parser)
|
module.preload(parser)
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
print(f"Error running preload() for {preload_script}", file=sys.stderr)
|
print_error(f"Error running preload() for {preload_script}", exc_info=True)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
|
@ -1,12 +1,12 @@
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing
|
from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing
|
||||||
|
from modules.errors import print_error
|
||||||
|
|
||||||
AlwaysVisible = object()
|
AlwaysVisible = object()
|
||||||
|
|
||||||
@ -264,8 +264,7 @@ def load_scripts():
|
|||||||
register_scripts_from_module(script_module)
|
register_scripts_from_module(script_module)
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
print(f"Error loading script: {scriptfile.filename}", file=sys.stderr)
|
print_error(f"Error loading script: {scriptfile.filename}", exc_info=True)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
sys.path = syspath
|
sys.path = syspath
|
||||||
@ -280,11 +279,9 @@ def load_scripts():
|
|||||||
|
|
||||||
def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
|
def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
|
||||||
try:
|
try:
|
||||||
res = func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
return res
|
|
||||||
except Exception:
|
except Exception:
|
||||||
print(f"Error calling: {filename}/{funcname}", file=sys.stderr)
|
print_error(f"Error calling: {filename}/{funcname}", exc_info=True)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
|
|
||||||
return default
|
return default
|
||||||
|
|
||||||
@ -450,8 +447,7 @@ class ScriptRunner:
|
|||||||
script_args = p.script_args[script.args_from:script.args_to]
|
script_args = p.script_args[script.args_from:script.args_to]
|
||||||
script.process(p, *script_args)
|
script.process(p, *script_args)
|
||||||
except Exception:
|
except Exception:
|
||||||
print(f"Error running process: {script.filename}", file=sys.stderr)
|
print_error(f"Error running process: {script.filename}", exc_info=True)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
|
|
||||||
def before_process_batch(self, p, **kwargs):
|
def before_process_batch(self, p, **kwargs):
|
||||||
for script in self.alwayson_scripts:
|
for script in self.alwayson_scripts:
|
||||||
@ -459,8 +455,7 @@ class ScriptRunner:
|
|||||||
script_args = p.script_args[script.args_from:script.args_to]
|
script_args = p.script_args[script.args_from:script.args_to]
|
||||||
script.before_process_batch(p, *script_args, **kwargs)
|
script.before_process_batch(p, *script_args, **kwargs)
|
||||||
except Exception:
|
except Exception:
|
||||||
print(f"Error running before_process_batch: {script.filename}", file=sys.stderr)
|
print_error(f"Error running before_process_batch: {script.filename}", exc_info=True)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
|
|
||||||
def process_batch(self, p, **kwargs):
|
def process_batch(self, p, **kwargs):
|
||||||
for script in self.alwayson_scripts:
|
for script in self.alwayson_scripts:
|
||||||
@ -468,8 +463,7 @@ class ScriptRunner:
|
|||||||
script_args = p.script_args[script.args_from:script.args_to]
|
script_args = p.script_args[script.args_from:script.args_to]
|
||||||
script.process_batch(p, *script_args, **kwargs)
|
script.process_batch(p, *script_args, **kwargs)
|
||||||
except Exception:
|
except Exception:
|
||||||
print(f"Error running process_batch: {script.filename}", file=sys.stderr)
|
print_error(f"Error running process_batch: {script.filename}", exc_info=True)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
|
|
||||||
def postprocess(self, p, processed):
|
def postprocess(self, p, processed):
|
||||||
for script in self.alwayson_scripts:
|
for script in self.alwayson_scripts:
|
||||||
@ -477,8 +471,7 @@ class ScriptRunner:
|
|||||||
script_args = p.script_args[script.args_from:script.args_to]
|
script_args = p.script_args[script.args_from:script.args_to]
|
||||||
script.postprocess(p, processed, *script_args)
|
script.postprocess(p, processed, *script_args)
|
||||||
except Exception:
|
except Exception:
|
||||||
print(f"Error running postprocess: {script.filename}", file=sys.stderr)
|
print_error(f"Error running postprocess: {script.filename}", exc_info=True)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
|
|
||||||
def postprocess_batch(self, p, images, **kwargs):
|
def postprocess_batch(self, p, images, **kwargs):
|
||||||
for script in self.alwayson_scripts:
|
for script in self.alwayson_scripts:
|
||||||
@ -486,8 +479,7 @@ class ScriptRunner:
|
|||||||
script_args = p.script_args[script.args_from:script.args_to]
|
script_args = p.script_args[script.args_from:script.args_to]
|
||||||
script.postprocess_batch(p, *script_args, images=images, **kwargs)
|
script.postprocess_batch(p, *script_args, images=images, **kwargs)
|
||||||
except Exception:
|
except Exception:
|
||||||
print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr)
|
print_error(f"Error running postprocess_batch: {script.filename}", exc_info=True)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
|
|
||||||
def postprocess_image(self, p, pp: PostprocessImageArgs):
|
def postprocess_image(self, p, pp: PostprocessImageArgs):
|
||||||
for script in self.alwayson_scripts:
|
for script in self.alwayson_scripts:
|
||||||
@ -495,24 +487,21 @@ class ScriptRunner:
|
|||||||
script_args = p.script_args[script.args_from:script.args_to]
|
script_args = p.script_args[script.args_from:script.args_to]
|
||||||
script.postprocess_image(p, pp, *script_args)
|
script.postprocess_image(p, pp, *script_args)
|
||||||
except Exception:
|
except Exception:
|
||||||
print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr)
|
print_error(f"Error running postprocess_image: {script.filename}", exc_info=True)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
|
|
||||||
def before_component(self, component, **kwargs):
|
def before_component(self, component, **kwargs):
|
||||||
for script in self.scripts:
|
for script in self.scripts:
|
||||||
try:
|
try:
|
||||||
script.before_component(component, **kwargs)
|
script.before_component(component, **kwargs)
|
||||||
except Exception:
|
except Exception:
|
||||||
print(f"Error running before_component: {script.filename}", file=sys.stderr)
|
print_error(f"Error running before_component: {script.filename}", exc_info=True)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
|
|
||||||
def after_component(self, component, **kwargs):
|
def after_component(self, component, **kwargs):
|
||||||
for script in self.scripts:
|
for script in self.scripts:
|
||||||
try:
|
try:
|
||||||
script.after_component(component, **kwargs)
|
script.after_component(component, **kwargs)
|
||||||
except Exception:
|
except Exception:
|
||||||
print(f"Error running after_component: {script.filename}", file=sys.stderr)
|
print_error(f"Error running after_component: {script.filename}", exc_info=True)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
|
|
||||||
def reload_sources(self, cache):
|
def reload_sources(self, cache):
|
||||||
for si, script in list(enumerate(self.scripts)):
|
for si, script in list(enumerate(self.scripts)):
|
||||||
|
@ -1,7 +1,5 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
import math
|
import math
|
||||||
import sys
|
|
||||||
import traceback
|
|
||||||
import psutil
|
import psutil
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -11,6 +9,7 @@ from ldm.util import default
|
|||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
from modules import shared, errors, devices, sub_quadratic_attention
|
from modules import shared, errors, devices, sub_quadratic_attention
|
||||||
|
from modules.errors import print_error
|
||||||
from modules.hypernetworks import hypernetwork
|
from modules.hypernetworks import hypernetwork
|
||||||
|
|
||||||
import ldm.modules.attention
|
import ldm.modules.attention
|
||||||
@ -140,8 +139,7 @@ if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
|
|||||||
import xformers.ops
|
import xformers.ops
|
||||||
shared.xformers_available = True
|
shared.xformers_available = True
|
||||||
except Exception:
|
except Exception:
|
||||||
print("Cannot import xformers", file=sys.stderr)
|
print_error("Cannot import xformers", exc_info=True)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
|
|
||||||
|
|
||||||
def get_available_vram():
|
def get_available_vram():
|
||||||
|
@ -416,12 +416,12 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
|||||||
"comma_padding_backtrack": OptionInfo(20, "Prompt word wrap length limit", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1}).info("in tokens - for texts shorter than specified, if they don't fit into 75 token limit, move them to the next 75 token chunk"),
|
"comma_padding_backtrack": OptionInfo(20, "Prompt word wrap length limit", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1}).info("in tokens - for texts shorter than specified, if they don't fit into 75 token limit, move them to the next 75 token chunk"),
|
||||||
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP nrtwork; 1 ignores none, 2 ignores one layer"),
|
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP nrtwork; 1 ignores none, 2 ignores one layer"),
|
||||||
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
|
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
|
||||||
"randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU"]}).info("changes seeds drastically; use CPU to produce the same picture across different vidocard vendors"),
|
"randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('optimizations', "Optimizations"), {
|
options_templates.update(options_section(('optimizations', "Optimizations"), {
|
||||||
"cross_attention_optimization": OptionInfo("Automatic", "Cross attention optimization", gr.Dropdown, lambda: {"choices": shared_items.cross_attention_optimizations()}),
|
"cross_attention_optimization": OptionInfo("Automatic", "Cross attention optimization", gr.Dropdown, lambda: {"choices": shared_items.cross_attention_optimizations()}),
|
||||||
"s_min_uncond": OptionInfo(0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 4.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"),
|
"s_min_uncond": OptionInfo(0.0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 4.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"),
|
||||||
"token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"),
|
"token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"),
|
||||||
"token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio for img2img", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
|
"token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio for img2img", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
|
||||||
"token_merging_ratio_hr": OptionInfo(0.0, "Token merging ratio for high-res pass", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
|
"token_merging_ratio_hr": OptionInfo(0.0, "Token merging ratio for high-res pass", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
|
||||||
|
@ -1,6 +1,4 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
import traceback
|
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -16,6 +14,7 @@ from torch.utils.tensorboard import SummaryWriter
|
|||||||
|
|
||||||
from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint
|
from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint
|
||||||
import modules.textual_inversion.dataset
|
import modules.textual_inversion.dataset
|
||||||
|
from modules.errors import print_error
|
||||||
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
||||||
|
|
||||||
from modules.textual_inversion.image_embedding import embedding_to_b64, embedding_from_b64, insert_image_data_embed, extract_image_data_embed, caption_image_overlay
|
from modules.textual_inversion.image_embedding import embedding_to_b64, embedding_from_b64, insert_image_data_embed, extract_image_data_embed, caption_image_overlay
|
||||||
@ -120,16 +119,29 @@ class EmbeddingDatabase:
|
|||||||
self.embedding_dirs.clear()
|
self.embedding_dirs.clear()
|
||||||
|
|
||||||
def register_embedding(self, embedding, model):
|
def register_embedding(self, embedding, model):
|
||||||
self.word_embeddings[embedding.name] = embedding
|
return self.register_embedding_by_name(embedding, model, embedding.name)
|
||||||
|
|
||||||
ids = model.cond_stage_model.tokenize([embedding.name])[0]
|
|
||||||
|
|
||||||
|
def register_embedding_by_name(self, embedding, model, name):
|
||||||
|
ids = model.cond_stage_model.tokenize([name])[0]
|
||||||
first_id = ids[0]
|
first_id = ids[0]
|
||||||
if first_id not in self.ids_lookup:
|
if first_id not in self.ids_lookup:
|
||||||
self.ids_lookup[first_id] = []
|
self.ids_lookup[first_id] = []
|
||||||
|
if name in self.word_embeddings:
|
||||||
self.ids_lookup[first_id] = sorted(self.ids_lookup[first_id] + [(ids, embedding)], key=lambda x: len(x[0]), reverse=True)
|
# remove old one from the lookup list
|
||||||
|
lookup = [x for x in self.ids_lookup[first_id] if x[1].name!=name]
|
||||||
|
else:
|
||||||
|
lookup = self.ids_lookup[first_id]
|
||||||
|
if embedding is not None:
|
||||||
|
lookup += [(ids, embedding)]
|
||||||
|
self.ids_lookup[first_id] = sorted(lookup, key=lambda x: len(x[0]), reverse=True)
|
||||||
|
if embedding is None:
|
||||||
|
# unregister embedding with specified name
|
||||||
|
if name in self.word_embeddings:
|
||||||
|
del self.word_embeddings[name]
|
||||||
|
if len(self.ids_lookup[first_id])==0:
|
||||||
|
del self.ids_lookup[first_id]
|
||||||
|
return None
|
||||||
|
self.word_embeddings[name] = embedding
|
||||||
return embedding
|
return embedding
|
||||||
|
|
||||||
def get_expected_shape(self):
|
def get_expected_shape(self):
|
||||||
@ -207,8 +219,7 @@ class EmbeddingDatabase:
|
|||||||
|
|
||||||
self.load_from_file(fullfn, fn)
|
self.load_from_file(fullfn, fn)
|
||||||
except Exception:
|
except Exception:
|
||||||
print(f"Error loading embedding {fn}:", file=sys.stderr)
|
print_error(f"Error loading embedding {fn}", exc_info=True)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
def load_textual_inversion_embeddings(self, force_reload=False):
|
def load_textual_inversion_embeddings(self, force_reload=False):
|
||||||
@ -632,8 +643,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
|
|||||||
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
|
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
|
||||||
save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True)
|
save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True)
|
||||||
except Exception:
|
except Exception:
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print_error("Error training embedding", exc_info=True)
|
||||||
pass
|
|
||||||
finally:
|
finally:
|
||||||
pbar.leave = False
|
pbar.leave = False
|
||||||
pbar.close()
|
pbar.close()
|
||||||
|
@ -2,7 +2,6 @@ import json
|
|||||||
import mimetypes
|
import mimetypes
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
@ -14,6 +13,7 @@ from PIL import Image, PngImagePlugin # noqa: F401
|
|||||||
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
|
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
|
||||||
|
|
||||||
from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave
|
from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave
|
||||||
|
from modules.errors import print_error
|
||||||
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
|
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
|
||||||
from modules.paths import script_path, data_path
|
from modules.paths import script_path, data_path
|
||||||
|
|
||||||
@ -231,9 +231,8 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info:
|
|||||||
res = all_seeds[index if 0 <= index < len(all_seeds) else 0]
|
res = all_seeds[index if 0 <= index < len(all_seeds) else 0]
|
||||||
|
|
||||||
except json.decoder.JSONDecodeError:
|
except json.decoder.JSONDecodeError:
|
||||||
if gen_info_string != '':
|
if gen_info_string:
|
||||||
print("Error parsing JSON generation info:", file=sys.stderr)
|
print_error(f"Error parsing JSON generation info: {gen_info_string}")
|
||||||
print(gen_info_string, file=sys.stderr)
|
|
||||||
|
|
||||||
return [res, gr_show(False)]
|
return [res, gr_show(False)]
|
||||||
|
|
||||||
@ -505,10 +504,10 @@ def create_ui():
|
|||||||
with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container:
|
with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container:
|
||||||
with gr.Column(scale=80):
|
with gr.Column(scale=80):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
hr_prompt = gr.Textbox(label="Prompt", elem_id="hires_prompt", show_label=False, lines=3, placeholder="Prompt for hires fix pass.\nLeave empty to use the same prompt as in first pass.", elem_classes=["prompt"])
|
hr_prompt = gr.Textbox(label="Hires prompt", elem_id="hires_prompt", show_label=False, lines=3, placeholder="Prompt for hires fix pass.\nLeave empty to use the same prompt as in first pass.", elem_classes=["prompt"])
|
||||||
with gr.Column(scale=80):
|
with gr.Column(scale=80):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
hr_negative_prompt = gr.Textbox(label="Negative prompt", elem_id="hires_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt for hires fix pass.\nLeave empty to use the same negative prompt as in first pass.", elem_classes=["prompt"])
|
hr_negative_prompt = gr.Textbox(label="Hires negative prompt", elem_id="hires_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt for hires fix pass.\nLeave empty to use the same negative prompt as in first pass.", elem_classes=["prompt"])
|
||||||
|
|
||||||
elif category == "batch":
|
elif category == "batch":
|
||||||
if not opts.dimensions_and_batch_together:
|
if not opts.dimensions_and_batch_together:
|
||||||
@ -1753,8 +1752,7 @@ def create_ui():
|
|||||||
try:
|
try:
|
||||||
results = modules.extras.run_modelmerger(*args)
|
results = modules.extras.run_modelmerger(*args)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Error loading/saving model file:", file=sys.stderr)
|
print_error("Error loading/saving model file", exc_info=True)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
modules.sd_models.list_models() # to remove the potentially missing models from the list
|
modules.sd_models.list_models() # to remove the potentially missing models from the list
|
||||||
return [*[gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)], f"Error merging checkpoints: {e}"]
|
return [*[gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)], f"Error merging checkpoints: {e}"]
|
||||||
return results
|
return results
|
||||||
|
@ -1,10 +1,8 @@
|
|||||||
import json
|
import json
|
||||||
import os.path
|
import os.path
|
||||||
import sys
|
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import traceback
|
|
||||||
|
|
||||||
import git
|
import git
|
||||||
|
|
||||||
@ -14,6 +12,7 @@ import shutil
|
|||||||
import errno
|
import errno
|
||||||
|
|
||||||
from modules import extensions, shared, paths, config_states
|
from modules import extensions, shared, paths, config_states
|
||||||
|
from modules.errors import print_error
|
||||||
from modules.paths_internal import config_states_dir
|
from modules.paths_internal import config_states_dir
|
||||||
from modules.call_queue import wrap_gradio_gpu_call
|
from modules.call_queue import wrap_gradio_gpu_call
|
||||||
|
|
||||||
@ -46,8 +45,7 @@ def apply_and_restart(disable_list, update_list, disable_all):
|
|||||||
try:
|
try:
|
||||||
ext.fetch_and_reset_hard()
|
ext.fetch_and_reset_hard()
|
||||||
except Exception:
|
except Exception:
|
||||||
print(f"Error getting updates for {ext.name}:", file=sys.stderr)
|
print_error(f"Error getting updates for {ext.name}", exc_info=True)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
|
|
||||||
shared.opts.disabled_extensions = disabled
|
shared.opts.disabled_extensions = disabled
|
||||||
shared.opts.disable_all_extensions = disable_all
|
shared.opts.disable_all_extensions = disable_all
|
||||||
@ -113,8 +111,7 @@ def check_updates(id_task, disable_list):
|
|||||||
if 'FETCH_HEAD' not in str(e):
|
if 'FETCH_HEAD' not in str(e):
|
||||||
raise
|
raise
|
||||||
except Exception:
|
except Exception:
|
||||||
print(f"Error checking updates for {ext.name}:", file=sys.stderr)
|
print_error(f"Error checking updates for {ext.name}", exc_info=True)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
|
|
||||||
shared.state.nextjob()
|
shared.state.nextjob()
|
||||||
|
|
||||||
@ -490,8 +487,14 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text="
|
|||||||
|
|
||||||
|
|
||||||
def preload_extensions_git_metadata():
|
def preload_extensions_git_metadata():
|
||||||
|
t0 = time.time()
|
||||||
for extension in extensions.extensions:
|
for extension in extensions.extensions:
|
||||||
extension.read_info_from_repo()
|
extension.read_info_from_repo()
|
||||||
|
print(
|
||||||
|
f"preload_extensions_git_metadata for "
|
||||||
|
f"{len(extensions.extensions)} extensions took "
|
||||||
|
f"{time.time() - t0:.2f}s"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_ui():
|
def create_ui():
|
||||||
|
@ -53,8 +53,8 @@ class Upscaler:
|
|||||||
|
|
||||||
def upscale(self, img: PIL.Image, scale, selected_model: str = None):
|
def upscale(self, img: PIL.Image, scale, selected_model: str = None):
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
dest_w = int(img.width * scale)
|
dest_w = round((img.width * scale - 4) / 8) * 8
|
||||||
dest_h = int(img.height * scale)
|
dest_h = round((img.height * scale - 4) / 8) * 8
|
||||||
|
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
shape = (img.width, img.height)
|
shape = (img.width, img.height)
|
||||||
|
@ -1,13 +1,12 @@
|
|||||||
import copy
|
import copy
|
||||||
import random
|
import random
|
||||||
import sys
|
|
||||||
import traceback
|
|
||||||
import shlex
|
import shlex
|
||||||
|
|
||||||
import modules.scripts as scripts
|
import modules.scripts as scripts
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from modules import sd_samplers
|
from modules import sd_samplers
|
||||||
|
from modules.errors import print_error
|
||||||
from modules.processing import Processed, process_images
|
from modules.processing import Processed, process_images
|
||||||
from modules.shared import state
|
from modules.shared import state
|
||||||
|
|
||||||
@ -136,8 +135,7 @@ class Script(scripts.Script):
|
|||||||
try:
|
try:
|
||||||
args = cmdargs(line)
|
args = cmdargs(line)
|
||||||
except Exception:
|
except Exception:
|
||||||
print(f"Error parsing line {line} as commandline:", file=sys.stderr)
|
print_error(f"Error parsing line {line} as commandline", exc_info=True)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
args = {"prompt": line}
|
args = {"prompt": line}
|
||||||
else:
|
else:
|
||||||
args = {"prompt": line}
|
args = {"prompt": line}
|
||||||
|
@ -36,7 +36,6 @@
|
|||||||
|
|
||||||
# Fixed git commits
|
# Fixed git commits
|
||||||
#export STABLE_DIFFUSION_COMMIT_HASH=""
|
#export STABLE_DIFFUSION_COMMIT_HASH=""
|
||||||
#export TAMING_TRANSFORMERS_COMMIT_HASH=""
|
|
||||||
#export CODEFORMER_COMMIT_HASH=""
|
#export CODEFORMER_COMMIT_HASH=""
|
||||||
#export BLIP_COMMIT_HASH=""
|
#export BLIP_COMMIT_HASH=""
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user