Merge branch 'AUTOMATIC1111:dev' into dev
This commit is contained in:
commit
9d8af4bd6a
2
.github/workflows/on_pull_request.yaml
vendored
2
.github/workflows/on_pull_request.yaml
vendored
@ -18,7 +18,7 @@ jobs:
|
|||||||
# not to have GHA download an (at the time of writing) 4 GB cache
|
# not to have GHA download an (at the time of writing) 4 GB cache
|
||||||
# of PyTorch and other dependencies.
|
# of PyTorch and other dependencies.
|
||||||
- name: Install Ruff
|
- name: Install Ruff
|
||||||
run: pip install ruff==0.0.265
|
run: pip install ruff==0.0.272
|
||||||
- name: Run Ruff
|
- name: Run Ruff
|
||||||
run: ruff .
|
run: ruff .
|
||||||
lint-js:
|
lint-js:
|
||||||
|
2
.github/workflows/run_tests.yaml
vendored
2
.github/workflows/run_tests.yaml
vendored
@ -50,7 +50,7 @@ jobs:
|
|||||||
python -m pytest -vv --junitxml=test/results.xml --cov . --cov-report=xml --verify-base-url test
|
python -m pytest -vv --junitxml=test/results.xml --cov . --cov-report=xml --verify-base-url test
|
||||||
- name: Kill test server
|
- name: Kill test server
|
||||||
if: always()
|
if: always()
|
||||||
run: curl -vv -XPOST http://127.0.0.1:7860/_stop && sleep 10
|
run: curl -vv -XPOST http://127.0.0.1:7860/sdapi/v1/server-stop && sleep 10
|
||||||
- name: Show coverage
|
- name: Show coverage
|
||||||
run: |
|
run: |
|
||||||
python -m coverage combine .coverage*
|
python -m coverage combine .coverage*
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
from basicsr.utils.download_util import load_file_from_url
|
from modules.modelloader import load_file_from_url
|
||||||
|
|
||||||
from modules.upscaler import Upscaler, UpscalerData
|
from modules.upscaler import Upscaler, UpscalerData
|
||||||
from ldsr_model_arch import LDSR
|
from ldsr_model_arch import LDSR
|
||||||
from modules import shared, script_callbacks, errors
|
from modules import shared, script_callbacks, errors
|
||||||
@ -43,20 +42,17 @@ class UpscalerLDSR(Upscaler):
|
|||||||
if local_safetensors_path is not None and os.path.exists(local_safetensors_path):
|
if local_safetensors_path is not None and os.path.exists(local_safetensors_path):
|
||||||
model = local_safetensors_path
|
model = local_safetensors_path
|
||||||
else:
|
else:
|
||||||
model = local_ckpt_path if local_ckpt_path is not None else load_file_from_url(url=self.model_url, model_dir=self.model_download_path, file_name="model.ckpt", progress=True)
|
model = local_ckpt_path or load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name="model.ckpt")
|
||||||
|
|
||||||
yaml = local_yaml_path if local_yaml_path is not None else load_file_from_url(url=self.yaml_url, model_dir=self.model_download_path, file_name="project.yaml", progress=True)
|
yaml = local_yaml_path or load_file_from_url(self.yaml_url, model_dir=self.model_download_path, file_name="project.yaml")
|
||||||
|
|
||||||
try:
|
return LDSR(model, yaml)
|
||||||
return LDSR(model, yaml)
|
|
||||||
except Exception:
|
|
||||||
errors.report("Error importing LDSR", exc_info=True)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def do_upscale(self, img, path):
|
def do_upscale(self, img, path):
|
||||||
ldsr = self.load_model(path)
|
try:
|
||||||
if ldsr is None:
|
ldsr = self.load_model(path)
|
||||||
print("NO LDSR!")
|
except Exception:
|
||||||
|
errors.report(f"Failed loading LDSR model {path}", exc_info=True)
|
||||||
return img
|
return img
|
||||||
ddim_steps = shared.opts.ldsr_steps
|
ddim_steps = shared.opts.ldsr_steps
|
||||||
return ldsr.super_resolution(img, ddim_steps, self.scale)
|
return ldsr.super_resolution(img, ddim_steps, self.scale)
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import os.path
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import PIL.Image
|
import PIL.Image
|
||||||
@ -6,12 +5,11 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from basicsr.utils.download_util import load_file_from_url
|
|
||||||
|
|
||||||
import modules.upscaler
|
import modules.upscaler
|
||||||
from modules import devices, modelloader, script_callbacks, errors
|
from modules import devices, modelloader, script_callbacks, errors
|
||||||
from scunet_model_arch import SCUNet as net
|
from scunet_model_arch import SCUNet
|
||||||
|
|
||||||
|
from modules.modelloader import load_file_from_url
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
|
|
||||||
|
|
||||||
@ -28,7 +26,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
|
|||||||
scalers = []
|
scalers = []
|
||||||
add_model2 = True
|
add_model2 = True
|
||||||
for file in model_paths:
|
for file in model_paths:
|
||||||
if "http" in file:
|
if file.startswith("http"):
|
||||||
name = self.model_name
|
name = self.model_name
|
||||||
else:
|
else:
|
||||||
name = modelloader.friendly_name(file)
|
name = modelloader.friendly_name(file)
|
||||||
@ -89,9 +87,10 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
|
|||||||
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
model = self.load_model(selected_file)
|
try:
|
||||||
if model is None:
|
model = self.load_model(selected_file)
|
||||||
print(f"ScuNET: Unable to load model from {selected_file}", file=sys.stderr)
|
except Exception as e:
|
||||||
|
print(f"ScuNET: Unable to load model from {selected_file}: {e}", file=sys.stderr)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
device = devices.get_device_for('scunet')
|
device = devices.get_device_for('scunet')
|
||||||
@ -119,15 +118,12 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
|
|||||||
|
|
||||||
def load_model(self, path: str):
|
def load_model(self, path: str):
|
||||||
device = devices.get_device_for('scunet')
|
device = devices.get_device_for('scunet')
|
||||||
if "http" in path:
|
if path.startswith("http"):
|
||||||
filename = load_file_from_url(url=self.model_url, model_dir=self.model_download_path, file_name="%s.pth" % self.name, progress=True)
|
# TODO: this doesn't use `path` at all?
|
||||||
|
filename = load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth")
|
||||||
else:
|
else:
|
||||||
filename = path
|
filename = path
|
||||||
if not os.path.exists(os.path.join(self.model_path, filename)) or filename is None:
|
model = SCUNet(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
|
||||||
print(f"ScuNET: Unable to load model from {filename}", file=sys.stderr)
|
|
||||||
return None
|
|
||||||
|
|
||||||
model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
|
|
||||||
model.load_state_dict(torch.load(filename), strict=True)
|
model.load_state_dict(torch.load(filename), strict=True)
|
||||||
model.eval()
|
model.eval()
|
||||||
for _, v in model.named_parameters():
|
for _, v in model.named_parameters():
|
||||||
|
@ -1,17 +1,17 @@
|
|||||||
import os
|
import sys
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from basicsr.utils.download_util import load_file_from_url
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from modules import modelloader, devices, script_callbacks, shared
|
from modules import modelloader, devices, script_callbacks, shared
|
||||||
from modules.shared import opts, state
|
from modules.shared import opts, state
|
||||||
from swinir_model_arch import SwinIR as net
|
from swinir_model_arch import SwinIR
|
||||||
from swinir_model_arch_v2 import Swin2SR as net2
|
from swinir_model_arch_v2 import Swin2SR
|
||||||
from modules.upscaler import Upscaler, UpscalerData
|
from modules.upscaler import Upscaler, UpscalerData
|
||||||
|
|
||||||
|
SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth"
|
||||||
|
|
||||||
device_swinir = devices.get_device_for('swinir')
|
device_swinir = devices.get_device_for('swinir')
|
||||||
|
|
||||||
@ -19,16 +19,14 @@ device_swinir = devices.get_device_for('swinir')
|
|||||||
class UpscalerSwinIR(Upscaler):
|
class UpscalerSwinIR(Upscaler):
|
||||||
def __init__(self, dirname):
|
def __init__(self, dirname):
|
||||||
self.name = "SwinIR"
|
self.name = "SwinIR"
|
||||||
self.model_url = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0" \
|
self.model_url = SWINIR_MODEL_URL
|
||||||
"/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR" \
|
|
||||||
"-L_x4_GAN.pth "
|
|
||||||
self.model_name = "SwinIR 4x"
|
self.model_name = "SwinIR 4x"
|
||||||
self.user_path = dirname
|
self.user_path = dirname
|
||||||
super().__init__()
|
super().__init__()
|
||||||
scalers = []
|
scalers = []
|
||||||
model_files = self.find_models(ext_filter=[".pt", ".pth"])
|
model_files = self.find_models(ext_filter=[".pt", ".pth"])
|
||||||
for model in model_files:
|
for model in model_files:
|
||||||
if "http" in model:
|
if model.startswith("http"):
|
||||||
name = self.model_name
|
name = self.model_name
|
||||||
else:
|
else:
|
||||||
name = modelloader.friendly_name(model)
|
name = modelloader.friendly_name(model)
|
||||||
@ -37,8 +35,10 @@ class UpscalerSwinIR(Upscaler):
|
|||||||
self.scalers = scalers
|
self.scalers = scalers
|
||||||
|
|
||||||
def do_upscale(self, img, model_file):
|
def do_upscale(self, img, model_file):
|
||||||
model = self.load_model(model_file)
|
try:
|
||||||
if model is None:
|
model = self.load_model(model_file)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed loading SwinIR model {model_file}: {e}", file=sys.stderr)
|
||||||
return img
|
return img
|
||||||
model = model.to(device_swinir, dtype=devices.dtype)
|
model = model.to(device_swinir, dtype=devices.dtype)
|
||||||
img = upscale(img, model)
|
img = upscale(img, model)
|
||||||
@ -49,30 +49,31 @@ class UpscalerSwinIR(Upscaler):
|
|||||||
return img
|
return img
|
||||||
|
|
||||||
def load_model(self, path, scale=4):
|
def load_model(self, path, scale=4):
|
||||||
if "http" in path:
|
if path.startswith("http"):
|
||||||
dl_name = "%s%s" % (self.model_name.replace(" ", "_"), ".pth")
|
filename = modelloader.load_file_from_url(
|
||||||
filename = load_file_from_url(url=path, model_dir=self.model_download_path, file_name=dl_name, progress=True)
|
url=path,
|
||||||
|
model_dir=self.model_download_path,
|
||||||
|
file_name=f"{self.model_name.replace(' ', '_')}.pth",
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
filename = path
|
filename = path
|
||||||
if filename is None or not os.path.exists(filename):
|
|
||||||
return None
|
|
||||||
if filename.endswith(".v2.pth"):
|
if filename.endswith(".v2.pth"):
|
||||||
model = net2(
|
model = Swin2SR(
|
||||||
upscale=scale,
|
upscale=scale,
|
||||||
in_chans=3,
|
in_chans=3,
|
||||||
img_size=64,
|
img_size=64,
|
||||||
window_size=8,
|
window_size=8,
|
||||||
img_range=1.0,
|
img_range=1.0,
|
||||||
depths=[6, 6, 6, 6, 6, 6],
|
depths=[6, 6, 6, 6, 6, 6],
|
||||||
embed_dim=180,
|
embed_dim=180,
|
||||||
num_heads=[6, 6, 6, 6, 6, 6],
|
num_heads=[6, 6, 6, 6, 6, 6],
|
||||||
mlp_ratio=2,
|
mlp_ratio=2,
|
||||||
upsampler="nearest+conv",
|
upsampler="nearest+conv",
|
||||||
resi_connection="1conv",
|
resi_connection="1conv",
|
||||||
)
|
)
|
||||||
params = None
|
params = None
|
||||||
else:
|
else:
|
||||||
model = net(
|
model = SwinIR(
|
||||||
upscale=scale,
|
upscale=scale,
|
||||||
in_chans=3,
|
in_chans=3,
|
||||||
img_size=64,
|
img_size=64,
|
||||||
|
@ -4,12 +4,12 @@ onUiLoaded(async() => {
|
|||||||
inpaint: "#img2maskimg",
|
inpaint: "#img2maskimg",
|
||||||
inpaintSketch: "#inpaint_sketch",
|
inpaintSketch: "#inpaint_sketch",
|
||||||
rangeGroup: "#img2img_column_size",
|
rangeGroup: "#img2img_column_size",
|
||||||
sketch: "#img2img_sketch",
|
sketch: "#img2img_sketch"
|
||||||
};
|
};
|
||||||
const tabNameToElementId = {
|
const tabNameToElementId = {
|
||||||
"Inpaint sketch": elementIDs.inpaintSketch,
|
"Inpaint sketch": elementIDs.inpaintSketch,
|
||||||
"Inpaint": elementIDs.inpaint,
|
"Inpaint": elementIDs.inpaint,
|
||||||
"Sketch": elementIDs.sketch,
|
"Sketch": elementIDs.sketch
|
||||||
};
|
};
|
||||||
|
|
||||||
// Helper functions
|
// Helper functions
|
||||||
@ -42,43 +42,110 @@ onUiLoaded(async() => {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check is hotkey valid
|
// Function for defining the "Ctrl", "Shift" and "Alt" keys
|
||||||
function isSingleLetter(value) {
|
function isModifierKey(event, key) {
|
||||||
|
switch (key) {
|
||||||
|
case "Ctrl":
|
||||||
|
return event.ctrlKey;
|
||||||
|
case "Shift":
|
||||||
|
return event.shiftKey;
|
||||||
|
case "Alt":
|
||||||
|
return event.altKey;
|
||||||
|
default:
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if hotkey is valid
|
||||||
|
function isValidHotkey(value) {
|
||||||
|
const specialKeys = ["Ctrl", "Alt", "Shift", "Disable"];
|
||||||
return (
|
return (
|
||||||
typeof value === "string" && value.length === 1 && /[a-z]/i.test(value)
|
(typeof value === "string" &&
|
||||||
|
value.length === 1 &&
|
||||||
|
/[a-z]/i.test(value)) ||
|
||||||
|
specialKeys.includes(value)
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create hotkeyConfig from opts
|
// Normalize hotkey
|
||||||
function createHotkeyConfig(defaultHotkeysConfig, hotkeysConfigOpts) {
|
function normalizeHotkey(hotkey) {
|
||||||
const result = {};
|
return hotkey.length === 1 ? "Key" + hotkey.toUpperCase() : hotkey;
|
||||||
const usedKeys = new Set();
|
}
|
||||||
|
|
||||||
|
// Format hotkey for display
|
||||||
|
function formatHotkeyForDisplay(hotkey) {
|
||||||
|
return hotkey.startsWith("Key") ? hotkey.slice(3) : hotkey;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create hotkey configuration with the provided options
|
||||||
|
function createHotkeyConfig(defaultHotkeysConfig, hotkeysConfigOpts) {
|
||||||
|
const result = {}; // Resulting hotkey configuration
|
||||||
|
const usedKeys = new Set(); // Set of used hotkeys
|
||||||
|
|
||||||
|
// Iterate through defaultHotkeysConfig keys
|
||||||
for (const key in defaultHotkeysConfig) {
|
for (const key in defaultHotkeysConfig) {
|
||||||
if (typeof hotkeysConfigOpts[key] === "boolean") {
|
const userValue = hotkeysConfigOpts[key]; // User-provided hotkey value
|
||||||
result[key] = hotkeysConfigOpts[key];
|
const defaultValue = defaultHotkeysConfig[key]; // Default hotkey value
|
||||||
continue;
|
|
||||||
}
|
// Apply appropriate value for undefined, boolean, or object userValue
|
||||||
if (
|
if (
|
||||||
hotkeysConfigOpts[key] &&
|
userValue === undefined ||
|
||||||
isSingleLetter(hotkeysConfigOpts[key]) &&
|
typeof userValue === "boolean" ||
|
||||||
!usedKeys.has(hotkeysConfigOpts[key].toUpperCase())
|
typeof userValue === "object" ||
|
||||||
|
userValue === "disable"
|
||||||
) {
|
) {
|
||||||
// If the property passed the test and has not yet been used, add 'Key' before it and save it
|
result[key] =
|
||||||
result[key] = "Key" + hotkeysConfigOpts[key].toUpperCase();
|
userValue === undefined ? defaultValue : userValue;
|
||||||
usedKeys.add(hotkeysConfigOpts[key].toUpperCase());
|
} else if (isValidHotkey(userValue)) {
|
||||||
|
const normalizedUserValue = normalizeHotkey(userValue);
|
||||||
|
|
||||||
|
// Check for conflicting hotkeys
|
||||||
|
if (!usedKeys.has(normalizedUserValue)) {
|
||||||
|
usedKeys.add(normalizedUserValue);
|
||||||
|
result[key] = normalizedUserValue;
|
||||||
|
} else {
|
||||||
|
console.error(
|
||||||
|
`Hotkey: ${formatHotkeyForDisplay(
|
||||||
|
userValue
|
||||||
|
)} for ${key} is repeated and conflicts with another hotkey. The default hotkey is used: ${formatHotkeyForDisplay(
|
||||||
|
defaultValue
|
||||||
|
)}`
|
||||||
|
);
|
||||||
|
result[key] = defaultValue;
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
// If the property does not pass the test or has already been used, we keep the default value
|
|
||||||
console.error(
|
console.error(
|
||||||
`Hotkey: ${hotkeysConfigOpts[key]} for ${key} is repeated and conflicts with another hotkey or is not 1 letter. The default hotkey is used: ${defaultHotkeysConfig[key][3]}`
|
`Hotkey: ${formatHotkeyForDisplay(
|
||||||
|
userValue
|
||||||
|
)} for ${key} is not valid. The default hotkey is used: ${formatHotkeyForDisplay(
|
||||||
|
defaultValue
|
||||||
|
)}`
|
||||||
);
|
);
|
||||||
result[key] = defaultHotkeysConfig[key];
|
result[key] = defaultValue;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Disables functions in the config object based on the provided list of function names
|
||||||
|
function disableFunctions(config, disabledFunctions) {
|
||||||
|
// Bind the hasOwnProperty method to the functionMap object to avoid errors
|
||||||
|
const hasOwnProperty =
|
||||||
|
Object.prototype.hasOwnProperty.bind(functionMap);
|
||||||
|
|
||||||
|
// Loop through the disabledFunctions array and disable the corresponding functions in the config object
|
||||||
|
disabledFunctions.forEach(funcName => {
|
||||||
|
if (hasOwnProperty(funcName)) {
|
||||||
|
const key = functionMap[funcName];
|
||||||
|
config[key] = "disable";
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// Return the updated config object
|
||||||
|
return config;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The restoreImgRedMask function displays a red mask around an image to indicate the aspect ratio.
|
* The restoreImgRedMask function displays a red mask around an image to indicate the aspect ratio.
|
||||||
* If the image display property is set to 'none', the mask breaks. To fix this, the function
|
* If the image display property is set to 'none', the mask breaks. To fix this, the function
|
||||||
@ -100,7 +167,9 @@ onUiLoaded(async() => {
|
|||||||
imageARPreview.style.transform = "";
|
imageARPreview.style.transform = "";
|
||||||
if (parseFloat(mainTab.style.width) > 865) {
|
if (parseFloat(mainTab.style.width) > 865) {
|
||||||
const transformString = mainTab.style.transform;
|
const transformString = mainTab.style.transform;
|
||||||
const scaleMatch = transformString.match(/scale\(([-+]?[0-9]*\.?[0-9]+)\)/);
|
const scaleMatch = transformString.match(
|
||||||
|
/scale\(([-+]?[0-9]*\.?[0-9]+)\)/
|
||||||
|
);
|
||||||
let zoom = 1; // default zoom
|
let zoom = 1; // default zoom
|
||||||
|
|
||||||
if (scaleMatch && scaleMatch[1]) {
|
if (scaleMatch && scaleMatch[1]) {
|
||||||
@ -124,31 +193,52 @@ onUiLoaded(async() => {
|
|||||||
|
|
||||||
// Default config
|
// Default config
|
||||||
const defaultHotkeysConfig = {
|
const defaultHotkeysConfig = {
|
||||||
|
canvas_hotkey_zoom: "Alt",
|
||||||
|
canvas_hotkey_adjust: "Ctrl",
|
||||||
canvas_hotkey_reset: "KeyR",
|
canvas_hotkey_reset: "KeyR",
|
||||||
canvas_hotkey_fullscreen: "KeyS",
|
canvas_hotkey_fullscreen: "KeyS",
|
||||||
canvas_hotkey_move: "KeyF",
|
canvas_hotkey_move: "KeyF",
|
||||||
canvas_hotkey_overlap: "KeyO",
|
canvas_hotkey_overlap: "KeyO",
|
||||||
canvas_show_tooltip: true,
|
canvas_disabled_functions: [],
|
||||||
canvas_swap_controls: false
|
canvas_show_tooltip: true
|
||||||
};
|
};
|
||||||
// swap the actions for ctr + wheel and shift + wheel
|
|
||||||
const hotkeysConfig = createHotkeyConfig(
|
const functionMap = {
|
||||||
|
"Zoom": "canvas_hotkey_zoom",
|
||||||
|
"Adjust brush size": "canvas_hotkey_adjust",
|
||||||
|
"Moving canvas": "canvas_hotkey_move",
|
||||||
|
"Fullscreen": "canvas_hotkey_fullscreen",
|
||||||
|
"Reset Zoom": "canvas_hotkey_reset",
|
||||||
|
"Overlap": "canvas_hotkey_overlap"
|
||||||
|
};
|
||||||
|
|
||||||
|
// Loading the configuration from opts
|
||||||
|
const preHotkeysConfig = createHotkeyConfig(
|
||||||
defaultHotkeysConfig,
|
defaultHotkeysConfig,
|
||||||
hotkeysConfigOpts
|
hotkeysConfigOpts
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// Disable functions that are not needed by the user
|
||||||
|
const hotkeysConfig = disableFunctions(
|
||||||
|
preHotkeysConfig,
|
||||||
|
preHotkeysConfig.canvas_disabled_functions
|
||||||
|
);
|
||||||
|
|
||||||
let isMoving = false;
|
let isMoving = false;
|
||||||
let mouseX, mouseY;
|
let mouseX, mouseY;
|
||||||
let activeElement;
|
let activeElement;
|
||||||
|
|
||||||
const elements = Object.fromEntries(Object.keys(elementIDs).map((id) => [
|
const elements = Object.fromEntries(
|
||||||
id,
|
Object.keys(elementIDs).map(id => [
|
||||||
gradioApp().querySelector(elementIDs[id]),
|
id,
|
||||||
]));
|
gradioApp().querySelector(elementIDs[id])
|
||||||
|
])
|
||||||
|
);
|
||||||
const elemData = {};
|
const elemData = {};
|
||||||
|
|
||||||
// Apply functionality to the range inputs. Restore redmask and correct for long images.
|
// Apply functionality to the range inputs. Restore redmask and correct for long images.
|
||||||
const rangeInputs = elements.rangeGroup ? Array.from(elements.rangeGroup.querySelectorAll("input")) :
|
const rangeInputs = elements.rangeGroup ?
|
||||||
|
Array.from(elements.rangeGroup.querySelectorAll("input")) :
|
||||||
[
|
[
|
||||||
gradioApp().querySelector("#img2img_width input[type='range']"),
|
gradioApp().querySelector("#img2img_width input[type='range']"),
|
||||||
gradioApp().querySelector("#img2img_height input[type='range']")
|
gradioApp().querySelector("#img2img_height input[type='range']")
|
||||||
@ -180,38 +270,56 @@ onUiLoaded(async() => {
|
|||||||
const toolTipElemnt =
|
const toolTipElemnt =
|
||||||
targetElement.querySelector(".image-container");
|
targetElement.querySelector(".image-container");
|
||||||
const tooltip = document.createElement("div");
|
const tooltip = document.createElement("div");
|
||||||
tooltip.className = "tooltip";
|
tooltip.className = "canvas-tooltip";
|
||||||
|
|
||||||
// Creating an item of information
|
// Creating an item of information
|
||||||
const info = document.createElement("i");
|
const info = document.createElement("i");
|
||||||
info.className = "tooltip-info";
|
info.className = "canvas-tooltip-info";
|
||||||
info.textContent = "";
|
info.textContent = "";
|
||||||
|
|
||||||
// Create a container for the contents of the tooltip
|
// Create a container for the contents of the tooltip
|
||||||
const tooltipContent = document.createElement("div");
|
const tooltipContent = document.createElement("div");
|
||||||
tooltipContent.className = "tooltip-content";
|
tooltipContent.className = "canvas-tooltip-content";
|
||||||
|
|
||||||
// Add info about hotkeys
|
// Define an array with hotkey information and their actions
|
||||||
const zoomKey = hotkeysConfig.canvas_swap_controls ? "Ctrl" : "Shift";
|
const hotkeysInfo = [
|
||||||
const adjustKey = hotkeysConfig.canvas_swap_controls ? "Shift" : "Ctrl";
|
|
||||||
|
|
||||||
const hotkeys = [
|
|
||||||
{key: `${zoomKey} + wheel`, action: "Zoom canvas"},
|
|
||||||
{key: `${adjustKey} + wheel`, action: "Adjust brush size"},
|
|
||||||
{
|
{
|
||||||
key: hotkeysConfig.canvas_hotkey_reset.charAt(hotkeysConfig.canvas_hotkey_reset.length - 1),
|
configKey: "canvas_hotkey_zoom",
|
||||||
action: "Reset zoom"
|
action: "Zoom canvas",
|
||||||
|
keySuffix: " + wheel"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
key: hotkeysConfig.canvas_hotkey_fullscreen.charAt(hotkeysConfig.canvas_hotkey_fullscreen.length - 1),
|
configKey: "canvas_hotkey_adjust",
|
||||||
|
action: "Adjust brush size",
|
||||||
|
keySuffix: " + wheel"
|
||||||
|
},
|
||||||
|
{configKey: "canvas_hotkey_reset", action: "Reset zoom"},
|
||||||
|
{
|
||||||
|
configKey: "canvas_hotkey_fullscreen",
|
||||||
action: "Fullscreen mode"
|
action: "Fullscreen mode"
|
||||||
},
|
},
|
||||||
{
|
{configKey: "canvas_hotkey_move", action: "Move canvas"},
|
||||||
key: hotkeysConfig.canvas_hotkey_move.charAt(hotkeysConfig.canvas_hotkey_move.length - 1),
|
{configKey: "canvas_hotkey_overlap", action: "Overlap"}
|
||||||
action: "Move canvas"
|
|
||||||
}
|
|
||||||
];
|
];
|
||||||
|
|
||||||
|
// Create hotkeys array with disabled property based on the config values
|
||||||
|
const hotkeys = hotkeysInfo.map(info => {
|
||||||
|
const configValue = hotkeysConfig[info.configKey];
|
||||||
|
const key = info.keySuffix ?
|
||||||
|
`${configValue}${info.keySuffix}` :
|
||||||
|
configValue.charAt(configValue.length - 1);
|
||||||
|
return {
|
||||||
|
key,
|
||||||
|
action: info.action,
|
||||||
|
disabled: configValue === "disable"
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
for (const hotkey of hotkeys) {
|
for (const hotkey of hotkeys) {
|
||||||
|
if (hotkey.disabled) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
const p = document.createElement("p");
|
const p = document.createElement("p");
|
||||||
p.innerHTML = `<b>${hotkey.key}</b> - ${hotkey.action}`;
|
p.innerHTML = `<b>${hotkey.key}</b> - ${hotkey.action}`;
|
||||||
tooltipContent.appendChild(p);
|
tooltipContent.appendChild(p);
|
||||||
@ -346,10 +454,7 @@ onUiLoaded(async() => {
|
|||||||
|
|
||||||
// Change the zoom level based on user interaction
|
// Change the zoom level based on user interaction
|
||||||
function changeZoomLevel(operation, e) {
|
function changeZoomLevel(operation, e) {
|
||||||
if (
|
if (isModifierKey(e, hotkeysConfig.canvas_hotkey_zoom)) {
|
||||||
(!hotkeysConfig.canvas_swap_controls && e.shiftKey) ||
|
|
||||||
(hotkeysConfig.canvas_swap_controls && e.ctrlKey)
|
|
||||||
) {
|
|
||||||
e.preventDefault();
|
e.preventDefault();
|
||||||
|
|
||||||
let zoomPosX, zoomPosY;
|
let zoomPosX, zoomPosY;
|
||||||
@ -514,6 +619,13 @@ onUiLoaded(async() => {
|
|||||||
event.preventDefault();
|
event.preventDefault();
|
||||||
action(event);
|
action(event);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (
|
||||||
|
isModifierKey(event, hotkeysConfig.canvas_hotkey_zoom) ||
|
||||||
|
isModifierKey(event, hotkeysConfig.canvas_hotkey_adjust)
|
||||||
|
) {
|
||||||
|
event.preventDefault();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get Mouse position
|
// Get Mouse position
|
||||||
@ -564,11 +676,7 @@ onUiLoaded(async() => {
|
|||||||
changeZoomLevel(operation, e);
|
changeZoomLevel(operation, e);
|
||||||
|
|
||||||
// Handle brush size adjustment with ctrl key pressed
|
// Handle brush size adjustment with ctrl key pressed
|
||||||
if (
|
if (isModifierKey(e, hotkeysConfig.canvas_hotkey_adjust)) {
|
||||||
(hotkeysConfig.canvas_swap_controls && e.shiftKey) ||
|
|
||||||
(!hotkeysConfig.canvas_swap_controls &&
|
|
||||||
(e.ctrlKey || e.metaKey))
|
|
||||||
) {
|
|
||||||
e.preventDefault();
|
e.preventDefault();
|
||||||
|
|
||||||
// Increase or decrease brush size based on scroll direction
|
// Increase or decrease brush size based on scroll direction
|
||||||
|
@ -1,10 +1,13 @@
|
|||||||
|
import gradio as gr
|
||||||
from modules import shared
|
from modules import shared
|
||||||
|
|
||||||
shared.options_templates.update(shared.options_section(('canvas_hotkey', "Canvas Hotkeys"), {
|
shared.options_templates.update(shared.options_section(('canvas_hotkey', "Canvas Hotkeys"), {
|
||||||
"canvas_hotkey_move": shared.OptionInfo("F", "Moving the canvas"),
|
"canvas_hotkey_zoom": shared.OptionInfo("Alt", "Zoom canvas", gr.Radio, {"choices": ["Shift","Ctrl", "Alt"]}).info("If you choose 'Shift' you cannot scroll horizontally, 'Alt' can cause a little trouble in firefox"),
|
||||||
|
"canvas_hotkey_adjust": shared.OptionInfo("Ctrl", "Adjust brush size", gr.Radio, {"choices": ["Shift","Ctrl", "Alt"]}).info("If you choose 'Shift' you cannot scroll horizontally, 'Alt' can cause a little trouble in firefox"),
|
||||||
|
"canvas_hotkey_move": shared.OptionInfo("F", "Moving the canvas").info("To work correctly in firefox, turn off 'Automatically search the page text when typing' in the browser settings"),
|
||||||
"canvas_hotkey_fullscreen": shared.OptionInfo("S", "Fullscreen Mode, maximizes the picture so that it fits into the screen and stretches it to its full width "),
|
"canvas_hotkey_fullscreen": shared.OptionInfo("S", "Fullscreen Mode, maximizes the picture so that it fits into the screen and stretches it to its full width "),
|
||||||
"canvas_hotkey_reset": shared.OptionInfo("R", "Reset zoom and canvas positon"),
|
"canvas_hotkey_reset": shared.OptionInfo("R", "Reset zoom and canvas positon"),
|
||||||
"canvas_hotkey_overlap": shared.OptionInfo("O", "Toggle overlap ( Technical button, neededs for testing )"),
|
"canvas_hotkey_overlap": shared.OptionInfo("O", "Toggle overlap").info("Technical button, neededs for testing"),
|
||||||
"canvas_show_tooltip": shared.OptionInfo(True, "Enable tooltip on the canvas"),
|
"canvas_show_tooltip": shared.OptionInfo(True, "Enable tooltip on the canvas"),
|
||||||
"canvas_swap_controls": shared.OptionInfo(False, "Swap hotkey combinations for Zoom and Adjust brush resize"),
|
"canvas_disabled_functions": shared.OptionInfo(["Overlap"], "Disable function that you don't use", gr.CheckboxGroup, {"choices": ["Zoom","Adjust brush size", "Moving canvas","Fullscreen","Reset Zoom","Overlap"]}),
|
||||||
}))
|
}))
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
.tooltip-info {
|
.canvas-tooltip-info {
|
||||||
position: absolute;
|
position: absolute;
|
||||||
top: 10px;
|
top: 10px;
|
||||||
left: 10px;
|
left: 10px;
|
||||||
@ -15,7 +15,7 @@
|
|||||||
z-index: 100;
|
z-index: 100;
|
||||||
}
|
}
|
||||||
|
|
||||||
.tooltip-info::after {
|
.canvas-tooltip-info::after {
|
||||||
content: '';
|
content: '';
|
||||||
display: block;
|
display: block;
|
||||||
width: 2px;
|
width: 2px;
|
||||||
@ -24,7 +24,7 @@
|
|||||||
margin-top: 2px;
|
margin-top: 2px;
|
||||||
}
|
}
|
||||||
|
|
||||||
.tooltip-info::before {
|
.canvas-tooltip-info::before {
|
||||||
content: '';
|
content: '';
|
||||||
display: block;
|
display: block;
|
||||||
width: 2px;
|
width: 2px;
|
||||||
@ -32,7 +32,7 @@
|
|||||||
background-color: white;
|
background-color: white;
|
||||||
}
|
}
|
||||||
|
|
||||||
.tooltip-content {
|
.canvas-tooltip-content {
|
||||||
display: none;
|
display: none;
|
||||||
background-color: #f9f9f9;
|
background-color: #f9f9f9;
|
||||||
color: #333;
|
color: #333;
|
||||||
@ -50,7 +50,7 @@
|
|||||||
z-index: 100;
|
z-index: 100;
|
||||||
}
|
}
|
||||||
|
|
||||||
.tooltip:hover .tooltip-content {
|
.canvas-tooltip:hover .canvas-tooltip-content {
|
||||||
display: block;
|
display: block;
|
||||||
animation: fadeIn 0.5s;
|
animation: fadeIn 0.5s;
|
||||||
opacity: 1;
|
opacity: 1;
|
||||||
|
@ -72,3 +72,21 @@ function config_state_confirm_restore(_, config_state_name, config_restore_type)
|
|||||||
}
|
}
|
||||||
return [confirmed, config_state_name, config_restore_type];
|
return [confirmed, config_state_name, config_restore_type];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function toggle_all_extensions(event) {
|
||||||
|
gradioApp().querySelectorAll('#extensions .extension_toggle').forEach(function(checkbox_el) {
|
||||||
|
checkbox_el.checked = event.target.checked;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
function toggle_extension() {
|
||||||
|
let all_extensions_toggled = true;
|
||||||
|
for (const checkbox_el of gradioApp().querySelectorAll('#extensions .extension_toggle')) {
|
||||||
|
if (!checkbox_el.checked) {
|
||||||
|
all_extensions_toggled = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
gradioApp().querySelector('#extensions .all_extensions_toggle').checked = all_extensions_toggled;
|
||||||
|
}
|
||||||
|
@ -15,7 +15,7 @@ var titles = {
|
|||||||
"CFG Scale": "Classifier Free Guidance Scale - how strongly the image should conform to prompt - lower values produce more creative results",
|
"CFG Scale": "Classifier Free Guidance Scale - how strongly the image should conform to prompt - lower values produce more creative results",
|
||||||
"Seed": "A value that determines the output of random number generator - if you create an image with same parameters and seed as another image, you'll get the same result",
|
"Seed": "A value that determines the output of random number generator - if you create an image with same parameters and seed as another image, you'll get the same result",
|
||||||
"\u{1f3b2}\ufe0f": "Set seed to -1, which will cause a new random number to be used every time",
|
"\u{1f3b2}\ufe0f": "Set seed to -1, which will cause a new random number to be used every time",
|
||||||
"\u267b\ufe0f": "Reuse seed from last generation, mostly useful if it was randomed",
|
"\u267b\ufe0f": "Reuse seed from last generation, mostly useful if it was randomized",
|
||||||
"\u2199\ufe0f": "Read generation parameters from prompt or last generation if prompt is empty into user interface.",
|
"\u2199\ufe0f": "Read generation parameters from prompt or last generation if prompt is empty into user interface.",
|
||||||
"\u{1f4c2}": "Open images output directory",
|
"\u{1f4c2}": "Open images output directory",
|
||||||
"\u{1f4be}": "Save style",
|
"\u{1f4be}": "Save style",
|
||||||
@ -112,7 +112,7 @@ var titles = {
|
|||||||
"Resize height to": "Resizes image to this height. If 0, height is inferred from either of two nearby sliders.",
|
"Resize height to": "Resizes image to this height. If 0, height is inferred from either of two nearby sliders.",
|
||||||
"Multiplier for extra networks": "When adding extra network such as Hypernetwork or Lora to prompt, use this multiplier for it.",
|
"Multiplier for extra networks": "When adding extra network such as Hypernetwork or Lora to prompt, use this multiplier for it.",
|
||||||
"Discard weights with matching name": "Regular expression; if weights's name matches it, the weights is not written to the resulting checkpoint. Use ^model_ema to discard EMA weights.",
|
"Discard weights with matching name": "Regular expression; if weights's name matches it, the weights is not written to the resulting checkpoint. Use ^model_ema to discard EMA weights.",
|
||||||
"Extra networks tab order": "Comma-separated list of tab names; tabs listed here will appear in the extra networks UI first and in order lsited.",
|
"Extra networks tab order": "Comma-separated list of tab names; tabs listed here will appear in the extra networks UI first and in order listed.",
|
||||||
"Negative Guidance minimum sigma": "Skip negative prompt for steps where image is already mostly denoised; the higher this value, the more skips there will be; provides increased performance in exchange for minor quality reduction."
|
"Negative Guidance minimum sigma": "Skip negative prompt for steps where image is already mostly denoised; the higher this value, the more skips there will be; provides increased performance in exchange for minor quality reduction."
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -14,7 +14,7 @@ from fastapi.encoders import jsonable_encoder
|
|||||||
from secrets import compare_digest
|
from secrets import compare_digest
|
||||||
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors
|
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart
|
||||||
from modules.api import models
|
from modules.api import models
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
||||||
@ -22,7 +22,7 @@ from modules.textual_inversion.textual_inversion import create_embedding, train_
|
|||||||
from modules.textual_inversion.preprocess import preprocess
|
from modules.textual_inversion.preprocess import preprocess
|
||||||
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
|
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
|
||||||
from PIL import PngImagePlugin,Image
|
from PIL import PngImagePlugin,Image
|
||||||
from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights
|
from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights, checkpoint_alisases
|
||||||
from modules.sd_vae import vae_dict
|
from modules.sd_vae import vae_dict
|
||||||
from modules.sd_models_config import find_checkpoint_config_near_filename
|
from modules.sd_models_config import find_checkpoint_config_near_filename
|
||||||
from modules.realesrgan_model import get_realesrgan_models
|
from modules.realesrgan_model import get_realesrgan_models
|
||||||
@ -32,13 +32,6 @@ import piexif
|
|||||||
import piexif.helper
|
import piexif.helper
|
||||||
|
|
||||||
|
|
||||||
def upscaler_to_index(name: str):
|
|
||||||
try:
|
|
||||||
return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in shared.sd_upscalers])}") from e
|
|
||||||
|
|
||||||
|
|
||||||
def script_name_to_index(name, scripts):
|
def script_name_to_index(name, scripts):
|
||||||
try:
|
try:
|
||||||
return [script.title().lower() for script in scripts].index(name.lower())
|
return [script.title().lower() for script in scripts].index(name.lower())
|
||||||
@ -209,6 +202,11 @@ class Api:
|
|||||||
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
|
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
|
||||||
self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo])
|
self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo])
|
||||||
|
|
||||||
|
if shared.cmd_opts.add_stop_route:
|
||||||
|
self.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"])
|
||||||
|
self.add_api_route("/sdapi/v1/server-restart", self.restart_webui, methods=["POST"])
|
||||||
|
self.add_api_route("/sdapi/v1/server-stop", self.stop_webui, methods=["POST"])
|
||||||
|
|
||||||
self.default_script_arg_txt2img = []
|
self.default_script_arg_txt2img = []
|
||||||
self.default_script_arg_img2img = []
|
self.default_script_arg_img2img = []
|
||||||
|
|
||||||
@ -517,6 +515,10 @@ class Api:
|
|||||||
return options
|
return options
|
||||||
|
|
||||||
def set_config(self, req: Dict[str, Any]):
|
def set_config(self, req: Dict[str, Any]):
|
||||||
|
checkpoint_name = req.get("sd_model_checkpoint", None)
|
||||||
|
if checkpoint_name is not None and checkpoint_name not in checkpoint_alisases:
|
||||||
|
raise RuntimeError(f"model {checkpoint_name!r} not found")
|
||||||
|
|
||||||
for k, v in req.items():
|
for k, v in req.items():
|
||||||
shared.opts.set(k, v)
|
shared.opts.set(k, v)
|
||||||
|
|
||||||
@ -715,3 +717,15 @@ class Api:
|
|||||||
def launch(self, server_name, port):
|
def launch(self, server_name, port):
|
||||||
self.app.include_router(self.router)
|
self.app.include_router(self.router)
|
||||||
uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=0)
|
uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=0)
|
||||||
|
|
||||||
|
def kill_webui(self):
|
||||||
|
restart.stop_program()
|
||||||
|
|
||||||
|
def restart_webui(self):
|
||||||
|
if restart.is_restartable():
|
||||||
|
restart.restart_program()
|
||||||
|
return Response(status_code=501)
|
||||||
|
|
||||||
|
def stop_webui(request):
|
||||||
|
shared.state.server_command = "stop"
|
||||||
|
return Response("Stopping.")
|
||||||
|
@ -274,10 +274,6 @@ class PromptStyleItem(BaseModel):
|
|||||||
prompt: Optional[str] = Field(title="Prompt")
|
prompt: Optional[str] = Field(title="Prompt")
|
||||||
negative_prompt: Optional[str] = Field(title="Negative Prompt")
|
negative_prompt: Optional[str] = Field(title="Negative Prompt")
|
||||||
|
|
||||||
class ArtistItem(BaseModel):
|
|
||||||
name: str = Field(title="Name")
|
|
||||||
score: float = Field(title="Score")
|
|
||||||
category: str = Field(title="Category")
|
|
||||||
|
|
||||||
class EmbeddingItem(BaseModel):
|
class EmbeddingItem(BaseModel):
|
||||||
step: Optional[int] = Field(title="Step", description="The number of steps that were used to train this embedding, if available")
|
step: Optional[int] = Field(title="Step", description="The number of steps that were used to train this embedding, if available")
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
from functools import wraps
|
||||||
import html
|
import html
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
@ -18,6 +19,7 @@ def wrap_queued_call(func):
|
|||||||
|
|
||||||
|
|
||||||
def wrap_gradio_gpu_call(func, extra_outputs=None):
|
def wrap_gradio_gpu_call(func, extra_outputs=None):
|
||||||
|
@wraps(func)
|
||||||
def f(*args, **kwargs):
|
def f(*args, **kwargs):
|
||||||
|
|
||||||
# if the first argument is a string that says "task(...)", it is treated as a job id
|
# if the first argument is a string that says "task(...)", it is treated as a job id
|
||||||
@ -45,6 +47,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None):
|
|||||||
|
|
||||||
|
|
||||||
def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
|
def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
|
||||||
|
@wraps(func)
|
||||||
def f(*args, extra_outputs_array=extra_outputs, **kwargs):
|
def f(*args, extra_outputs_array=extra_outputs, **kwargs):
|
||||||
run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats
|
run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats
|
||||||
if run_memmon:
|
if run_memmon:
|
||||||
|
@ -106,4 +106,4 @@ parser.add_argument("--skip-version-check", action='store_true', help="Do not ch
|
|||||||
parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
|
parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
|
||||||
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
|
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
|
||||||
parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy')
|
parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy')
|
||||||
parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server')
|
parser.add_argument('--add-stop-route', action='store_true', help='enable server stop/restart/kill via api')
|
||||||
|
@ -15,14 +15,11 @@ model_dir = "Codeformer"
|
|||||||
model_path = os.path.join(models_path, model_dir)
|
model_path = os.path.join(models_path, model_dir)
|
||||||
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
|
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
|
||||||
|
|
||||||
have_codeformer = False
|
|
||||||
codeformer = None
|
codeformer = None
|
||||||
|
|
||||||
|
|
||||||
def setup_model(dirname):
|
def setup_model(dirname):
|
||||||
global model_path
|
os.makedirs(model_path, exist_ok=True)
|
||||||
if not os.path.exists(model_path):
|
|
||||||
os.makedirs(model_path)
|
|
||||||
|
|
||||||
path = modules.paths.paths.get("CodeFormer", None)
|
path = modules.paths.paths.get("CodeFormer", None)
|
||||||
if path is None:
|
if path is None:
|
||||||
@ -125,9 +122,6 @@ def setup_model(dirname):
|
|||||||
|
|
||||||
return restored_img
|
return restored_img
|
||||||
|
|
||||||
global have_codeformer
|
|
||||||
have_codeformer = True
|
|
||||||
|
|
||||||
global codeformer
|
global codeformer
|
||||||
codeformer = FaceRestorerCodeFormer(dirname)
|
codeformer = FaceRestorerCodeFormer(dirname)
|
||||||
shared.face_restorers.append(codeformer)
|
shared.face_restorers.append(codeformer)
|
||||||
|
@ -15,13 +15,6 @@ def has_mps() -> bool:
|
|||||||
else:
|
else:
|
||||||
return mac_specific.has_mps
|
return mac_specific.has_mps
|
||||||
|
|
||||||
def extract_device_id(args, name):
|
|
||||||
for x in range(len(args)):
|
|
||||||
if name in args[x]:
|
|
||||||
return args[x + 1]
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def get_cuda_device_string():
|
def get_cuda_device_string():
|
||||||
from modules import shared
|
from modules import shared
|
||||||
|
@ -1,15 +1,13 @@
|
|||||||
import os
|
import sys
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from basicsr.utils.download_util import load_file_from_url
|
|
||||||
|
|
||||||
import modules.esrgan_model_arch as arch
|
import modules.esrgan_model_arch as arch
|
||||||
from modules import modelloader, images, devices
|
from modules import modelloader, images, devices
|
||||||
from modules.upscaler import Upscaler, UpscalerData
|
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
|
from modules.upscaler import Upscaler, UpscalerData
|
||||||
|
|
||||||
|
|
||||||
def mod2normal(state_dict):
|
def mod2normal(state_dict):
|
||||||
@ -134,7 +132,7 @@ class UpscalerESRGAN(Upscaler):
|
|||||||
scaler_data = UpscalerData(self.model_name, self.model_url, self, 4)
|
scaler_data = UpscalerData(self.model_name, self.model_url, self, 4)
|
||||||
scalers.append(scaler_data)
|
scalers.append(scaler_data)
|
||||||
for file in model_paths:
|
for file in model_paths:
|
||||||
if "http" in file:
|
if file.startswith("http"):
|
||||||
name = self.model_name
|
name = self.model_name
|
||||||
else:
|
else:
|
||||||
name = modelloader.friendly_name(file)
|
name = modelloader.friendly_name(file)
|
||||||
@ -143,26 +141,25 @@ class UpscalerESRGAN(Upscaler):
|
|||||||
self.scalers.append(scaler_data)
|
self.scalers.append(scaler_data)
|
||||||
|
|
||||||
def do_upscale(self, img, selected_model):
|
def do_upscale(self, img, selected_model):
|
||||||
model = self.load_model(selected_model)
|
try:
|
||||||
if model is None:
|
model = self.load_model(selected_model)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Unable to load ESRGAN model {selected_model}: {e}", file=sys.stderr)
|
||||||
return img
|
return img
|
||||||
model.to(devices.device_esrgan)
|
model.to(devices.device_esrgan)
|
||||||
img = esrgan_upscale(model, img)
|
img = esrgan_upscale(model, img)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
def load_model(self, path: str):
|
def load_model(self, path: str):
|
||||||
if "http" in path:
|
if path.startswith("http"):
|
||||||
filename = load_file_from_url(
|
# TODO: this doesn't use `path` at all?
|
||||||
|
filename = modelloader.load_file_from_url(
|
||||||
url=self.model_url,
|
url=self.model_url,
|
||||||
model_dir=self.model_download_path,
|
model_dir=self.model_download_path,
|
||||||
file_name=f"{self.model_name}.pth",
|
file_name=f"{self.model_name}.pth",
|
||||||
progress=True,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
filename = path
|
filename = path
|
||||||
if not os.path.exists(filename) or filename is None:
|
|
||||||
print(f"Unable to load {self.model_path} from {filename}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
|
state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
|
||||||
|
|
||||||
|
@ -7,8 +7,7 @@ from modules.paths_internal import extensions_dir, extensions_builtin_dir, scrip
|
|||||||
|
|
||||||
extensions = []
|
extensions = []
|
||||||
|
|
||||||
if not os.path.exists(extensions_dir):
|
os.makedirs(extensions_dir, exist_ok=True)
|
||||||
os.makedirs(extensions_dir)
|
|
||||||
|
|
||||||
|
|
||||||
def active():
|
def active():
|
||||||
|
@ -174,31 +174,6 @@ def send_image_and_dimensions(x):
|
|||||||
return img, w, h
|
return img, w, h
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def find_hypernetwork_key(hypernet_name, hypernet_hash=None):
|
|
||||||
"""Determines the config parameter name to use for the hypernet based on the parameters in the infotext.
|
|
||||||
|
|
||||||
Example: an infotext provides "Hypernet: ke-ta" and "Hypernet hash: 1234abcd". For the "Hypernet" config
|
|
||||||
parameter this means there should be an entry that looks like "ke-ta-10000(1234abcd)" to set it to.
|
|
||||||
|
|
||||||
If the infotext has no hash, then a hypernet with the same name will be selected instead.
|
|
||||||
"""
|
|
||||||
hypernet_name = hypernet_name.lower()
|
|
||||||
if hypernet_hash is not None:
|
|
||||||
# Try to match the hash in the name
|
|
||||||
for hypernet_key in shared.hypernetworks.keys():
|
|
||||||
result = re_hypernet_hash.search(hypernet_key)
|
|
||||||
if result is not None and result[1] == hypernet_hash:
|
|
||||||
return hypernet_key
|
|
||||||
else:
|
|
||||||
# Fall back to a hypernet with the same name
|
|
||||||
for hypernet_key in shared.hypernetworks.keys():
|
|
||||||
if hypernet_key.lower().startswith(hypernet_name):
|
|
||||||
return hypernet_key
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def restore_old_hires_fix_params(res):
|
def restore_old_hires_fix_params(res):
|
||||||
"""for infotexts that specify old First pass size parameter, convert it into
|
"""for infotexts that specify old First pass size parameter, convert it into
|
||||||
width, height, and hr scale"""
|
width, height, and hr scale"""
|
||||||
@ -332,10 +307,6 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
settings_map = {}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
infotext_to_setting_name_mapping = [
|
infotext_to_setting_name_mapping = [
|
||||||
('Clip skip', 'CLIP_stop_at_last_layers', ),
|
('Clip skip', 'CLIP_stop_at_last_layers', ),
|
||||||
('Conditional mask weight', 'inpainting_mask_weight'),
|
('Conditional mask weight', 'inpainting_mask_weight'),
|
||||||
@ -357,6 +328,7 @@ infotext_to_setting_name_mapping = [
|
|||||||
('Token merging ratio hr', 'token_merging_ratio_hr'),
|
('Token merging ratio hr', 'token_merging_ratio_hr'),
|
||||||
('RNG', 'randn_source'),
|
('RNG', 'randn_source'),
|
||||||
('NGMS', 's_min_uncond'),
|
('NGMS', 's_min_uncond'),
|
||||||
|
('Pad conds', 'pad_cond_uncond'),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -25,7 +25,7 @@ def gfpgann():
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN")
|
models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN")
|
||||||
if len(models) == 1 and "http" in models[0]:
|
if len(models) == 1 and models[0].startswith("http"):
|
||||||
model_file = models[0]
|
model_file = models[0]
|
||||||
elif len(models) != 0:
|
elif len(models) != 0:
|
||||||
latest_file = max(models, key=os.path.getctime)
|
latest_file = max(models, key=os.path.getctime)
|
||||||
@ -70,11 +70,8 @@ gfpgan_constructor = None
|
|||||||
|
|
||||||
|
|
||||||
def setup_model(dirname):
|
def setup_model(dirname):
|
||||||
global model_path
|
|
||||||
if not os.path.exists(model_path):
|
|
||||||
os.makedirs(model_path)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
os.makedirs(model_path, exist_ok=True)
|
||||||
from gfpgan import GFPGANer
|
from gfpgan import GFPGANer
|
||||||
from facexlib import detection, parsing # noqa: F401
|
from facexlib import detection, parsing # noqa: F401
|
||||||
global user_path
|
global user_path
|
||||||
|
@ -353,17 +353,6 @@ def load_hypernetworks(names, multipliers=None):
|
|||||||
shared.loaded_hypernetworks.append(hypernetwork)
|
shared.loaded_hypernetworks.append(hypernetwork)
|
||||||
|
|
||||||
|
|
||||||
def find_closest_hypernetwork_name(search: str):
|
|
||||||
if not search:
|
|
||||||
return None
|
|
||||||
search = search.lower()
|
|
||||||
applicable = [name for name in shared.hypernetworks if search in name.lower()]
|
|
||||||
if not applicable:
|
|
||||||
return None
|
|
||||||
applicable = sorted(applicable, key=lambda name: len(name))
|
|
||||||
return applicable[0]
|
|
||||||
|
|
||||||
|
|
||||||
def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None):
|
def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None):
|
||||||
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None)
|
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None)
|
||||||
|
|
||||||
@ -446,18 +435,6 @@ def statistics(data):
|
|||||||
return total_information, recent_information
|
return total_information, recent_information
|
||||||
|
|
||||||
|
|
||||||
def report_statistics(loss_info:dict):
|
|
||||||
keys = sorted(loss_info.keys(), key=lambda x: sum(loss_info[x]) / len(loss_info[x]))
|
|
||||||
for key in keys:
|
|
||||||
try:
|
|
||||||
print("Loss statistics for file " + key)
|
|
||||||
info, recent = statistics(list(loss_info[key]))
|
|
||||||
print(info)
|
|
||||||
print(recent)
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
|
|
||||||
|
|
||||||
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
|
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
|
||||||
# Remove illegal characters from name.
|
# Remove illegal characters from name.
|
||||||
name = "".join( x for x in name if (x.isalnum() or x in "._- "))
|
name = "".join( x for x in name if (x.isalnum() or x in "._- "))
|
||||||
@ -770,7 +747,6 @@ Last saved image: {html.escape(last_saved_image)}<br/>
|
|||||||
pbar.leave = False
|
pbar.leave = False
|
||||||
pbar.close()
|
pbar.close()
|
||||||
hypernetwork.eval()
|
hypernetwork.eval()
|
||||||
#report_statistics(loss_dict)
|
|
||||||
sd_hijack_checkpoint.remove()
|
sd_hijack_checkpoint.remove()
|
||||||
|
|
||||||
|
|
||||||
|
@ -372,8 +372,8 @@ class FilenameGenerator:
|
|||||||
'hasprompt': lambda self, *args: self.hasprompt(*args), # accepts formats:[hasprompt<prompt1|default><prompt2>..]
|
'hasprompt': lambda self, *args: self.hasprompt(*args), # accepts formats:[hasprompt<prompt1|default><prompt2>..]
|
||||||
'clip_skip': lambda self: opts.data["CLIP_stop_at_last_layers"],
|
'clip_skip': lambda self: opts.data["CLIP_stop_at_last_layers"],
|
||||||
'denoising': lambda self: self.p.denoising_strength if self.p and self.p.denoising_strength else NOTHING_AND_SKIP_PREVIOUS_TEXT,
|
'denoising': lambda self: self.p.denoising_strength if self.p and self.p.denoising_strength else NOTHING_AND_SKIP_PREVIOUS_TEXT,
|
||||||
|
'user': lambda self: self.p.user,
|
||||||
'vae_filename': lambda self: self.get_vae_filename(),
|
'vae_filename': lambda self: self.get_vae_filename(),
|
||||||
|
|
||||||
}
|
}
|
||||||
default_time_format = '%Y%m%d%H%M%S'
|
default_time_format = '%Y%m%d%H%M%S'
|
||||||
|
|
||||||
|
@ -3,6 +3,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError
|
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
from modules import sd_samplers
|
from modules import sd_samplers
|
||||||
from modules.generation_parameters_copypaste import create_override_settings_dict
|
from modules.generation_parameters_copypaste import create_override_settings_dict
|
||||||
@ -97,7 +98,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
|
|||||||
processed_image.save(os.path.join(output_dir, filename))
|
processed_image.save(os.path.join(output_dir, filename))
|
||||||
|
|
||||||
|
|
||||||
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args):
|
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, request: gr.Request, *args):
|
||||||
override_settings = create_override_settings_dict(override_settings_texts)
|
override_settings = create_override_settings_dict(override_settings_texts)
|
||||||
|
|
||||||
is_batch = mode == 5
|
is_batch = mode == 5
|
||||||
@ -180,6 +181,8 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
|
|||||||
p.scripts = modules.scripts.scripts_img2img
|
p.scripts = modules.scripts.scripts_img2img
|
||||||
p.script_args = args
|
p.script_args = args
|
||||||
|
|
||||||
|
p.user = request.username
|
||||||
|
|
||||||
if shared.cmd_opts.enable_console_prompts:
|
if shared.cmd_opts.enable_console_prompts:
|
||||||
print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
|
print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
|
||||||
|
|
||||||
|
@ -147,10 +147,10 @@ def git_clone(url, dir, name, commithash=None):
|
|||||||
return
|
return
|
||||||
|
|
||||||
run(f'"{git}" -C "{dir}" fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")
|
run(f'"{git}" -C "{dir}" fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")
|
||||||
run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}")
|
run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}", live=True)
|
||||||
return
|
return
|
||||||
|
|
||||||
run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}")
|
run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}", live=True)
|
||||||
|
|
||||||
if commithash is not None:
|
if commithash is not None:
|
||||||
run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
|
run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import importlib
|
import importlib
|
||||||
@ -8,6 +10,29 @@ from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, Upscale
|
|||||||
from modules.paths import script_path, models_path
|
from modules.paths import script_path, models_path
|
||||||
|
|
||||||
|
|
||||||
|
def load_file_from_url(
|
||||||
|
url: str,
|
||||||
|
*,
|
||||||
|
model_dir: str,
|
||||||
|
progress: bool = True,
|
||||||
|
file_name: str | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""Download a file from `url` into `model_dir`, using the file present if possible.
|
||||||
|
|
||||||
|
Returns the path to the downloaded file.
|
||||||
|
"""
|
||||||
|
os.makedirs(model_dir, exist_ok=True)
|
||||||
|
if not file_name:
|
||||||
|
parts = urlparse(url)
|
||||||
|
file_name = os.path.basename(parts.path)
|
||||||
|
cached_file = os.path.abspath(os.path.join(model_dir, file_name))
|
||||||
|
if not os.path.exists(cached_file):
|
||||||
|
print(f'Downloading: "{url}" to {cached_file}\n')
|
||||||
|
from torch.hub import download_url_to_file
|
||||||
|
download_url_to_file(url, cached_file, progress=progress)
|
||||||
|
return cached_file
|
||||||
|
|
||||||
|
|
||||||
def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None) -> list:
|
def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None) -> list:
|
||||||
"""
|
"""
|
||||||
A one-and done loader to try finding the desired models in specified directories.
|
A one-and done loader to try finding the desired models in specified directories.
|
||||||
@ -46,9 +71,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
|
|||||||
|
|
||||||
if model_url is not None and len(output) == 0:
|
if model_url is not None and len(output) == 0:
|
||||||
if download_name is not None:
|
if download_name is not None:
|
||||||
from basicsr.utils.download_util import load_file_from_url
|
output.append(load_file_from_url(model_url, model_dir=places[0], file_name=download_name))
|
||||||
dl = load_file_from_url(model_url, places[0], True, download_name)
|
|
||||||
output.append(dl)
|
|
||||||
else:
|
else:
|
||||||
output.append(model_url)
|
output.append(model_url)
|
||||||
|
|
||||||
@ -59,7 +82,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
|
|||||||
|
|
||||||
|
|
||||||
def friendly_name(file: str):
|
def friendly_name(file: str):
|
||||||
if "http" in file:
|
if file.startswith("http"):
|
||||||
file = urlparse(file).path
|
file = urlparse(file).path
|
||||||
|
|
||||||
file = os.path.basename(file)
|
file = os.path.basename(file)
|
||||||
@ -95,8 +118,7 @@ def cleanup_models():
|
|||||||
|
|
||||||
def move_files(src_path: str, dest_path: str, ext_filter: str = None):
|
def move_files(src_path: str, dest_path: str, ext_filter: str = None):
|
||||||
try:
|
try:
|
||||||
if not os.path.exists(dest_path):
|
os.makedirs(dest_path, exist_ok=True)
|
||||||
os.makedirs(dest_path)
|
|
||||||
if os.path.exists(src_path):
|
if os.path.exists(src_path):
|
||||||
for file in os.listdir(src_path):
|
for file in os.listdir(src_path):
|
||||||
fullpath = os.path.join(src_path, file)
|
fullpath = os.path.join(src_path, file)
|
||||||
|
@ -38,17 +38,3 @@ for d, must_exist, what, options in path_dirs:
|
|||||||
else:
|
else:
|
||||||
sys.path.append(d)
|
sys.path.append(d)
|
||||||
paths[what] = d
|
paths[what] = d
|
||||||
|
|
||||||
|
|
||||||
class Prioritize:
|
|
||||||
def __init__(self, name):
|
|
||||||
self.name = name
|
|
||||||
self.path = None
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
self.path = sys.path.copy()
|
|
||||||
sys.path = [paths[self.name]] + sys.path
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
||||||
sys.path = self.path
|
|
||||||
self.path = None
|
|
||||||
|
@ -184,6 +184,8 @@ class StableDiffusionProcessing:
|
|||||||
self.uc = None
|
self.uc = None
|
||||||
self.c = None
|
self.c = None
|
||||||
|
|
||||||
|
self.user = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def sd_model(self):
|
def sd_model(self):
|
||||||
return shared.sd_model
|
return shared.sd_model
|
||||||
@ -549,7 +551,7 @@ def program_version():
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0):
|
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0, use_main_prompt=False):
|
||||||
index = position_in_batch + iteration * p.batch_size
|
index = position_in_batch + iteration * p.batch_size
|
||||||
|
|
||||||
clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
|
clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
|
||||||
@ -585,13 +587,15 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
|
|||||||
"NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
|
"NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
|
||||||
**p.extra_generation_params,
|
**p.extra_generation_params,
|
||||||
"Version": program_version() if opts.add_version_to_infotext else None,
|
"Version": program_version() if opts.add_version_to_infotext else None,
|
||||||
|
"User": p.user if opts.add_user_name_to_info else None,
|
||||||
}
|
}
|
||||||
|
|
||||||
generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
|
generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
|
||||||
|
|
||||||
|
prompt_text = p.prompt if use_main_prompt else all_prompts[index]
|
||||||
negative_prompt_text = f"\nNegative prompt: {p.all_negative_prompts[index]}" if p.all_negative_prompts[index] else ""
|
negative_prompt_text = f"\nNegative prompt: {p.all_negative_prompts[index]}" if p.all_negative_prompts[index] else ""
|
||||||
|
|
||||||
return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
|
return f"{prompt_text}{negative_prompt_text}\n{generation_params_text}".strip()
|
||||||
|
|
||||||
|
|
||||||
def process_images(p: StableDiffusionProcessing) -> Processed:
|
def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||||
@ -663,8 +667,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
else:
|
else:
|
||||||
p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]
|
p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]
|
||||||
|
|
||||||
def infotext(iteration=0, position_in_batch=0):
|
def infotext(iteration=0, position_in_batch=0, use_main_prompt=False):
|
||||||
return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch)
|
return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch, use_main_prompt)
|
||||||
|
|
||||||
if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
|
if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
|
||||||
model_hijack.embedding_db.load_textual_inversion_embeddings()
|
model_hijack.embedding_db.load_textual_inversion_embeddings()
|
||||||
@ -824,7 +828,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
grid = images.image_grid(output_images, p.batch_size)
|
grid = images.image_grid(output_images, p.batch_size)
|
||||||
|
|
||||||
if opts.return_grid:
|
if opts.return_grid:
|
||||||
text = infotext()
|
text = infotext(use_main_prompt=True)
|
||||||
infotexts.insert(0, text)
|
infotexts.insert(0, text)
|
||||||
if opts.enable_pnginfo:
|
if opts.enable_pnginfo:
|
||||||
grid.info["parameters"] = text
|
grid.info["parameters"] = text
|
||||||
@ -832,7 +836,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
index_of_first_image = 1
|
index_of_first_image = 1
|
||||||
|
|
||||||
if opts.grid_save:
|
if opts.grid_save:
|
||||||
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
|
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(use_main_prompt=True), short_filename=not opts.grid_extended_filename, p=p, grid=True)
|
||||||
|
|
||||||
if not p.disable_extra_networks and p.extra_network_data:
|
if not p.disable_extra_networks and p.extra_network_data:
|
||||||
extra_networks.deactivate(p, p.extra_network_data)
|
extra_networks.deactivate(p, p.extra_network_data)
|
||||||
|
@ -2,7 +2,6 @@ import os
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from basicsr.utils.download_util import load_file_from_url
|
|
||||||
from realesrgan import RealESRGANer
|
from realesrgan import RealESRGANer
|
||||||
|
|
||||||
from modules.upscaler import Upscaler, UpscalerData
|
from modules.upscaler import Upscaler, UpscalerData
|
||||||
@ -43,9 +42,10 @@ class UpscalerRealESRGAN(Upscaler):
|
|||||||
if not self.enable:
|
if not self.enable:
|
||||||
return img
|
return img
|
||||||
|
|
||||||
info = self.load_model(path)
|
try:
|
||||||
if not os.path.exists(info.local_data_path):
|
info = self.load_model(path)
|
||||||
print(f"Unable to load RealESRGAN model: {info.name}")
|
except Exception:
|
||||||
|
errors.report(f"Unable to load RealESRGAN model {path}", exc_info=True)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
upsampler = RealESRGANer(
|
upsampler = RealESRGANer(
|
||||||
@ -63,20 +63,17 @@ class UpscalerRealESRGAN(Upscaler):
|
|||||||
return image
|
return image
|
||||||
|
|
||||||
def load_model(self, path):
|
def load_model(self, path):
|
||||||
try:
|
for scaler in self.scalers:
|
||||||
info = next(iter([scaler for scaler in self.scalers if scaler.data_path == path]), None)
|
if scaler.data_path == path:
|
||||||
|
if scaler.local_data_path.startswith("http"):
|
||||||
if info is None:
|
scaler.local_data_path = modelloader.load_file_from_url(
|
||||||
print(f"Unable to find model info: {path}")
|
scaler.data_path,
|
||||||
return None
|
model_dir=self.model_download_path,
|
||||||
|
)
|
||||||
if info.local_data_path.startswith("http"):
|
if not os.path.exists(scaler.local_data_path):
|
||||||
info.local_data_path = load_file_from_url(url=info.data_path, model_dir=self.model_download_path, progress=True)
|
raise FileNotFoundError(f"RealESRGAN data missing: {scaler.local_data_path}")
|
||||||
|
return scaler
|
||||||
return info
|
raise ValueError(f"Unable to find model info: {path}")
|
||||||
except Exception:
|
|
||||||
errors.report("Error making Real-ESRGAN models list", exc_info=True)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def load_models(self, _):
|
def load_models(self, _):
|
||||||
return get_realesrgan_models(self)
|
return get_realesrgan_models(self)
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
|
import inspect
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
@ -249,7 +250,7 @@ def load_scripts():
|
|||||||
|
|
||||||
def register_scripts_from_module(module):
|
def register_scripts_from_module(module):
|
||||||
for script_class in module.__dict__.values():
|
for script_class in module.__dict__.values():
|
||||||
if type(script_class) != type:
|
if not inspect.isclass(script_class):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if issubclass(script_class, Script):
|
if issubclass(script_class, Script):
|
||||||
|
@ -95,8 +95,7 @@ except Exception:
|
|||||||
|
|
||||||
|
|
||||||
def setup_model():
|
def setup_model():
|
||||||
if not os.path.exists(model_path):
|
os.makedirs(model_path, exist_ok=True)
|
||||||
os.makedirs(model_path)
|
|
||||||
|
|
||||||
enable_midas_autodownload()
|
enable_midas_autodownload()
|
||||||
|
|
||||||
@ -248,7 +247,12 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None
|
|||||||
_, extension = os.path.splitext(checkpoint_file)
|
_, extension = os.path.splitext(checkpoint_file)
|
||||||
if extension.lower() == ".safetensors":
|
if extension.lower() == ".safetensors":
|
||||||
device = map_location or shared.weight_load_location or devices.get_optimal_device_name()
|
device = map_location or shared.weight_load_location or devices.get_optimal_device_name()
|
||||||
pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
|
|
||||||
|
if not shared.opts.disable_mmap_load_safetensors:
|
||||||
|
pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
|
||||||
|
else:
|
||||||
|
pl_sd = safetensors.torch.load(open(checkpoint_file, 'rb').read())
|
||||||
|
pl_sd = {k: v.to(device) for k, v in pl_sd.items()}
|
||||||
else:
|
else:
|
||||||
pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)
|
pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)
|
||||||
|
|
||||||
|
@ -69,6 +69,7 @@ class CFGDenoiser(torch.nn.Module):
|
|||||||
self.init_latent = None
|
self.init_latent = None
|
||||||
self.step = 0
|
self.step = 0
|
||||||
self.image_cfg_scale = None
|
self.image_cfg_scale = None
|
||||||
|
self.padded_cond_uncond = False
|
||||||
|
|
||||||
def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
|
def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
|
||||||
denoised_uncond = x_out[-uncond.shape[0]:]
|
denoised_uncond = x_out[-uncond.shape[0]:]
|
||||||
@ -133,15 +134,17 @@ class CFGDenoiser(torch.nn.Module):
|
|||||||
x_in = x_in[:-batch_size]
|
x_in = x_in[:-batch_size]
|
||||||
sigma_in = sigma_in[:-batch_size]
|
sigma_in = sigma_in[:-batch_size]
|
||||||
|
|
||||||
# TODO add infotext entry
|
self.padded_cond_uncond = False
|
||||||
if shared.opts.pad_cond_uncond and tensor.shape[1] != uncond.shape[1]:
|
if shared.opts.pad_cond_uncond and tensor.shape[1] != uncond.shape[1]:
|
||||||
empty = shared.sd_model.cond_stage_model_empty_prompt
|
empty = shared.sd_model.cond_stage_model_empty_prompt
|
||||||
num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1]
|
num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1]
|
||||||
|
|
||||||
if num_repeats < 0:
|
if num_repeats < 0:
|
||||||
tensor = torch.cat([tensor, empty.repeat((tensor.shape[0], -num_repeats, 1))], axis=1)
|
tensor = torch.cat([tensor, empty.repeat((tensor.shape[0], -num_repeats, 1))], axis=1)
|
||||||
|
self.padded_cond_uncond = True
|
||||||
elif num_repeats > 0:
|
elif num_repeats > 0:
|
||||||
uncond = torch.cat([uncond, empty.repeat((uncond.shape[0], num_repeats, 1))], axis=1)
|
uncond = torch.cat([uncond, empty.repeat((uncond.shape[0], num_repeats, 1))], axis=1)
|
||||||
|
self.padded_cond_uncond = True
|
||||||
|
|
||||||
if tensor.shape[1] == uncond.shape[1] or skip_uncond:
|
if tensor.shape[1] == uncond.shape[1] or skip_uncond:
|
||||||
if is_edit_model:
|
if is_edit_model:
|
||||||
@ -405,6 +408,9 @@ class KDiffusionSampler:
|
|||||||
|
|
||||||
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||||
|
|
||||||
|
if self.model_wrap_cfg.padded_cond_uncond:
|
||||||
|
p.extra_generation_params["Pad conds"] = True
|
||||||
|
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||||
@ -438,5 +444,8 @@ class KDiffusionSampler:
|
|||||||
's_min_uncond': self.s_min_uncond
|
's_min_uncond': self.s_min_uncond
|
||||||
}, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
}, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||||
|
|
||||||
|
if self.model_wrap_cfg.padded_cond_uncond:
|
||||||
|
p.extra_generation_params["Pad conds"] = True
|
||||||
|
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
@ -376,6 +376,7 @@ options_templates.update(options_section(('system', "System"), {
|
|||||||
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
|
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
|
||||||
"print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."),
|
"print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."),
|
||||||
"list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""),
|
"list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""),
|
||||||
|
"disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('training', "Training"), {
|
options_templates.update(options_section(('training', "Training"), {
|
||||||
@ -409,7 +410,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
|||||||
"enable_emphasis": OptionInfo(True, "Enable emphasis").info("use (text) to make model pay more attention to text and [text] to make it pay less attention"),
|
"enable_emphasis": OptionInfo(True, "Enable emphasis").info("use (text) to make model pay more attention to text and [text] to make it pay less attention"),
|
||||||
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
|
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
|
||||||
"comma_padding_backtrack": OptionInfo(20, "Prompt word wrap length limit", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1}).info("in tokens - for texts shorter than specified, if they don't fit into 75 token limit, move them to the next 75 token chunk"),
|
"comma_padding_backtrack": OptionInfo(20, "Prompt word wrap length limit", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1}).info("in tokens - for texts shorter than specified, if they don't fit into 75 token limit, move them to the next 75 token chunk"),
|
||||||
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP nrtwork; 1 ignores none, 2 ignores one layer"),
|
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"),
|
||||||
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
|
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
|
||||||
"randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors"),
|
"randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors"),
|
||||||
}))
|
}))
|
||||||
@ -493,6 +494,7 @@ options_templates.update(options_section(('ui', "User interface"), {
|
|||||||
options_templates.update(options_section(('infotext', "Infotext"), {
|
options_templates.update(options_section(('infotext', "Infotext"), {
|
||||||
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
|
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
|
||||||
"add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
|
"add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
|
||||||
|
"add_user_name_to_info": OptionInfo(False, "Add user name to generation information when authenticated"),
|
||||||
"add_version_to_infotext": OptionInfo(True, "Add program version to generation information"),
|
"add_version_to_infotext": OptionInfo(True, "Add program version to generation information"),
|
||||||
"disable_weights_auto_swap": OptionInfo(True, "Disregard checkpoint information from pasted infotext").info("when reading generation parameters from text into UI"),
|
"disable_weights_auto_swap": OptionInfo(True, "Disregard checkpoint information from pasted infotext").info("when reading generation parameters from text into UI"),
|
||||||
"infotext_styles": OptionInfo("Apply if any", "Infer styles from prompts of pasted infotext", gr.Radio, {"choices": ["Ignore", "Apply", "Discard", "Apply if any"]}).info("when reading generation parameters from text into UI)").html("""<ul style='margin-left: 1.5em'>
|
"infotext_styles": OptionInfo("Apply if any", "Infer styles from prompts of pasted infotext", gr.Radio, {"choices": ["Ignore", "Apply", "Discard", "Apply if any"]}).info("when reading generation parameters from text into UI)").html("""<ul style='margin-left: 1.5em'>
|
||||||
|
@ -298,8 +298,7 @@ def download_and_cache_models(dirname):
|
|||||||
download_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true'
|
download_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true'
|
||||||
model_file_name = 'face_detection_yunet.onnx'
|
model_file_name = 'face_detection_yunet.onnx'
|
||||||
|
|
||||||
if not os.path.exists(dirname):
|
os.makedirs(dirname, exist_ok=True)
|
||||||
os.makedirs(dirname)
|
|
||||||
|
|
||||||
cache_file = os.path.join(dirname, model_file_name)
|
cache_file = os.path.join(dirname, model_file_name)
|
||||||
if not os.path.exists(cache_file):
|
if not os.path.exists(cache_file):
|
||||||
|
@ -2,11 +2,51 @@ import datetime
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
|
||||||
saved_params_shared = {"model_name", "model_hash", "initial_step", "num_of_dataset_images", "learn_rate", "batch_size", "clip_grad_mode", "clip_grad_value", "gradient_step", "data_root", "log_directory", "training_width", "training_height", "steps", "create_image_every", "template_file", "gradient_step", "latent_sampling_method"}
|
saved_params_shared = {
|
||||||
saved_params_ti = {"embedding_name", "num_vectors_per_token", "save_embedding_every", "save_image_with_stored_embedding"}
|
"batch_size",
|
||||||
saved_params_hypernet = {"hypernetwork_name", "layer_structure", "activation_func", "weight_init", "add_layer_norm", "use_dropout", "save_hypernetwork_every"}
|
"clip_grad_mode",
|
||||||
|
"clip_grad_value",
|
||||||
|
"create_image_every",
|
||||||
|
"data_root",
|
||||||
|
"gradient_step",
|
||||||
|
"initial_step",
|
||||||
|
"latent_sampling_method",
|
||||||
|
"learn_rate",
|
||||||
|
"log_directory",
|
||||||
|
"model_hash",
|
||||||
|
"model_name",
|
||||||
|
"num_of_dataset_images",
|
||||||
|
"steps",
|
||||||
|
"template_file",
|
||||||
|
"training_height",
|
||||||
|
"training_width",
|
||||||
|
}
|
||||||
|
saved_params_ti = {
|
||||||
|
"embedding_name",
|
||||||
|
"num_vectors_per_token",
|
||||||
|
"save_embedding_every",
|
||||||
|
"save_image_with_stored_embedding",
|
||||||
|
}
|
||||||
|
saved_params_hypernet = {
|
||||||
|
"activation_func",
|
||||||
|
"add_layer_norm",
|
||||||
|
"hypernetwork_name",
|
||||||
|
"layer_structure",
|
||||||
|
"save_hypernetwork_every",
|
||||||
|
"use_dropout",
|
||||||
|
"weight_init",
|
||||||
|
}
|
||||||
saved_params_all = saved_params_shared | saved_params_ti | saved_params_hypernet
|
saved_params_all = saved_params_shared | saved_params_ti | saved_params_hypernet
|
||||||
saved_params_previews = {"preview_prompt", "preview_negative_prompt", "preview_steps", "preview_sampler_index", "preview_cfg_scale", "preview_seed", "preview_width", "preview_height"}
|
saved_params_previews = {
|
||||||
|
"preview_cfg_scale",
|
||||||
|
"preview_height",
|
||||||
|
"preview_negative_prompt",
|
||||||
|
"preview_prompt",
|
||||||
|
"preview_sampler_index",
|
||||||
|
"preview_seed",
|
||||||
|
"preview_steps",
|
||||||
|
"preview_width",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def save_settings_to_file(log_directory, all_params):
|
def save_settings_to_file(log_directory, all_params):
|
||||||
|
@ -4,10 +4,10 @@ from modules.generation_parameters_copypaste import create_override_settings_dic
|
|||||||
from modules.shared import opts, cmd_opts
|
from modules.shared import opts, cmd_opts
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from modules.ui import plaintext_to_html
|
from modules.ui import plaintext_to_html
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
|
||||||
|
def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args):
|
||||||
def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, *args):
|
|
||||||
override_settings = create_override_settings_dict(override_settings_texts)
|
override_settings = create_override_settings_dict(override_settings_texts)
|
||||||
|
|
||||||
p = processing.StableDiffusionProcessingTxt2Img(
|
p = processing.StableDiffusionProcessingTxt2Img(
|
||||||
@ -48,6 +48,8 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
|
|||||||
p.scripts = modules.scripts.scripts_txt2img
|
p.scripts = modules.scripts.scripts_txt2img
|
||||||
p.script_args = args
|
p.script_args = args
|
||||||
|
|
||||||
|
p.user = request.username
|
||||||
|
|
||||||
if cmd_opts.enable_console_prompts:
|
if cmd_opts.enable_console_prompts:
|
||||||
print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
|
print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
|
||||||
|
|
||||||
|
@ -773,7 +773,7 @@ def create_ui():
|
|||||||
selected_scale_tab = gr.State(value=0)
|
selected_scale_tab = gr.State(value=0)
|
||||||
|
|
||||||
with gr.Tabs():
|
with gr.Tabs():
|
||||||
with gr.Tab(label="Resize to") as tab_scale_to:
|
with gr.Tab(label="Resize to", elem_id="img2img_tab_resize_to") as tab_scale_to:
|
||||||
with FormRow():
|
with FormRow():
|
||||||
with gr.Column(elem_id="img2img_column_size", scale=4):
|
with gr.Column(elem_id="img2img_column_size", scale=4):
|
||||||
width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
|
width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
|
||||||
@ -782,7 +782,7 @@ def create_ui():
|
|||||||
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn")
|
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn")
|
||||||
detect_image_size_btn = ToolButton(value=detect_image_size_symbol, elem_id="img2img_detect_image_size_btn")
|
detect_image_size_btn = ToolButton(value=detect_image_size_symbol, elem_id="img2img_detect_image_size_btn")
|
||||||
|
|
||||||
with gr.Tab(label="Resize by") as tab_scale_by:
|
with gr.Tab(label="Resize by", elem_id="img2img_tab_resize_by") as tab_scale_by:
|
||||||
scale_by = gr.Slider(minimum=0.05, maximum=4.0, step=0.05, label="Scale", value=1.0, elem_id="img2img_scale")
|
scale_by = gr.Slider(minimum=0.05, maximum=4.0, step=0.05, label="Scale", value=1.0, elem_id="img2img_scale")
|
||||||
|
|
||||||
with FormRow():
|
with FormRow():
|
||||||
|
@ -138,7 +138,10 @@ def extension_table():
|
|||||||
<table id="extensions">
|
<table id="extensions">
|
||||||
<thead>
|
<thead>
|
||||||
<tr>
|
<tr>
|
||||||
<th><abbr title="Use checkbox to enable the extension; it will be enabled or disabled when you click apply button">Extension</abbr></th>
|
<th>
|
||||||
|
<input class="gr-check-radio gr-checkbox all_extensions_toggle" type="checkbox" {'checked="checked"' if all(ext.enabled for ext in extensions.extensions) else ''} onchange="toggle_all_extensions(event)" />
|
||||||
|
<abbr title="Use checkbox to enable the extension; it will be enabled or disabled when you click apply button">Extension</abbr>
|
||||||
|
</th>
|
||||||
<th>URL</th>
|
<th>URL</th>
|
||||||
<th>Branch</th>
|
<th>Branch</th>
|
||||||
<th>Version</th>
|
<th>Version</th>
|
||||||
@ -170,7 +173,7 @@ def extension_table():
|
|||||||
|
|
||||||
code += f"""
|
code += f"""
|
||||||
<tr>
|
<tr>
|
||||||
<td><label{style}><input class="gr-check-radio gr-checkbox" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''}>{html.escape(ext.name)}</label></td>
|
<td><label{style}><input class="gr-check-radio gr-checkbox extension_toggle" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''} onchange="toggle_extension(event)" />{html.escape(ext.name)}</label></td>
|
||||||
<td>{remote}</td>
|
<td>{remote}</td>
|
||||||
<td>{ext.branch}</td>
|
<td>{ext.branch}</td>
|
||||||
<td>{version_link}</td>
|
<td>{version_link}</td>
|
||||||
@ -325,6 +328,11 @@ def normalize_git_url(url):
|
|||||||
def install_extension_from_url(dirname, url, branch_name=None):
|
def install_extension_from_url(dirname, url, branch_name=None):
|
||||||
check_access()
|
check_access()
|
||||||
|
|
||||||
|
if isinstance(dirname, str):
|
||||||
|
dirname = dirname.strip()
|
||||||
|
if isinstance(url, str):
|
||||||
|
url = url.strip()
|
||||||
|
|
||||||
assert url, 'No URL specified'
|
assert url, 'No URL specified'
|
||||||
|
|
||||||
if dirname is None or dirname == "":
|
if dirname is None or dirname == "":
|
||||||
@ -563,9 +571,9 @@ def create_ui():
|
|||||||
available_extensions_table = gr.HTML()
|
available_extensions_table = gr.HTML()
|
||||||
|
|
||||||
refresh_available_extensions_button.click(
|
refresh_available_extensions_button.click(
|
||||||
fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update(), gr.update()]),
|
fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update(), gr.update(), gr.update()]),
|
||||||
inputs=[available_extensions_index, hide_tags, sort_column],
|
inputs=[available_extensions_index, hide_tags, sort_column],
|
||||||
outputs=[available_extensions_index, available_extensions_table, hide_tags, install_result, search_extensions_text],
|
outputs=[available_extensions_index, available_extensions_table, hide_tags, search_extensions_text, install_result],
|
||||||
)
|
)
|
||||||
|
|
||||||
install_extension_button.click(
|
install_extension_button.click(
|
||||||
|
9
webui.py
9
webui.py
@ -11,7 +11,7 @@ import json
|
|||||||
from threading import Thread
|
from threading import Thread
|
||||||
from typing import Iterable
|
from typing import Iterable
|
||||||
|
|
||||||
from fastapi import FastAPI, Response
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.middleware.gzip import GZipMiddleware
|
from fastapi.middleware.gzip import GZipMiddleware
|
||||||
from packaging import version
|
from packaging import version
|
||||||
@ -362,11 +362,6 @@ def api_only():
|
|||||||
api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861)
|
api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861)
|
||||||
|
|
||||||
|
|
||||||
def stop_route(request):
|
|
||||||
shared.state.server_command = "stop"
|
|
||||||
return Response("Stopping.")
|
|
||||||
|
|
||||||
|
|
||||||
def webui():
|
def webui():
|
||||||
launch_api = cmd_opts.api
|
launch_api = cmd_opts.api
|
||||||
initialize()
|
initialize()
|
||||||
@ -404,8 +399,6 @@ def webui():
|
|||||||
"redoc_url": "/redoc",
|
"redoc_url": "/redoc",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
if cmd_opts.add_stop_route:
|
|
||||||
app.add_route("/_stop", stop_route, methods=["POST"])
|
|
||||||
|
|
||||||
# after initial launch, disable --autolaunch for subsequent restarts
|
# after initial launch, disable --autolaunch for subsequent restarts
|
||||||
cmd_opts.autolaunch = False
|
cmd_opts.autolaunch = False
|
||||||
|
Loading…
Reference in New Issue
Block a user