Merge branch 'AUTOMATIC1111:dev' into dev

This commit is contained in:
Beinsezii 2023-06-27 15:29:47 -07:00 committed by GitHub
commit 9d8af4bd6a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
38 changed files with 435 additions and 302 deletions

View File

@ -18,7 +18,7 @@ jobs:
# not to have GHA download an (at the time of writing) 4 GB cache # not to have GHA download an (at the time of writing) 4 GB cache
# of PyTorch and other dependencies. # of PyTorch and other dependencies.
- name: Install Ruff - name: Install Ruff
run: pip install ruff==0.0.265 run: pip install ruff==0.0.272
- name: Run Ruff - name: Run Ruff
run: ruff . run: ruff .
lint-js: lint-js:

View File

@ -50,7 +50,7 @@ jobs:
python -m pytest -vv --junitxml=test/results.xml --cov . --cov-report=xml --verify-base-url test python -m pytest -vv --junitxml=test/results.xml --cov . --cov-report=xml --verify-base-url test
- name: Kill test server - name: Kill test server
if: always() if: always()
run: curl -vv -XPOST http://127.0.0.1:7860/_stop && sleep 10 run: curl -vv -XPOST http://127.0.0.1:7860/sdapi/v1/server-stop && sleep 10
- name: Show coverage - name: Show coverage
run: | run: |
python -m coverage combine .coverage* python -m coverage combine .coverage*

View File

@ -1,7 +1,6 @@
import os import os
from basicsr.utils.download_util import load_file_from_url from modules.modelloader import load_file_from_url
from modules.upscaler import Upscaler, UpscalerData from modules.upscaler import Upscaler, UpscalerData
from ldsr_model_arch import LDSR from ldsr_model_arch import LDSR
from modules import shared, script_callbacks, errors from modules import shared, script_callbacks, errors
@ -43,20 +42,17 @@ class UpscalerLDSR(Upscaler):
if local_safetensors_path is not None and os.path.exists(local_safetensors_path): if local_safetensors_path is not None and os.path.exists(local_safetensors_path):
model = local_safetensors_path model = local_safetensors_path
else: else:
model = local_ckpt_path if local_ckpt_path is not None else load_file_from_url(url=self.model_url, model_dir=self.model_download_path, file_name="model.ckpt", progress=True) model = local_ckpt_path or load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name="model.ckpt")
yaml = local_yaml_path if local_yaml_path is not None else load_file_from_url(url=self.yaml_url, model_dir=self.model_download_path, file_name="project.yaml", progress=True) yaml = local_yaml_path or load_file_from_url(self.yaml_url, model_dir=self.model_download_path, file_name="project.yaml")
try: return LDSR(model, yaml)
return LDSR(model, yaml)
except Exception:
errors.report("Error importing LDSR", exc_info=True)
return None
def do_upscale(self, img, path): def do_upscale(self, img, path):
ldsr = self.load_model(path) try:
if ldsr is None: ldsr = self.load_model(path)
print("NO LDSR!") except Exception:
errors.report(f"Failed loading LDSR model {path}", exc_info=True)
return img return img
ddim_steps = shared.opts.ldsr_steps ddim_steps = shared.opts.ldsr_steps
return ldsr.super_resolution(img, ddim_steps, self.scale) return ldsr.super_resolution(img, ddim_steps, self.scale)

View File

@ -1,4 +1,3 @@
import os.path
import sys import sys
import PIL.Image import PIL.Image
@ -6,12 +5,11 @@ import numpy as np
import torch import torch
from tqdm import tqdm from tqdm import tqdm
from basicsr.utils.download_util import load_file_from_url
import modules.upscaler import modules.upscaler
from modules import devices, modelloader, script_callbacks, errors from modules import devices, modelloader, script_callbacks, errors
from scunet_model_arch import SCUNet as net from scunet_model_arch import SCUNet
from modules.modelloader import load_file_from_url
from modules.shared import opts from modules.shared import opts
@ -28,7 +26,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
scalers = [] scalers = []
add_model2 = True add_model2 = True
for file in model_paths: for file in model_paths:
if "http" in file: if file.startswith("http"):
name = self.model_name name = self.model_name
else: else:
name = modelloader.friendly_name(file) name = modelloader.friendly_name(file)
@ -89,9 +87,10 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
torch.cuda.empty_cache() torch.cuda.empty_cache()
model = self.load_model(selected_file) try:
if model is None: model = self.load_model(selected_file)
print(f"ScuNET: Unable to load model from {selected_file}", file=sys.stderr) except Exception as e:
print(f"ScuNET: Unable to load model from {selected_file}: {e}", file=sys.stderr)
return img return img
device = devices.get_device_for('scunet') device = devices.get_device_for('scunet')
@ -119,15 +118,12 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
def load_model(self, path: str): def load_model(self, path: str):
device = devices.get_device_for('scunet') device = devices.get_device_for('scunet')
if "http" in path: if path.startswith("http"):
filename = load_file_from_url(url=self.model_url, model_dir=self.model_download_path, file_name="%s.pth" % self.name, progress=True) # TODO: this doesn't use `path` at all?
filename = load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth")
else: else:
filename = path filename = path
if not os.path.exists(os.path.join(self.model_path, filename)) or filename is None: model = SCUNet(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
print(f"ScuNET: Unable to load model from {filename}", file=sys.stderr)
return None
model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
model.load_state_dict(torch.load(filename), strict=True) model.load_state_dict(torch.load(filename), strict=True)
model.eval() model.eval()
for _, v in model.named_parameters(): for _, v in model.named_parameters():

View File

@ -1,17 +1,17 @@
import os import sys
import numpy as np import numpy as np
import torch import torch
from PIL import Image from PIL import Image
from basicsr.utils.download_util import load_file_from_url
from tqdm import tqdm from tqdm import tqdm
from modules import modelloader, devices, script_callbacks, shared from modules import modelloader, devices, script_callbacks, shared
from modules.shared import opts, state from modules.shared import opts, state
from swinir_model_arch import SwinIR as net from swinir_model_arch import SwinIR
from swinir_model_arch_v2 import Swin2SR as net2 from swinir_model_arch_v2 import Swin2SR
from modules.upscaler import Upscaler, UpscalerData from modules.upscaler import Upscaler, UpscalerData
SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth"
device_swinir = devices.get_device_for('swinir') device_swinir = devices.get_device_for('swinir')
@ -19,16 +19,14 @@ device_swinir = devices.get_device_for('swinir')
class UpscalerSwinIR(Upscaler): class UpscalerSwinIR(Upscaler):
def __init__(self, dirname): def __init__(self, dirname):
self.name = "SwinIR" self.name = "SwinIR"
self.model_url = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0" \ self.model_url = SWINIR_MODEL_URL
"/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR" \
"-L_x4_GAN.pth "
self.model_name = "SwinIR 4x" self.model_name = "SwinIR 4x"
self.user_path = dirname self.user_path = dirname
super().__init__() super().__init__()
scalers = [] scalers = []
model_files = self.find_models(ext_filter=[".pt", ".pth"]) model_files = self.find_models(ext_filter=[".pt", ".pth"])
for model in model_files: for model in model_files:
if "http" in model: if model.startswith("http"):
name = self.model_name name = self.model_name
else: else:
name = modelloader.friendly_name(model) name = modelloader.friendly_name(model)
@ -37,8 +35,10 @@ class UpscalerSwinIR(Upscaler):
self.scalers = scalers self.scalers = scalers
def do_upscale(self, img, model_file): def do_upscale(self, img, model_file):
model = self.load_model(model_file) try:
if model is None: model = self.load_model(model_file)
except Exception as e:
print(f"Failed loading SwinIR model {model_file}: {e}", file=sys.stderr)
return img return img
model = model.to(device_swinir, dtype=devices.dtype) model = model.to(device_swinir, dtype=devices.dtype)
img = upscale(img, model) img = upscale(img, model)
@ -49,30 +49,31 @@ class UpscalerSwinIR(Upscaler):
return img return img
def load_model(self, path, scale=4): def load_model(self, path, scale=4):
if "http" in path: if path.startswith("http"):
dl_name = "%s%s" % (self.model_name.replace(" ", "_"), ".pth") filename = modelloader.load_file_from_url(
filename = load_file_from_url(url=path, model_dir=self.model_download_path, file_name=dl_name, progress=True) url=path,
model_dir=self.model_download_path,
file_name=f"{self.model_name.replace(' ', '_')}.pth",
)
else: else:
filename = path filename = path
if filename is None or not os.path.exists(filename):
return None
if filename.endswith(".v2.pth"): if filename.endswith(".v2.pth"):
model = net2( model = Swin2SR(
upscale=scale, upscale=scale,
in_chans=3, in_chans=3,
img_size=64, img_size=64,
window_size=8, window_size=8,
img_range=1.0, img_range=1.0,
depths=[6, 6, 6, 6, 6, 6], depths=[6, 6, 6, 6, 6, 6],
embed_dim=180, embed_dim=180,
num_heads=[6, 6, 6, 6, 6, 6], num_heads=[6, 6, 6, 6, 6, 6],
mlp_ratio=2, mlp_ratio=2,
upsampler="nearest+conv", upsampler="nearest+conv",
resi_connection="1conv", resi_connection="1conv",
) )
params = None params = None
else: else:
model = net( model = SwinIR(
upscale=scale, upscale=scale,
in_chans=3, in_chans=3,
img_size=64, img_size=64,

View File

@ -4,12 +4,12 @@ onUiLoaded(async() => {
inpaint: "#img2maskimg", inpaint: "#img2maskimg",
inpaintSketch: "#inpaint_sketch", inpaintSketch: "#inpaint_sketch",
rangeGroup: "#img2img_column_size", rangeGroup: "#img2img_column_size",
sketch: "#img2img_sketch", sketch: "#img2img_sketch"
}; };
const tabNameToElementId = { const tabNameToElementId = {
"Inpaint sketch": elementIDs.inpaintSketch, "Inpaint sketch": elementIDs.inpaintSketch,
"Inpaint": elementIDs.inpaint, "Inpaint": elementIDs.inpaint,
"Sketch": elementIDs.sketch, "Sketch": elementIDs.sketch
}; };
// Helper functions // Helper functions
@ -42,43 +42,110 @@ onUiLoaded(async() => {
} }
} }
// Check is hotkey valid // Function for defining the "Ctrl", "Shift" and "Alt" keys
function isSingleLetter(value) { function isModifierKey(event, key) {
switch (key) {
case "Ctrl":
return event.ctrlKey;
case "Shift":
return event.shiftKey;
case "Alt":
return event.altKey;
default:
return false;
}
}
// Check if hotkey is valid
function isValidHotkey(value) {
const specialKeys = ["Ctrl", "Alt", "Shift", "Disable"];
return ( return (
typeof value === "string" && value.length === 1 && /[a-z]/i.test(value) (typeof value === "string" &&
value.length === 1 &&
/[a-z]/i.test(value)) ||
specialKeys.includes(value)
); );
} }
// Create hotkeyConfig from opts // Normalize hotkey
function createHotkeyConfig(defaultHotkeysConfig, hotkeysConfigOpts) { function normalizeHotkey(hotkey) {
const result = {}; return hotkey.length === 1 ? "Key" + hotkey.toUpperCase() : hotkey;
const usedKeys = new Set(); }
// Format hotkey for display
function formatHotkeyForDisplay(hotkey) {
return hotkey.startsWith("Key") ? hotkey.slice(3) : hotkey;
}
// Create hotkey configuration with the provided options
function createHotkeyConfig(defaultHotkeysConfig, hotkeysConfigOpts) {
const result = {}; // Resulting hotkey configuration
const usedKeys = new Set(); // Set of used hotkeys
// Iterate through defaultHotkeysConfig keys
for (const key in defaultHotkeysConfig) { for (const key in defaultHotkeysConfig) {
if (typeof hotkeysConfigOpts[key] === "boolean") { const userValue = hotkeysConfigOpts[key]; // User-provided hotkey value
result[key] = hotkeysConfigOpts[key]; const defaultValue = defaultHotkeysConfig[key]; // Default hotkey value
continue;
} // Apply appropriate value for undefined, boolean, or object userValue
if ( if (
hotkeysConfigOpts[key] && userValue === undefined ||
isSingleLetter(hotkeysConfigOpts[key]) && typeof userValue === "boolean" ||
!usedKeys.has(hotkeysConfigOpts[key].toUpperCase()) typeof userValue === "object" ||
userValue === "disable"
) { ) {
// If the property passed the test and has not yet been used, add 'Key' before it and save it result[key] =
result[key] = "Key" + hotkeysConfigOpts[key].toUpperCase(); userValue === undefined ? defaultValue : userValue;
usedKeys.add(hotkeysConfigOpts[key].toUpperCase()); } else if (isValidHotkey(userValue)) {
const normalizedUserValue = normalizeHotkey(userValue);
// Check for conflicting hotkeys
if (!usedKeys.has(normalizedUserValue)) {
usedKeys.add(normalizedUserValue);
result[key] = normalizedUserValue;
} else {
console.error(
`Hotkey: ${formatHotkeyForDisplay(
userValue
)} for ${key} is repeated and conflicts with another hotkey. The default hotkey is used: ${formatHotkeyForDisplay(
defaultValue
)}`
);
result[key] = defaultValue;
}
} else { } else {
// If the property does not pass the test or has already been used, we keep the default value
console.error( console.error(
`Hotkey: ${hotkeysConfigOpts[key]} for ${key} is repeated and conflicts with another hotkey or is not 1 letter. The default hotkey is used: ${defaultHotkeysConfig[key][3]}` `Hotkey: ${formatHotkeyForDisplay(
userValue
)} for ${key} is not valid. The default hotkey is used: ${formatHotkeyForDisplay(
defaultValue
)}`
); );
result[key] = defaultHotkeysConfig[key]; result[key] = defaultValue;
} }
} }
return result; return result;
} }
// Disables functions in the config object based on the provided list of function names
function disableFunctions(config, disabledFunctions) {
// Bind the hasOwnProperty method to the functionMap object to avoid errors
const hasOwnProperty =
Object.prototype.hasOwnProperty.bind(functionMap);
// Loop through the disabledFunctions array and disable the corresponding functions in the config object
disabledFunctions.forEach(funcName => {
if (hasOwnProperty(funcName)) {
const key = functionMap[funcName];
config[key] = "disable";
}
});
// Return the updated config object
return config;
}
/** /**
* The restoreImgRedMask function displays a red mask around an image to indicate the aspect ratio. * The restoreImgRedMask function displays a red mask around an image to indicate the aspect ratio.
* If the image display property is set to 'none', the mask breaks. To fix this, the function * If the image display property is set to 'none', the mask breaks. To fix this, the function
@ -100,7 +167,9 @@ onUiLoaded(async() => {
imageARPreview.style.transform = ""; imageARPreview.style.transform = "";
if (parseFloat(mainTab.style.width) > 865) { if (parseFloat(mainTab.style.width) > 865) {
const transformString = mainTab.style.transform; const transformString = mainTab.style.transform;
const scaleMatch = transformString.match(/scale\(([-+]?[0-9]*\.?[0-9]+)\)/); const scaleMatch = transformString.match(
/scale\(([-+]?[0-9]*\.?[0-9]+)\)/
);
let zoom = 1; // default zoom let zoom = 1; // default zoom
if (scaleMatch && scaleMatch[1]) { if (scaleMatch && scaleMatch[1]) {
@ -124,31 +193,52 @@ onUiLoaded(async() => {
// Default config // Default config
const defaultHotkeysConfig = { const defaultHotkeysConfig = {
canvas_hotkey_zoom: "Alt",
canvas_hotkey_adjust: "Ctrl",
canvas_hotkey_reset: "KeyR", canvas_hotkey_reset: "KeyR",
canvas_hotkey_fullscreen: "KeyS", canvas_hotkey_fullscreen: "KeyS",
canvas_hotkey_move: "KeyF", canvas_hotkey_move: "KeyF",
canvas_hotkey_overlap: "KeyO", canvas_hotkey_overlap: "KeyO",
canvas_show_tooltip: true, canvas_disabled_functions: [],
canvas_swap_controls: false canvas_show_tooltip: true
}; };
// swap the actions for ctr + wheel and shift + wheel
const hotkeysConfig = createHotkeyConfig( const functionMap = {
"Zoom": "canvas_hotkey_zoom",
"Adjust brush size": "canvas_hotkey_adjust",
"Moving canvas": "canvas_hotkey_move",
"Fullscreen": "canvas_hotkey_fullscreen",
"Reset Zoom": "canvas_hotkey_reset",
"Overlap": "canvas_hotkey_overlap"
};
// Loading the configuration from opts
const preHotkeysConfig = createHotkeyConfig(
defaultHotkeysConfig, defaultHotkeysConfig,
hotkeysConfigOpts hotkeysConfigOpts
); );
// Disable functions that are not needed by the user
const hotkeysConfig = disableFunctions(
preHotkeysConfig,
preHotkeysConfig.canvas_disabled_functions
);
let isMoving = false; let isMoving = false;
let mouseX, mouseY; let mouseX, mouseY;
let activeElement; let activeElement;
const elements = Object.fromEntries(Object.keys(elementIDs).map((id) => [ const elements = Object.fromEntries(
id, Object.keys(elementIDs).map(id => [
gradioApp().querySelector(elementIDs[id]), id,
])); gradioApp().querySelector(elementIDs[id])
])
);
const elemData = {}; const elemData = {};
// Apply functionality to the range inputs. Restore redmask and correct for long images. // Apply functionality to the range inputs. Restore redmask and correct for long images.
const rangeInputs = elements.rangeGroup ? Array.from(elements.rangeGroup.querySelectorAll("input")) : const rangeInputs = elements.rangeGroup ?
Array.from(elements.rangeGroup.querySelectorAll("input")) :
[ [
gradioApp().querySelector("#img2img_width input[type='range']"), gradioApp().querySelector("#img2img_width input[type='range']"),
gradioApp().querySelector("#img2img_height input[type='range']") gradioApp().querySelector("#img2img_height input[type='range']")
@ -180,38 +270,56 @@ onUiLoaded(async() => {
const toolTipElemnt = const toolTipElemnt =
targetElement.querySelector(".image-container"); targetElement.querySelector(".image-container");
const tooltip = document.createElement("div"); const tooltip = document.createElement("div");
tooltip.className = "tooltip"; tooltip.className = "canvas-tooltip";
// Creating an item of information // Creating an item of information
const info = document.createElement("i"); const info = document.createElement("i");
info.className = "tooltip-info"; info.className = "canvas-tooltip-info";
info.textContent = ""; info.textContent = "";
// Create a container for the contents of the tooltip // Create a container for the contents of the tooltip
const tooltipContent = document.createElement("div"); const tooltipContent = document.createElement("div");
tooltipContent.className = "tooltip-content"; tooltipContent.className = "canvas-tooltip-content";
// Add info about hotkeys // Define an array with hotkey information and their actions
const zoomKey = hotkeysConfig.canvas_swap_controls ? "Ctrl" : "Shift"; const hotkeysInfo = [
const adjustKey = hotkeysConfig.canvas_swap_controls ? "Shift" : "Ctrl";
const hotkeys = [
{key: `${zoomKey} + wheel`, action: "Zoom canvas"},
{key: `${adjustKey} + wheel`, action: "Adjust brush size"},
{ {
key: hotkeysConfig.canvas_hotkey_reset.charAt(hotkeysConfig.canvas_hotkey_reset.length - 1), configKey: "canvas_hotkey_zoom",
action: "Reset zoom" action: "Zoom canvas",
keySuffix: " + wheel"
}, },
{ {
key: hotkeysConfig.canvas_hotkey_fullscreen.charAt(hotkeysConfig.canvas_hotkey_fullscreen.length - 1), configKey: "canvas_hotkey_adjust",
action: "Adjust brush size",
keySuffix: " + wheel"
},
{configKey: "canvas_hotkey_reset", action: "Reset zoom"},
{
configKey: "canvas_hotkey_fullscreen",
action: "Fullscreen mode" action: "Fullscreen mode"
}, },
{ {configKey: "canvas_hotkey_move", action: "Move canvas"},
key: hotkeysConfig.canvas_hotkey_move.charAt(hotkeysConfig.canvas_hotkey_move.length - 1), {configKey: "canvas_hotkey_overlap", action: "Overlap"}
action: "Move canvas"
}
]; ];
// Create hotkeys array with disabled property based on the config values
const hotkeys = hotkeysInfo.map(info => {
const configValue = hotkeysConfig[info.configKey];
const key = info.keySuffix ?
`${configValue}${info.keySuffix}` :
configValue.charAt(configValue.length - 1);
return {
key,
action: info.action,
disabled: configValue === "disable"
};
});
for (const hotkey of hotkeys) { for (const hotkey of hotkeys) {
if (hotkey.disabled) {
continue;
}
const p = document.createElement("p"); const p = document.createElement("p");
p.innerHTML = `<b>${hotkey.key}</b> - ${hotkey.action}`; p.innerHTML = `<b>${hotkey.key}</b> - ${hotkey.action}`;
tooltipContent.appendChild(p); tooltipContent.appendChild(p);
@ -346,10 +454,7 @@ onUiLoaded(async() => {
// Change the zoom level based on user interaction // Change the zoom level based on user interaction
function changeZoomLevel(operation, e) { function changeZoomLevel(operation, e) {
if ( if (isModifierKey(e, hotkeysConfig.canvas_hotkey_zoom)) {
(!hotkeysConfig.canvas_swap_controls && e.shiftKey) ||
(hotkeysConfig.canvas_swap_controls && e.ctrlKey)
) {
e.preventDefault(); e.preventDefault();
let zoomPosX, zoomPosY; let zoomPosX, zoomPosY;
@ -514,6 +619,13 @@ onUiLoaded(async() => {
event.preventDefault(); event.preventDefault();
action(event); action(event);
} }
if (
isModifierKey(event, hotkeysConfig.canvas_hotkey_zoom) ||
isModifierKey(event, hotkeysConfig.canvas_hotkey_adjust)
) {
event.preventDefault();
}
} }
// Get Mouse position // Get Mouse position
@ -564,11 +676,7 @@ onUiLoaded(async() => {
changeZoomLevel(operation, e); changeZoomLevel(operation, e);
// Handle brush size adjustment with ctrl key pressed // Handle brush size adjustment with ctrl key pressed
if ( if (isModifierKey(e, hotkeysConfig.canvas_hotkey_adjust)) {
(hotkeysConfig.canvas_swap_controls && e.shiftKey) ||
(!hotkeysConfig.canvas_swap_controls &&
(e.ctrlKey || e.metaKey))
) {
e.preventDefault(); e.preventDefault();
// Increase or decrease brush size based on scroll direction // Increase or decrease brush size based on scroll direction

View File

@ -1,10 +1,13 @@
import gradio as gr
from modules import shared from modules import shared
shared.options_templates.update(shared.options_section(('canvas_hotkey', "Canvas Hotkeys"), { shared.options_templates.update(shared.options_section(('canvas_hotkey', "Canvas Hotkeys"), {
"canvas_hotkey_move": shared.OptionInfo("F", "Moving the canvas"), "canvas_hotkey_zoom": shared.OptionInfo("Alt", "Zoom canvas", gr.Radio, {"choices": ["Shift","Ctrl", "Alt"]}).info("If you choose 'Shift' you cannot scroll horizontally, 'Alt' can cause a little trouble in firefox"),
"canvas_hotkey_adjust": shared.OptionInfo("Ctrl", "Adjust brush size", gr.Radio, {"choices": ["Shift","Ctrl", "Alt"]}).info("If you choose 'Shift' you cannot scroll horizontally, 'Alt' can cause a little trouble in firefox"),
"canvas_hotkey_move": shared.OptionInfo("F", "Moving the canvas").info("To work correctly in firefox, turn off 'Automatically search the page text when typing' in the browser settings"),
"canvas_hotkey_fullscreen": shared.OptionInfo("S", "Fullscreen Mode, maximizes the picture so that it fits into the screen and stretches it to its full width "), "canvas_hotkey_fullscreen": shared.OptionInfo("S", "Fullscreen Mode, maximizes the picture so that it fits into the screen and stretches it to its full width "),
"canvas_hotkey_reset": shared.OptionInfo("R", "Reset zoom and canvas positon"), "canvas_hotkey_reset": shared.OptionInfo("R", "Reset zoom and canvas positon"),
"canvas_hotkey_overlap": shared.OptionInfo("O", "Toggle overlap ( Technical button, neededs for testing )"), "canvas_hotkey_overlap": shared.OptionInfo("O", "Toggle overlap").info("Technical button, neededs for testing"),
"canvas_show_tooltip": shared.OptionInfo(True, "Enable tooltip on the canvas"), "canvas_show_tooltip": shared.OptionInfo(True, "Enable tooltip on the canvas"),
"canvas_swap_controls": shared.OptionInfo(False, "Swap hotkey combinations for Zoom and Adjust brush resize"), "canvas_disabled_functions": shared.OptionInfo(["Overlap"], "Disable function that you don't use", gr.CheckboxGroup, {"choices": ["Zoom","Adjust brush size", "Moving canvas","Fullscreen","Reset Zoom","Overlap"]}),
})) }))

View File

@ -1,4 +1,4 @@
.tooltip-info { .canvas-tooltip-info {
position: absolute; position: absolute;
top: 10px; top: 10px;
left: 10px; left: 10px;
@ -15,7 +15,7 @@
z-index: 100; z-index: 100;
} }
.tooltip-info::after { .canvas-tooltip-info::after {
content: ''; content: '';
display: block; display: block;
width: 2px; width: 2px;
@ -24,7 +24,7 @@
margin-top: 2px; margin-top: 2px;
} }
.tooltip-info::before { .canvas-tooltip-info::before {
content: ''; content: '';
display: block; display: block;
width: 2px; width: 2px;
@ -32,7 +32,7 @@
background-color: white; background-color: white;
} }
.tooltip-content { .canvas-tooltip-content {
display: none; display: none;
background-color: #f9f9f9; background-color: #f9f9f9;
color: #333; color: #333;
@ -50,7 +50,7 @@
z-index: 100; z-index: 100;
} }
.tooltip:hover .tooltip-content { .canvas-tooltip:hover .canvas-tooltip-content {
display: block; display: block;
animation: fadeIn 0.5s; animation: fadeIn 0.5s;
opacity: 1; opacity: 1;

View File

@ -72,3 +72,21 @@ function config_state_confirm_restore(_, config_state_name, config_restore_type)
} }
return [confirmed, config_state_name, config_restore_type]; return [confirmed, config_state_name, config_restore_type];
} }
function toggle_all_extensions(event) {
gradioApp().querySelectorAll('#extensions .extension_toggle').forEach(function(checkbox_el) {
checkbox_el.checked = event.target.checked;
});
}
function toggle_extension() {
let all_extensions_toggled = true;
for (const checkbox_el of gradioApp().querySelectorAll('#extensions .extension_toggle')) {
if (!checkbox_el.checked) {
all_extensions_toggled = false;
break;
}
}
gradioApp().querySelector('#extensions .all_extensions_toggle').checked = all_extensions_toggled;
}

View File

@ -15,7 +15,7 @@ var titles = {
"CFG Scale": "Classifier Free Guidance Scale - how strongly the image should conform to prompt - lower values produce more creative results", "CFG Scale": "Classifier Free Guidance Scale - how strongly the image should conform to prompt - lower values produce more creative results",
"Seed": "A value that determines the output of random number generator - if you create an image with same parameters and seed as another image, you'll get the same result", "Seed": "A value that determines the output of random number generator - if you create an image with same parameters and seed as another image, you'll get the same result",
"\u{1f3b2}\ufe0f": "Set seed to -1, which will cause a new random number to be used every time", "\u{1f3b2}\ufe0f": "Set seed to -1, which will cause a new random number to be used every time",
"\u267b\ufe0f": "Reuse seed from last generation, mostly useful if it was randomed", "\u267b\ufe0f": "Reuse seed from last generation, mostly useful if it was randomized",
"\u2199\ufe0f": "Read generation parameters from prompt or last generation if prompt is empty into user interface.", "\u2199\ufe0f": "Read generation parameters from prompt or last generation if prompt is empty into user interface.",
"\u{1f4c2}": "Open images output directory", "\u{1f4c2}": "Open images output directory",
"\u{1f4be}": "Save style", "\u{1f4be}": "Save style",
@ -112,7 +112,7 @@ var titles = {
"Resize height to": "Resizes image to this height. If 0, height is inferred from either of two nearby sliders.", "Resize height to": "Resizes image to this height. If 0, height is inferred from either of two nearby sliders.",
"Multiplier for extra networks": "When adding extra network such as Hypernetwork or Lora to prompt, use this multiplier for it.", "Multiplier for extra networks": "When adding extra network such as Hypernetwork or Lora to prompt, use this multiplier for it.",
"Discard weights with matching name": "Regular expression; if weights's name matches it, the weights is not written to the resulting checkpoint. Use ^model_ema to discard EMA weights.", "Discard weights with matching name": "Regular expression; if weights's name matches it, the weights is not written to the resulting checkpoint. Use ^model_ema to discard EMA weights.",
"Extra networks tab order": "Comma-separated list of tab names; tabs listed here will appear in the extra networks UI first and in order lsited.", "Extra networks tab order": "Comma-separated list of tab names; tabs listed here will appear in the extra networks UI first and in order listed.",
"Negative Guidance minimum sigma": "Skip negative prompt for steps where image is already mostly denoised; the higher this value, the more skips there will be; provides increased performance in exchange for minor quality reduction." "Negative Guidance minimum sigma": "Skip negative prompt for steps where image is already mostly denoised; the higher this value, the more skips there will be; provides increased performance in exchange for minor quality reduction."
}; };

View File

@ -14,7 +14,7 @@ from fastapi.encoders import jsonable_encoder
from secrets import compare_digest from secrets import compare_digest
import modules.shared as shared import modules.shared as shared
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart
from modules.api import models from modules.api import models
from modules.shared import opts from modules.shared import opts
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
@ -22,7 +22,7 @@ from modules.textual_inversion.textual_inversion import create_embedding, train_
from modules.textual_inversion.preprocess import preprocess from modules.textual_inversion.preprocess import preprocess
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
from PIL import PngImagePlugin,Image from PIL import PngImagePlugin,Image
from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights, checkpoint_alisases
from modules.sd_vae import vae_dict from modules.sd_vae import vae_dict
from modules.sd_models_config import find_checkpoint_config_near_filename from modules.sd_models_config import find_checkpoint_config_near_filename
from modules.realesrgan_model import get_realesrgan_models from modules.realesrgan_model import get_realesrgan_models
@ -32,13 +32,6 @@ import piexif
import piexif.helper import piexif.helper
def upscaler_to_index(name: str):
try:
return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
except Exception as e:
raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in shared.sd_upscalers])}") from e
def script_name_to_index(name, scripts): def script_name_to_index(name, scripts):
try: try:
return [script.title().lower() for script in scripts].index(name.lower()) return [script.title().lower() for script in scripts].index(name.lower())
@ -209,6 +202,11 @@ class Api:
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList) self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo]) self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo])
if shared.cmd_opts.add_stop_route:
self.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"])
self.add_api_route("/sdapi/v1/server-restart", self.restart_webui, methods=["POST"])
self.add_api_route("/sdapi/v1/server-stop", self.stop_webui, methods=["POST"])
self.default_script_arg_txt2img = [] self.default_script_arg_txt2img = []
self.default_script_arg_img2img = [] self.default_script_arg_img2img = []
@ -517,6 +515,10 @@ class Api:
return options return options
def set_config(self, req: Dict[str, Any]): def set_config(self, req: Dict[str, Any]):
checkpoint_name = req.get("sd_model_checkpoint", None)
if checkpoint_name is not None and checkpoint_name not in checkpoint_alisases:
raise RuntimeError(f"model {checkpoint_name!r} not found")
for k, v in req.items(): for k, v in req.items():
shared.opts.set(k, v) shared.opts.set(k, v)
@ -715,3 +717,15 @@ class Api:
def launch(self, server_name, port): def launch(self, server_name, port):
self.app.include_router(self.router) self.app.include_router(self.router)
uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=0) uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=0)
def kill_webui(self):
restart.stop_program()
def restart_webui(self):
if restart.is_restartable():
restart.restart_program()
return Response(status_code=501)
def stop_webui(request):
shared.state.server_command = "stop"
return Response("Stopping.")

View File

@ -274,10 +274,6 @@ class PromptStyleItem(BaseModel):
prompt: Optional[str] = Field(title="Prompt") prompt: Optional[str] = Field(title="Prompt")
negative_prompt: Optional[str] = Field(title="Negative Prompt") negative_prompt: Optional[str] = Field(title="Negative Prompt")
class ArtistItem(BaseModel):
name: str = Field(title="Name")
score: float = Field(title="Score")
category: str = Field(title="Category")
class EmbeddingItem(BaseModel): class EmbeddingItem(BaseModel):
step: Optional[int] = Field(title="Step", description="The number of steps that were used to train this embedding, if available") step: Optional[int] = Field(title="Step", description="The number of steps that were used to train this embedding, if available")

View File

@ -1,3 +1,4 @@
from functools import wraps
import html import html
import threading import threading
import time import time
@ -18,6 +19,7 @@ def wrap_queued_call(func):
def wrap_gradio_gpu_call(func, extra_outputs=None): def wrap_gradio_gpu_call(func, extra_outputs=None):
@wraps(func)
def f(*args, **kwargs): def f(*args, **kwargs):
# if the first argument is a string that says "task(...)", it is treated as a job id # if the first argument is a string that says "task(...)", it is treated as a job id
@ -45,6 +47,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None):
def wrap_gradio_call(func, extra_outputs=None, add_stats=False): def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
@wraps(func)
def f(*args, extra_outputs_array=extra_outputs, **kwargs): def f(*args, extra_outputs_array=extra_outputs, **kwargs):
run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats
if run_memmon: if run_memmon:

View File

@ -106,4 +106,4 @@ parser.add_argument("--skip-version-check", action='store_true', help="Do not ch
parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False) parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False) parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy') parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy')
parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server') parser.add_argument('--add-stop-route', action='store_true', help='enable server stop/restart/kill via api')

View File

@ -15,14 +15,11 @@ model_dir = "Codeformer"
model_path = os.path.join(models_path, model_dir) model_path = os.path.join(models_path, model_dir)
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth' model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
have_codeformer = False
codeformer = None codeformer = None
def setup_model(dirname): def setup_model(dirname):
global model_path os.makedirs(model_path, exist_ok=True)
if not os.path.exists(model_path):
os.makedirs(model_path)
path = modules.paths.paths.get("CodeFormer", None) path = modules.paths.paths.get("CodeFormer", None)
if path is None: if path is None:
@ -125,9 +122,6 @@ def setup_model(dirname):
return restored_img return restored_img
global have_codeformer
have_codeformer = True
global codeformer global codeformer
codeformer = FaceRestorerCodeFormer(dirname) codeformer = FaceRestorerCodeFormer(dirname)
shared.face_restorers.append(codeformer) shared.face_restorers.append(codeformer)

View File

@ -15,13 +15,6 @@ def has_mps() -> bool:
else: else:
return mac_specific.has_mps return mac_specific.has_mps
def extract_device_id(args, name):
for x in range(len(args)):
if name in args[x]:
return args[x + 1]
return None
def get_cuda_device_string(): def get_cuda_device_string():
from modules import shared from modules import shared

View File

@ -1,15 +1,13 @@
import os import sys
import numpy as np import numpy as np
import torch import torch
from PIL import Image from PIL import Image
from basicsr.utils.download_util import load_file_from_url
import modules.esrgan_model_arch as arch import modules.esrgan_model_arch as arch
from modules import modelloader, images, devices from modules import modelloader, images, devices
from modules.upscaler import Upscaler, UpscalerData
from modules.shared import opts from modules.shared import opts
from modules.upscaler import Upscaler, UpscalerData
def mod2normal(state_dict): def mod2normal(state_dict):
@ -134,7 +132,7 @@ class UpscalerESRGAN(Upscaler):
scaler_data = UpscalerData(self.model_name, self.model_url, self, 4) scaler_data = UpscalerData(self.model_name, self.model_url, self, 4)
scalers.append(scaler_data) scalers.append(scaler_data)
for file in model_paths: for file in model_paths:
if "http" in file: if file.startswith("http"):
name = self.model_name name = self.model_name
else: else:
name = modelloader.friendly_name(file) name = modelloader.friendly_name(file)
@ -143,26 +141,25 @@ class UpscalerESRGAN(Upscaler):
self.scalers.append(scaler_data) self.scalers.append(scaler_data)
def do_upscale(self, img, selected_model): def do_upscale(self, img, selected_model):
model = self.load_model(selected_model) try:
if model is None: model = self.load_model(selected_model)
except Exception as e:
print(f"Unable to load ESRGAN model {selected_model}: {e}", file=sys.stderr)
return img return img
model.to(devices.device_esrgan) model.to(devices.device_esrgan)
img = esrgan_upscale(model, img) img = esrgan_upscale(model, img)
return img return img
def load_model(self, path: str): def load_model(self, path: str):
if "http" in path: if path.startswith("http"):
filename = load_file_from_url( # TODO: this doesn't use `path` at all?
filename = modelloader.load_file_from_url(
url=self.model_url, url=self.model_url,
model_dir=self.model_download_path, model_dir=self.model_download_path,
file_name=f"{self.model_name}.pth", file_name=f"{self.model_name}.pth",
progress=True,
) )
else: else:
filename = path filename = path
if not os.path.exists(filename) or filename is None:
print(f"Unable to load {self.model_path} from {filename}")
return None
state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None) state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)

View File

@ -7,8 +7,7 @@ from modules.paths_internal import extensions_dir, extensions_builtin_dir, scrip
extensions = [] extensions = []
if not os.path.exists(extensions_dir): os.makedirs(extensions_dir, exist_ok=True)
os.makedirs(extensions_dir)
def active(): def active():

View File

@ -174,31 +174,6 @@ def send_image_and_dimensions(x):
return img, w, h return img, w, h
def find_hypernetwork_key(hypernet_name, hypernet_hash=None):
"""Determines the config parameter name to use for the hypernet based on the parameters in the infotext.
Example: an infotext provides "Hypernet: ke-ta" and "Hypernet hash: 1234abcd". For the "Hypernet" config
parameter this means there should be an entry that looks like "ke-ta-10000(1234abcd)" to set it to.
If the infotext has no hash, then a hypernet with the same name will be selected instead.
"""
hypernet_name = hypernet_name.lower()
if hypernet_hash is not None:
# Try to match the hash in the name
for hypernet_key in shared.hypernetworks.keys():
result = re_hypernet_hash.search(hypernet_key)
if result is not None and result[1] == hypernet_hash:
return hypernet_key
else:
# Fall back to a hypernet with the same name
for hypernet_key in shared.hypernetworks.keys():
if hypernet_key.lower().startswith(hypernet_name):
return hypernet_key
return None
def restore_old_hires_fix_params(res): def restore_old_hires_fix_params(res):
"""for infotexts that specify old First pass size parameter, convert it into """for infotexts that specify old First pass size parameter, convert it into
width, height, and hr scale""" width, height, and hr scale"""
@ -332,10 +307,6 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
return res return res
settings_map = {}
infotext_to_setting_name_mapping = [ infotext_to_setting_name_mapping = [
('Clip skip', 'CLIP_stop_at_last_layers', ), ('Clip skip', 'CLIP_stop_at_last_layers', ),
('Conditional mask weight', 'inpainting_mask_weight'), ('Conditional mask weight', 'inpainting_mask_weight'),
@ -357,6 +328,7 @@ infotext_to_setting_name_mapping = [
('Token merging ratio hr', 'token_merging_ratio_hr'), ('Token merging ratio hr', 'token_merging_ratio_hr'),
('RNG', 'randn_source'), ('RNG', 'randn_source'),
('NGMS', 's_min_uncond'), ('NGMS', 's_min_uncond'),
('Pad conds', 'pad_cond_uncond'),
] ]

View File

@ -25,7 +25,7 @@ def gfpgann():
return None return None
models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN") models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN")
if len(models) == 1 and "http" in models[0]: if len(models) == 1 and models[0].startswith("http"):
model_file = models[0] model_file = models[0]
elif len(models) != 0: elif len(models) != 0:
latest_file = max(models, key=os.path.getctime) latest_file = max(models, key=os.path.getctime)
@ -70,11 +70,8 @@ gfpgan_constructor = None
def setup_model(dirname): def setup_model(dirname):
global model_path
if not os.path.exists(model_path):
os.makedirs(model_path)
try: try:
os.makedirs(model_path, exist_ok=True)
from gfpgan import GFPGANer from gfpgan import GFPGANer
from facexlib import detection, parsing # noqa: F401 from facexlib import detection, parsing # noqa: F401
global user_path global user_path

View File

@ -353,17 +353,6 @@ def load_hypernetworks(names, multipliers=None):
shared.loaded_hypernetworks.append(hypernetwork) shared.loaded_hypernetworks.append(hypernetwork)
def find_closest_hypernetwork_name(search: str):
if not search:
return None
search = search.lower()
applicable = [name for name in shared.hypernetworks if search in name.lower()]
if not applicable:
return None
applicable = sorted(applicable, key=lambda name: len(name))
return applicable[0]
def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None): def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None):
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None) hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None)
@ -446,18 +435,6 @@ def statistics(data):
return total_information, recent_information return total_information, recent_information
def report_statistics(loss_info:dict):
keys = sorted(loss_info.keys(), key=lambda x: sum(loss_info[x]) / len(loss_info[x]))
for key in keys:
try:
print("Loss statistics for file " + key)
info, recent = statistics(list(loss_info[key]))
print(info)
print(recent)
except Exception as e:
print(e)
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None): def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
# Remove illegal characters from name. # Remove illegal characters from name.
name = "".join( x for x in name if (x.isalnum() or x in "._- ")) name = "".join( x for x in name if (x.isalnum() or x in "._- "))
@ -770,7 +747,6 @@ Last saved image: {html.escape(last_saved_image)}<br/>
pbar.leave = False pbar.leave = False
pbar.close() pbar.close()
hypernetwork.eval() hypernetwork.eval()
#report_statistics(loss_dict)
sd_hijack_checkpoint.remove() sd_hijack_checkpoint.remove()

View File

@ -372,8 +372,8 @@ class FilenameGenerator:
'hasprompt': lambda self, *args: self.hasprompt(*args), # accepts formats:[hasprompt<prompt1|default><prompt2>..] 'hasprompt': lambda self, *args: self.hasprompt(*args), # accepts formats:[hasprompt<prompt1|default><prompt2>..]
'clip_skip': lambda self: opts.data["CLIP_stop_at_last_layers"], 'clip_skip': lambda self: opts.data["CLIP_stop_at_last_layers"],
'denoising': lambda self: self.p.denoising_strength if self.p and self.p.denoising_strength else NOTHING_AND_SKIP_PREVIOUS_TEXT, 'denoising': lambda self: self.p.denoising_strength if self.p and self.p.denoising_strength else NOTHING_AND_SKIP_PREVIOUS_TEXT,
'user': lambda self: self.p.user,
'vae_filename': lambda self: self.get_vae_filename(), 'vae_filename': lambda self: self.get_vae_filename(),
} }
default_time_format = '%Y%m%d%H%M%S' default_time_format = '%Y%m%d%H%M%S'

View File

@ -3,6 +3,7 @@ from pathlib import Path
import numpy as np import numpy as np
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError
import gradio as gr
from modules import sd_samplers from modules import sd_samplers
from modules.generation_parameters_copypaste import create_override_settings_dict from modules.generation_parameters_copypaste import create_override_settings_dict
@ -97,7 +98,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
processed_image.save(os.path.join(output_dir, filename)) processed_image.save(os.path.join(output_dir, filename))
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args): def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, request: gr.Request, *args):
override_settings = create_override_settings_dict(override_settings_texts) override_settings = create_override_settings_dict(override_settings_texts)
is_batch = mode == 5 is_batch = mode == 5
@ -180,6 +181,8 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
p.scripts = modules.scripts.scripts_img2img p.scripts = modules.scripts.scripts_img2img
p.script_args = args p.script_args = args
p.user = request.username
if shared.cmd_opts.enable_console_prompts: if shared.cmd_opts.enable_console_prompts:
print(f"\nimg2img: {prompt}", file=shared.progress_print_out) print(f"\nimg2img: {prompt}", file=shared.progress_print_out)

View File

@ -147,10 +147,10 @@ def git_clone(url, dir, name, commithash=None):
return return
run(f'"{git}" -C "{dir}" fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}") run(f'"{git}" -C "{dir}" fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")
run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}") run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}", live=True)
return return
run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}") run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}", live=True)
if commithash is not None: if commithash is not None:
run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}") run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import os import os
import shutil import shutil
import importlib import importlib
@ -8,6 +10,29 @@ from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, Upscale
from modules.paths import script_path, models_path from modules.paths import script_path, models_path
def load_file_from_url(
url: str,
*,
model_dir: str,
progress: bool = True,
file_name: str | None = None,
) -> str:
"""Download a file from `url` into `model_dir`, using the file present if possible.
Returns the path to the downloaded file.
"""
os.makedirs(model_dir, exist_ok=True)
if not file_name:
parts = urlparse(url)
file_name = os.path.basename(parts.path)
cached_file = os.path.abspath(os.path.join(model_dir, file_name))
if not os.path.exists(cached_file):
print(f'Downloading: "{url}" to {cached_file}\n')
from torch.hub import download_url_to_file
download_url_to_file(url, cached_file, progress=progress)
return cached_file
def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None) -> list: def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None) -> list:
""" """
A one-and done loader to try finding the desired models in specified directories. A one-and done loader to try finding the desired models in specified directories.
@ -46,9 +71,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
if model_url is not None and len(output) == 0: if model_url is not None and len(output) == 0:
if download_name is not None: if download_name is not None:
from basicsr.utils.download_util import load_file_from_url output.append(load_file_from_url(model_url, model_dir=places[0], file_name=download_name))
dl = load_file_from_url(model_url, places[0], True, download_name)
output.append(dl)
else: else:
output.append(model_url) output.append(model_url)
@ -59,7 +82,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
def friendly_name(file: str): def friendly_name(file: str):
if "http" in file: if file.startswith("http"):
file = urlparse(file).path file = urlparse(file).path
file = os.path.basename(file) file = os.path.basename(file)
@ -95,8 +118,7 @@ def cleanup_models():
def move_files(src_path: str, dest_path: str, ext_filter: str = None): def move_files(src_path: str, dest_path: str, ext_filter: str = None):
try: try:
if not os.path.exists(dest_path): os.makedirs(dest_path, exist_ok=True)
os.makedirs(dest_path)
if os.path.exists(src_path): if os.path.exists(src_path):
for file in os.listdir(src_path): for file in os.listdir(src_path):
fullpath = os.path.join(src_path, file) fullpath = os.path.join(src_path, file)

View File

@ -38,17 +38,3 @@ for d, must_exist, what, options in path_dirs:
else: else:
sys.path.append(d) sys.path.append(d)
paths[what] = d paths[what] = d
class Prioritize:
def __init__(self, name):
self.name = name
self.path = None
def __enter__(self):
self.path = sys.path.copy()
sys.path = [paths[self.name]] + sys.path
def __exit__(self, exc_type, exc_val, exc_tb):
sys.path = self.path
self.path = None

View File

@ -184,6 +184,8 @@ class StableDiffusionProcessing:
self.uc = None self.uc = None
self.c = None self.c = None
self.user = None
@property @property
def sd_model(self): def sd_model(self):
return shared.sd_model return shared.sd_model
@ -549,7 +551,7 @@ def program_version():
return res return res
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0): def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0, use_main_prompt=False):
index = position_in_batch + iteration * p.batch_size index = position_in_batch + iteration * p.batch_size
clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers) clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
@ -585,13 +587,15 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond, "NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
**p.extra_generation_params, **p.extra_generation_params,
"Version": program_version() if opts.add_version_to_infotext else None, "Version": program_version() if opts.add_version_to_infotext else None,
"User": p.user if opts.add_user_name_to_info else None,
} }
generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None]) generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
prompt_text = p.prompt if use_main_prompt else all_prompts[index]
negative_prompt_text = f"\nNegative prompt: {p.all_negative_prompts[index]}" if p.all_negative_prompts[index] else "" negative_prompt_text = f"\nNegative prompt: {p.all_negative_prompts[index]}" if p.all_negative_prompts[index] else ""
return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip() return f"{prompt_text}{negative_prompt_text}\n{generation_params_text}".strip()
def process_images(p: StableDiffusionProcessing) -> Processed: def process_images(p: StableDiffusionProcessing) -> Processed:
@ -663,8 +667,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
else: else:
p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))] p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]
def infotext(iteration=0, position_in_batch=0): def infotext(iteration=0, position_in_batch=0, use_main_prompt=False):
return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch) return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch, use_main_prompt)
if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings: if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
model_hijack.embedding_db.load_textual_inversion_embeddings() model_hijack.embedding_db.load_textual_inversion_embeddings()
@ -824,7 +828,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
grid = images.image_grid(output_images, p.batch_size) grid = images.image_grid(output_images, p.batch_size)
if opts.return_grid: if opts.return_grid:
text = infotext() text = infotext(use_main_prompt=True)
infotexts.insert(0, text) infotexts.insert(0, text)
if opts.enable_pnginfo: if opts.enable_pnginfo:
grid.info["parameters"] = text grid.info["parameters"] = text
@ -832,7 +836,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
index_of_first_image = 1 index_of_first_image = 1
if opts.grid_save: if opts.grid_save:
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True) images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(use_main_prompt=True), short_filename=not opts.grid_extended_filename, p=p, grid=True)
if not p.disable_extra_networks and p.extra_network_data: if not p.disable_extra_networks and p.extra_network_data:
extra_networks.deactivate(p, p.extra_network_data) extra_networks.deactivate(p, p.extra_network_data)

View File

@ -2,7 +2,6 @@ import os
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from basicsr.utils.download_util import load_file_from_url
from realesrgan import RealESRGANer from realesrgan import RealESRGANer
from modules.upscaler import Upscaler, UpscalerData from modules.upscaler import Upscaler, UpscalerData
@ -43,9 +42,10 @@ class UpscalerRealESRGAN(Upscaler):
if not self.enable: if not self.enable:
return img return img
info = self.load_model(path) try:
if not os.path.exists(info.local_data_path): info = self.load_model(path)
print(f"Unable to load RealESRGAN model: {info.name}") except Exception:
errors.report(f"Unable to load RealESRGAN model {path}", exc_info=True)
return img return img
upsampler = RealESRGANer( upsampler = RealESRGANer(
@ -63,20 +63,17 @@ class UpscalerRealESRGAN(Upscaler):
return image return image
def load_model(self, path): def load_model(self, path):
try: for scaler in self.scalers:
info = next(iter([scaler for scaler in self.scalers if scaler.data_path == path]), None) if scaler.data_path == path:
if scaler.local_data_path.startswith("http"):
if info is None: scaler.local_data_path = modelloader.load_file_from_url(
print(f"Unable to find model info: {path}") scaler.data_path,
return None model_dir=self.model_download_path,
)
if info.local_data_path.startswith("http"): if not os.path.exists(scaler.local_data_path):
info.local_data_path = load_file_from_url(url=info.data_path, model_dir=self.model_download_path, progress=True) raise FileNotFoundError(f"RealESRGAN data missing: {scaler.local_data_path}")
return scaler
return info raise ValueError(f"Unable to find model info: {path}")
except Exception:
errors.report("Error making Real-ESRGAN models list", exc_info=True)
return None
def load_models(self, _): def load_models(self, _):
return get_realesrgan_models(self) return get_realesrgan_models(self)

View File

@ -1,6 +1,7 @@
import os import os
import re import re
import sys import sys
import inspect
from collections import namedtuple from collections import namedtuple
import gradio as gr import gradio as gr
@ -249,7 +250,7 @@ def load_scripts():
def register_scripts_from_module(module): def register_scripts_from_module(module):
for script_class in module.__dict__.values(): for script_class in module.__dict__.values():
if type(script_class) != type: if not inspect.isclass(script_class):
continue continue
if issubclass(script_class, Script): if issubclass(script_class, Script):

View File

@ -95,8 +95,7 @@ except Exception:
def setup_model(): def setup_model():
if not os.path.exists(model_path): os.makedirs(model_path, exist_ok=True)
os.makedirs(model_path)
enable_midas_autodownload() enable_midas_autodownload()
@ -248,7 +247,12 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None
_, extension = os.path.splitext(checkpoint_file) _, extension = os.path.splitext(checkpoint_file)
if extension.lower() == ".safetensors": if extension.lower() == ".safetensors":
device = map_location or shared.weight_load_location or devices.get_optimal_device_name() device = map_location or shared.weight_load_location or devices.get_optimal_device_name()
pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
if not shared.opts.disable_mmap_load_safetensors:
pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
else:
pl_sd = safetensors.torch.load(open(checkpoint_file, 'rb').read())
pl_sd = {k: v.to(device) for k, v in pl_sd.items()}
else: else:
pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location) pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)

View File

@ -69,6 +69,7 @@ class CFGDenoiser(torch.nn.Module):
self.init_latent = None self.init_latent = None
self.step = 0 self.step = 0
self.image_cfg_scale = None self.image_cfg_scale = None
self.padded_cond_uncond = False
def combine_denoised(self, x_out, conds_list, uncond, cond_scale): def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
denoised_uncond = x_out[-uncond.shape[0]:] denoised_uncond = x_out[-uncond.shape[0]:]
@ -133,15 +134,17 @@ class CFGDenoiser(torch.nn.Module):
x_in = x_in[:-batch_size] x_in = x_in[:-batch_size]
sigma_in = sigma_in[:-batch_size] sigma_in = sigma_in[:-batch_size]
# TODO add infotext entry self.padded_cond_uncond = False
if shared.opts.pad_cond_uncond and tensor.shape[1] != uncond.shape[1]: if shared.opts.pad_cond_uncond and tensor.shape[1] != uncond.shape[1]:
empty = shared.sd_model.cond_stage_model_empty_prompt empty = shared.sd_model.cond_stage_model_empty_prompt
num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1] num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1]
if num_repeats < 0: if num_repeats < 0:
tensor = torch.cat([tensor, empty.repeat((tensor.shape[0], -num_repeats, 1))], axis=1) tensor = torch.cat([tensor, empty.repeat((tensor.shape[0], -num_repeats, 1))], axis=1)
self.padded_cond_uncond = True
elif num_repeats > 0: elif num_repeats > 0:
uncond = torch.cat([uncond, empty.repeat((uncond.shape[0], num_repeats, 1))], axis=1) uncond = torch.cat([uncond, empty.repeat((uncond.shape[0], num_repeats, 1))], axis=1)
self.padded_cond_uncond = True
if tensor.shape[1] == uncond.shape[1] or skip_uncond: if tensor.shape[1] == uncond.shape[1] or skip_uncond:
if is_edit_model: if is_edit_model:
@ -405,6 +408,9 @@ class KDiffusionSampler:
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs)) samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
if self.model_wrap_cfg.padded_cond_uncond:
p.extra_generation_params["Pad conds"] = True
return samples return samples
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
@ -438,5 +444,8 @@ class KDiffusionSampler:
's_min_uncond': self.s_min_uncond 's_min_uncond': self.s_min_uncond
}, disable=False, callback=self.callback_state, **extra_params_kwargs)) }, disable=False, callback=self.callback_state, **extra_params_kwargs))
if self.model_wrap_cfg.padded_cond_uncond:
p.extra_generation_params["Pad conds"] = True
return samples return samples

View File

@ -376,6 +376,7 @@ options_templates.update(options_section(('system', "System"), {
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."), "multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
"print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."), "print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."),
"list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""), "list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""),
"disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"),
})) }))
options_templates.update(options_section(('training', "Training"), { options_templates.update(options_section(('training', "Training"), {
@ -409,7 +410,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"enable_emphasis": OptionInfo(True, "Enable emphasis").info("use (text) to make model pay more attention to text and [text] to make it pay less attention"), "enable_emphasis": OptionInfo(True, "Enable emphasis").info("use (text) to make model pay more attention to text and [text] to make it pay less attention"),
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"), "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
"comma_padding_backtrack": OptionInfo(20, "Prompt word wrap length limit", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1}).info("in tokens - for texts shorter than specified, if they don't fit into 75 token limit, move them to the next 75 token chunk"), "comma_padding_backtrack": OptionInfo(20, "Prompt word wrap length limit", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1}).info("in tokens - for texts shorter than specified, if they don't fit into 75 token limit, move them to the next 75 token chunk"),
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP nrtwork; 1 ignores none, 2 ignores one layer"), "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"),
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"), "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
"randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors"), "randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors"),
})) }))
@ -493,6 +494,7 @@ options_templates.update(options_section(('ui', "User interface"), {
options_templates.update(options_section(('infotext', "Infotext"), { options_templates.update(options_section(('infotext', "Infotext"), {
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"), "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
"add_model_name_to_info": OptionInfo(True, "Add model name to generation information"), "add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
"add_user_name_to_info": OptionInfo(False, "Add user name to generation information when authenticated"),
"add_version_to_infotext": OptionInfo(True, "Add program version to generation information"), "add_version_to_infotext": OptionInfo(True, "Add program version to generation information"),
"disable_weights_auto_swap": OptionInfo(True, "Disregard checkpoint information from pasted infotext").info("when reading generation parameters from text into UI"), "disable_weights_auto_swap": OptionInfo(True, "Disregard checkpoint information from pasted infotext").info("when reading generation parameters from text into UI"),
"infotext_styles": OptionInfo("Apply if any", "Infer styles from prompts of pasted infotext", gr.Radio, {"choices": ["Ignore", "Apply", "Discard", "Apply if any"]}).info("when reading generation parameters from text into UI)").html("""<ul style='margin-left: 1.5em'> "infotext_styles": OptionInfo("Apply if any", "Infer styles from prompts of pasted infotext", gr.Radio, {"choices": ["Ignore", "Apply", "Discard", "Apply if any"]}).info("when reading generation parameters from text into UI)").html("""<ul style='margin-left: 1.5em'>

View File

@ -298,8 +298,7 @@ def download_and_cache_models(dirname):
download_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true' download_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true'
model_file_name = 'face_detection_yunet.onnx' model_file_name = 'face_detection_yunet.onnx'
if not os.path.exists(dirname): os.makedirs(dirname, exist_ok=True)
os.makedirs(dirname)
cache_file = os.path.join(dirname, model_file_name) cache_file = os.path.join(dirname, model_file_name)
if not os.path.exists(cache_file): if not os.path.exists(cache_file):

View File

@ -2,11 +2,51 @@ import datetime
import json import json
import os import os
saved_params_shared = {"model_name", "model_hash", "initial_step", "num_of_dataset_images", "learn_rate", "batch_size", "clip_grad_mode", "clip_grad_value", "gradient_step", "data_root", "log_directory", "training_width", "training_height", "steps", "create_image_every", "template_file", "gradient_step", "latent_sampling_method"} saved_params_shared = {
saved_params_ti = {"embedding_name", "num_vectors_per_token", "save_embedding_every", "save_image_with_stored_embedding"} "batch_size",
saved_params_hypernet = {"hypernetwork_name", "layer_structure", "activation_func", "weight_init", "add_layer_norm", "use_dropout", "save_hypernetwork_every"} "clip_grad_mode",
"clip_grad_value",
"create_image_every",
"data_root",
"gradient_step",
"initial_step",
"latent_sampling_method",
"learn_rate",
"log_directory",
"model_hash",
"model_name",
"num_of_dataset_images",
"steps",
"template_file",
"training_height",
"training_width",
}
saved_params_ti = {
"embedding_name",
"num_vectors_per_token",
"save_embedding_every",
"save_image_with_stored_embedding",
}
saved_params_hypernet = {
"activation_func",
"add_layer_norm",
"hypernetwork_name",
"layer_structure",
"save_hypernetwork_every",
"use_dropout",
"weight_init",
}
saved_params_all = saved_params_shared | saved_params_ti | saved_params_hypernet saved_params_all = saved_params_shared | saved_params_ti | saved_params_hypernet
saved_params_previews = {"preview_prompt", "preview_negative_prompt", "preview_steps", "preview_sampler_index", "preview_cfg_scale", "preview_seed", "preview_width", "preview_height"} saved_params_previews = {
"preview_cfg_scale",
"preview_height",
"preview_negative_prompt",
"preview_prompt",
"preview_sampler_index",
"preview_seed",
"preview_steps",
"preview_width",
}
def save_settings_to_file(log_directory, all_params): def save_settings_to_file(log_directory, all_params):

View File

@ -4,10 +4,10 @@ from modules.generation_parameters_copypaste import create_override_settings_dic
from modules.shared import opts, cmd_opts from modules.shared import opts, cmd_opts
import modules.shared as shared import modules.shared as shared
from modules.ui import plaintext_to_html from modules.ui import plaintext_to_html
import gradio as gr
def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args):
def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, *args):
override_settings = create_override_settings_dict(override_settings_texts) override_settings = create_override_settings_dict(override_settings_texts)
p = processing.StableDiffusionProcessingTxt2Img( p = processing.StableDiffusionProcessingTxt2Img(
@ -48,6 +48,8 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
p.scripts = modules.scripts.scripts_txt2img p.scripts = modules.scripts.scripts_txt2img
p.script_args = args p.script_args = args
p.user = request.username
if cmd_opts.enable_console_prompts: if cmd_opts.enable_console_prompts:
print(f"\ntxt2img: {prompt}", file=shared.progress_print_out) print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)

View File

@ -773,7 +773,7 @@ def create_ui():
selected_scale_tab = gr.State(value=0) selected_scale_tab = gr.State(value=0)
with gr.Tabs(): with gr.Tabs():
with gr.Tab(label="Resize to") as tab_scale_to: with gr.Tab(label="Resize to", elem_id="img2img_tab_resize_to") as tab_scale_to:
with FormRow(): with FormRow():
with gr.Column(elem_id="img2img_column_size", scale=4): with gr.Column(elem_id="img2img_column_size", scale=4):
width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width") width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
@ -782,7 +782,7 @@ def create_ui():
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn") res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn")
detect_image_size_btn = ToolButton(value=detect_image_size_symbol, elem_id="img2img_detect_image_size_btn") detect_image_size_btn = ToolButton(value=detect_image_size_symbol, elem_id="img2img_detect_image_size_btn")
with gr.Tab(label="Resize by") as tab_scale_by: with gr.Tab(label="Resize by", elem_id="img2img_tab_resize_by") as tab_scale_by:
scale_by = gr.Slider(minimum=0.05, maximum=4.0, step=0.05, label="Scale", value=1.0, elem_id="img2img_scale") scale_by = gr.Slider(minimum=0.05, maximum=4.0, step=0.05, label="Scale", value=1.0, elem_id="img2img_scale")
with FormRow(): with FormRow():

View File

@ -138,7 +138,10 @@ def extension_table():
<table id="extensions"> <table id="extensions">
<thead> <thead>
<tr> <tr>
<th><abbr title="Use checkbox to enable the extension; it will be enabled or disabled when you click apply button">Extension</abbr></th> <th>
<input class="gr-check-radio gr-checkbox all_extensions_toggle" type="checkbox" {'checked="checked"' if all(ext.enabled for ext in extensions.extensions) else ''} onchange="toggle_all_extensions(event)" />
<abbr title="Use checkbox to enable the extension; it will be enabled or disabled when you click apply button">Extension</abbr>
</th>
<th>URL</th> <th>URL</th>
<th>Branch</th> <th>Branch</th>
<th>Version</th> <th>Version</th>
@ -170,7 +173,7 @@ def extension_table():
code += f""" code += f"""
<tr> <tr>
<td><label{style}><input class="gr-check-radio gr-checkbox" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''}>{html.escape(ext.name)}</label></td> <td><label{style}><input class="gr-check-radio gr-checkbox extension_toggle" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''} onchange="toggle_extension(event)" />{html.escape(ext.name)}</label></td>
<td>{remote}</td> <td>{remote}</td>
<td>{ext.branch}</td> <td>{ext.branch}</td>
<td>{version_link}</td> <td>{version_link}</td>
@ -325,6 +328,11 @@ def normalize_git_url(url):
def install_extension_from_url(dirname, url, branch_name=None): def install_extension_from_url(dirname, url, branch_name=None):
check_access() check_access()
if isinstance(dirname, str):
dirname = dirname.strip()
if isinstance(url, str):
url = url.strip()
assert url, 'No URL specified' assert url, 'No URL specified'
if dirname is None or dirname == "": if dirname is None or dirname == "":
@ -563,9 +571,9 @@ def create_ui():
available_extensions_table = gr.HTML() available_extensions_table = gr.HTML()
refresh_available_extensions_button.click( refresh_available_extensions_button.click(
fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update(), gr.update()]), fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update(), gr.update(), gr.update()]),
inputs=[available_extensions_index, hide_tags, sort_column], inputs=[available_extensions_index, hide_tags, sort_column],
outputs=[available_extensions_index, available_extensions_table, hide_tags, install_result, search_extensions_text], outputs=[available_extensions_index, available_extensions_table, hide_tags, search_extensions_text, install_result],
) )
install_extension_button.click( install_extension_button.click(

View File

@ -11,7 +11,7 @@ import json
from threading import Thread from threading import Thread
from typing import Iterable from typing import Iterable
from fastapi import FastAPI, Response from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware from fastapi.middleware.gzip import GZipMiddleware
from packaging import version from packaging import version
@ -362,11 +362,6 @@ def api_only():
api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861) api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861)
def stop_route(request):
shared.state.server_command = "stop"
return Response("Stopping.")
def webui(): def webui():
launch_api = cmd_opts.api launch_api = cmd_opts.api
initialize() initialize()
@ -404,8 +399,6 @@ def webui():
"redoc_url": "/redoc", "redoc_url": "/redoc",
}, },
) )
if cmd_opts.add_stop_route:
app.add_route("/_stop", stop_route, methods=["POST"])
# after initial launch, disable --autolaunch for subsequent restarts # after initial launch, disable --autolaunch for subsequent restarts
cmd_opts.autolaunch = False cmd_opts.autolaunch = False