Merge branch 'dev' into img2img-save
This commit is contained in:
commit
e0218c4f22
2
.github/workflows/on_pull_request.yaml
vendored
2
.github/workflows/on_pull_request.yaml
vendored
@ -18,7 +18,7 @@ jobs:
|
|||||||
# not to have GHA download an (at the time of writing) 4 GB cache
|
# not to have GHA download an (at the time of writing) 4 GB cache
|
||||||
# of PyTorch and other dependencies.
|
# of PyTorch and other dependencies.
|
||||||
- name: Install Ruff
|
- name: Install Ruff
|
||||||
run: pip install ruff==0.0.265
|
run: pip install ruff==0.0.272
|
||||||
- name: Run Ruff
|
- name: Run Ruff
|
||||||
run: ruff .
|
run: ruff .
|
||||||
lint-js:
|
lint-js:
|
||||||
|
4
.github/workflows/run_tests.yaml
vendored
4
.github/workflows/run_tests.yaml
vendored
@ -42,7 +42,7 @@ jobs:
|
|||||||
--no-half
|
--no-half
|
||||||
--disable-opt-split-attention
|
--disable-opt-split-attention
|
||||||
--use-cpu all
|
--use-cpu all
|
||||||
--add-stop-route
|
--api-server-stop
|
||||||
2>&1 | tee output.txt &
|
2>&1 | tee output.txt &
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
@ -50,7 +50,7 @@ jobs:
|
|||||||
python -m pytest -vv --junitxml=test/results.xml --cov . --cov-report=xml --verify-base-url test
|
python -m pytest -vv --junitxml=test/results.xml --cov . --cov-report=xml --verify-base-url test
|
||||||
- name: Kill test server
|
- name: Kill test server
|
||||||
if: always()
|
if: always()
|
||||||
run: curl -vv -XPOST http://127.0.0.1:7860/_stop && sleep 10
|
run: curl -vv -XPOST http://127.0.0.1:7860/sdapi/v1/server-stop && sleep 10
|
||||||
- name: Show coverage
|
- name: Show coverage
|
||||||
run: |
|
run: |
|
||||||
python -m coverage combine .coverage*
|
python -m coverage combine .coverage*
|
||||||
|
@ -135,8 +135,11 @@ Find the instructions [here](https://github.com/AUTOMATIC1111/stable-diffusion-w
|
|||||||
Here's how to add code to this repo: [Contributing](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Contributing)
|
Here's how to add code to this repo: [Contributing](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Contributing)
|
||||||
|
|
||||||
## Documentation
|
## Documentation
|
||||||
|
|
||||||
The documentation was moved from this README over to the project's [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki).
|
The documentation was moved from this README over to the project's [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki).
|
||||||
|
|
||||||
|
For the purposes of getting Google and other search engines to crawl the wiki, here's a link to the (not for humans) [crawlable wiki](https://github-wiki-see.page/m/AUTOMATIC1111/stable-diffusion-webui/wiki).
|
||||||
|
|
||||||
## Credits
|
## Credits
|
||||||
Licenses for borrowed code can be found in `Settings -> Licenses` screen, and also in `html/licenses.html` file.
|
Licenses for borrowed code can be found in `Settings -> Licenses` screen, and also in `html/licenses.html` file.
|
||||||
|
|
||||||
|
@ -12,7 +12,7 @@ import safetensors.torch
|
|||||||
|
|
||||||
from ldm.models.diffusion.ddim import DDIMSampler
|
from ldm.models.diffusion.ddim import DDIMSampler
|
||||||
from ldm.util import instantiate_from_config, ismap
|
from ldm.util import instantiate_from_config, ismap
|
||||||
from modules import shared, sd_hijack
|
from modules import shared, sd_hijack, devices
|
||||||
|
|
||||||
cached_ldsr_model: torch.nn.Module = None
|
cached_ldsr_model: torch.nn.Module = None
|
||||||
|
|
||||||
@ -112,8 +112,7 @@ class LDSR:
|
|||||||
|
|
||||||
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
if torch.cuda.is_available:
|
devices.torch_gc()
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
im_og = image
|
im_og = image
|
||||||
width_og, height_og = im_og.size
|
width_og, height_og = im_og.size
|
||||||
@ -150,8 +149,7 @@ class LDSR:
|
|||||||
|
|
||||||
del model
|
del model
|
||||||
gc.collect()
|
gc.collect()
|
||||||
if torch.cuda.is_available:
|
devices.torch_gc()
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
return a
|
return a
|
||||||
|
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
from basicsr.utils.download_util import load_file_from_url
|
from modules.modelloader import load_file_from_url
|
||||||
|
|
||||||
from modules.upscaler import Upscaler, UpscalerData
|
from modules.upscaler import Upscaler, UpscalerData
|
||||||
from ldsr_model_arch import LDSR
|
from ldsr_model_arch import LDSR
|
||||||
from modules import shared, script_callbacks, errors
|
from modules import shared, script_callbacks, errors
|
||||||
@ -43,20 +42,17 @@ class UpscalerLDSR(Upscaler):
|
|||||||
if local_safetensors_path is not None and os.path.exists(local_safetensors_path):
|
if local_safetensors_path is not None and os.path.exists(local_safetensors_path):
|
||||||
model = local_safetensors_path
|
model = local_safetensors_path
|
||||||
else:
|
else:
|
||||||
model = local_ckpt_path if local_ckpt_path is not None else load_file_from_url(url=self.model_url, model_dir=self.model_download_path, file_name="model.ckpt", progress=True)
|
model = local_ckpt_path or load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name="model.ckpt")
|
||||||
|
|
||||||
yaml = local_yaml_path if local_yaml_path is not None else load_file_from_url(url=self.yaml_url, model_dir=self.model_download_path, file_name="project.yaml", progress=True)
|
yaml = local_yaml_path or load_file_from_url(self.yaml_url, model_dir=self.model_download_path, file_name="project.yaml")
|
||||||
|
|
||||||
try:
|
return LDSR(model, yaml)
|
||||||
return LDSR(model, yaml)
|
|
||||||
except Exception:
|
|
||||||
errors.report("Error importing LDSR", exc_info=True)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def do_upscale(self, img, path):
|
def do_upscale(self, img, path):
|
||||||
ldsr = self.load_model(path)
|
try:
|
||||||
if ldsr is None:
|
ldsr = self.load_model(path)
|
||||||
print("NO LDSR!")
|
except Exception:
|
||||||
|
errors.report(f"Failed loading LDSR model {path}", exc_info=True)
|
||||||
return img
|
return img
|
||||||
ddim_steps = shared.opts.ldsr_steps
|
ddim_steps = shared.opts.ldsr_steps
|
||||||
return ldsr.super_resolution(img, ddim_steps, self.scale)
|
return ldsr.super_resolution(img, ddim_steps, self.scale)
|
||||||
|
@ -443,7 +443,7 @@ def list_available_loras():
|
|||||||
os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
|
os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
|
||||||
|
|
||||||
candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
|
candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
|
||||||
for filename in sorted(candidates, key=str.lower):
|
for filename in candidates:
|
||||||
if os.path.isdir(filename):
|
if os.path.isdir(filename):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import os.path
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import PIL.Image
|
import PIL.Image
|
||||||
@ -6,12 +5,11 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from basicsr.utils.download_util import load_file_from_url
|
|
||||||
|
|
||||||
import modules.upscaler
|
import modules.upscaler
|
||||||
from modules import devices, modelloader, script_callbacks, errors
|
from modules import devices, modelloader, script_callbacks, errors
|
||||||
from scunet_model_arch import SCUNet as net
|
from scunet_model_arch import SCUNet
|
||||||
|
|
||||||
|
from modules.modelloader import load_file_from_url
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
|
|
||||||
|
|
||||||
@ -28,7 +26,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
|
|||||||
scalers = []
|
scalers = []
|
||||||
add_model2 = True
|
add_model2 = True
|
||||||
for file in model_paths:
|
for file in model_paths:
|
||||||
if "http" in file:
|
if file.startswith("http"):
|
||||||
name = self.model_name
|
name = self.model_name
|
||||||
else:
|
else:
|
||||||
name = modelloader.friendly_name(file)
|
name = modelloader.friendly_name(file)
|
||||||
@ -87,11 +85,12 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
|
|||||||
|
|
||||||
def do_upscale(self, img: PIL.Image.Image, selected_file):
|
def do_upscale(self, img: PIL.Image.Image, selected_file):
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
devices.torch_gc()
|
||||||
|
|
||||||
model = self.load_model(selected_file)
|
try:
|
||||||
if model is None:
|
model = self.load_model(selected_file)
|
||||||
print(f"ScuNET: Unable to load model from {selected_file}", file=sys.stderr)
|
except Exception as e:
|
||||||
|
print(f"ScuNET: Unable to load model from {selected_file}: {e}", file=sys.stderr)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
device = devices.get_device_for('scunet')
|
device = devices.get_device_for('scunet')
|
||||||
@ -111,7 +110,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
|
|||||||
torch_output = torch_output[:, :h * 1, :w * 1] # remove padding, if any
|
torch_output = torch_output[:, :h * 1, :w * 1] # remove padding, if any
|
||||||
np_output: np.ndarray = torch_output.float().cpu().clamp_(0, 1).numpy()
|
np_output: np.ndarray = torch_output.float().cpu().clamp_(0, 1).numpy()
|
||||||
del torch_img, torch_output
|
del torch_img, torch_output
|
||||||
torch.cuda.empty_cache()
|
devices.torch_gc()
|
||||||
|
|
||||||
output = np_output.transpose((1, 2, 0)) # CHW to HWC
|
output = np_output.transpose((1, 2, 0)) # CHW to HWC
|
||||||
output = output[:, :, ::-1] # BGR to RGB
|
output = output[:, :, ::-1] # BGR to RGB
|
||||||
@ -119,15 +118,12 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
|
|||||||
|
|
||||||
def load_model(self, path: str):
|
def load_model(self, path: str):
|
||||||
device = devices.get_device_for('scunet')
|
device = devices.get_device_for('scunet')
|
||||||
if "http" in path:
|
if path.startswith("http"):
|
||||||
filename = load_file_from_url(url=self.model_url, model_dir=self.model_download_path, file_name="%s.pth" % self.name, progress=True)
|
# TODO: this doesn't use `path` at all?
|
||||||
|
filename = load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth")
|
||||||
else:
|
else:
|
||||||
filename = path
|
filename = path
|
||||||
if not os.path.exists(os.path.join(self.model_path, filename)) or filename is None:
|
model = SCUNet(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
|
||||||
print(f"ScuNET: Unable to load model from {filename}", file=sys.stderr)
|
|
||||||
return None
|
|
||||||
|
|
||||||
model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
|
|
||||||
model.load_state_dict(torch.load(filename), strict=True)
|
model.load_state_dict(torch.load(filename), strict=True)
|
||||||
model.eval()
|
model.eval()
|
||||||
for _, v in model.named_parameters():
|
for _, v in model.named_parameters():
|
||||||
|
@ -1,34 +1,35 @@
|
|||||||
import os
|
import sys
|
||||||
|
import platform
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from basicsr.utils.download_util import load_file_from_url
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from modules import modelloader, devices, script_callbacks, shared
|
from modules import modelloader, devices, script_callbacks, shared
|
||||||
from modules.shared import opts, state
|
from modules.shared import opts, state
|
||||||
from swinir_model_arch import SwinIR as net
|
from swinir_model_arch import SwinIR
|
||||||
from swinir_model_arch_v2 import Swin2SR as net2
|
from swinir_model_arch_v2 import Swin2SR
|
||||||
from modules.upscaler import Upscaler, UpscalerData
|
from modules.upscaler import Upscaler, UpscalerData
|
||||||
|
|
||||||
|
SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth"
|
||||||
|
|
||||||
device_swinir = devices.get_device_for('swinir')
|
device_swinir = devices.get_device_for('swinir')
|
||||||
|
|
||||||
|
|
||||||
class UpscalerSwinIR(Upscaler):
|
class UpscalerSwinIR(Upscaler):
|
||||||
def __init__(self, dirname):
|
def __init__(self, dirname):
|
||||||
|
self._cached_model = None # keep the model when SWIN_torch_compile is on to prevent re-compile every runs
|
||||||
|
self._cached_model_config = None # to clear '_cached_model' when changing model (v1/v2) or settings
|
||||||
self.name = "SwinIR"
|
self.name = "SwinIR"
|
||||||
self.model_url = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0" \
|
self.model_url = SWINIR_MODEL_URL
|
||||||
"/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR" \
|
|
||||||
"-L_x4_GAN.pth "
|
|
||||||
self.model_name = "SwinIR 4x"
|
self.model_name = "SwinIR 4x"
|
||||||
self.user_path = dirname
|
self.user_path = dirname
|
||||||
super().__init__()
|
super().__init__()
|
||||||
scalers = []
|
scalers = []
|
||||||
model_files = self.find_models(ext_filter=[".pt", ".pth"])
|
model_files = self.find_models(ext_filter=[".pt", ".pth"])
|
||||||
for model in model_files:
|
for model in model_files:
|
||||||
if "http" in model:
|
if model.startswith("http"):
|
||||||
name = self.model_name
|
name = self.model_name
|
||||||
else:
|
else:
|
||||||
name = modelloader.friendly_name(model)
|
name = modelloader.friendly_name(model)
|
||||||
@ -37,42 +38,54 @@ class UpscalerSwinIR(Upscaler):
|
|||||||
self.scalers = scalers
|
self.scalers = scalers
|
||||||
|
|
||||||
def do_upscale(self, img, model_file):
|
def do_upscale(self, img, model_file):
|
||||||
model = self.load_model(model_file)
|
use_compile = hasattr(opts, 'SWIN_torch_compile') and opts.SWIN_torch_compile \
|
||||||
if model is None:
|
and int(torch.__version__.split('.')[0]) >= 2 and platform.system() != "Windows"
|
||||||
return img
|
current_config = (model_file, opts.SWIN_tile)
|
||||||
model = model.to(device_swinir, dtype=devices.dtype)
|
|
||||||
|
if use_compile and self._cached_model_config == current_config:
|
||||||
|
model = self._cached_model
|
||||||
|
else:
|
||||||
|
self._cached_model = None
|
||||||
|
try:
|
||||||
|
model = self.load_model(model_file)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed loading SwinIR model {model_file}: {e}", file=sys.stderr)
|
||||||
|
return img
|
||||||
|
model = model.to(device_swinir, dtype=devices.dtype)
|
||||||
|
if use_compile:
|
||||||
|
model = torch.compile(model)
|
||||||
|
self._cached_model = model
|
||||||
|
self._cached_model_config = current_config
|
||||||
img = upscale(img, model)
|
img = upscale(img, model)
|
||||||
try:
|
devices.torch_gc()
|
||||||
torch.cuda.empty_cache()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
return img
|
return img
|
||||||
|
|
||||||
def load_model(self, path, scale=4):
|
def load_model(self, path, scale=4):
|
||||||
if "http" in path:
|
if path.startswith("http"):
|
||||||
dl_name = "%s%s" % (self.model_name.replace(" ", "_"), ".pth")
|
filename = modelloader.load_file_from_url(
|
||||||
filename = load_file_from_url(url=path, model_dir=self.model_download_path, file_name=dl_name, progress=True)
|
url=path,
|
||||||
|
model_dir=self.model_download_path,
|
||||||
|
file_name=f"{self.model_name.replace(' ', '_')}.pth",
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
filename = path
|
filename = path
|
||||||
if filename is None or not os.path.exists(filename):
|
|
||||||
return None
|
|
||||||
if filename.endswith(".v2.pth"):
|
if filename.endswith(".v2.pth"):
|
||||||
model = net2(
|
model = Swin2SR(
|
||||||
upscale=scale,
|
upscale=scale,
|
||||||
in_chans=3,
|
in_chans=3,
|
||||||
img_size=64,
|
img_size=64,
|
||||||
window_size=8,
|
window_size=8,
|
||||||
img_range=1.0,
|
img_range=1.0,
|
||||||
depths=[6, 6, 6, 6, 6, 6],
|
depths=[6, 6, 6, 6, 6, 6],
|
||||||
embed_dim=180,
|
embed_dim=180,
|
||||||
num_heads=[6, 6, 6, 6, 6, 6],
|
num_heads=[6, 6, 6, 6, 6, 6],
|
||||||
mlp_ratio=2,
|
mlp_ratio=2,
|
||||||
upsampler="nearest+conv",
|
upsampler="nearest+conv",
|
||||||
resi_connection="1conv",
|
resi_connection="1conv",
|
||||||
)
|
)
|
||||||
params = None
|
params = None
|
||||||
else:
|
else:
|
||||||
model = net(
|
model = SwinIR(
|
||||||
upscale=scale,
|
upscale=scale,
|
||||||
in_chans=3,
|
in_chans=3,
|
||||||
img_size=64,
|
img_size=64,
|
||||||
@ -172,6 +185,8 @@ def on_ui_settings():
|
|||||||
|
|
||||||
shared.opts.add_option("SWIN_tile", shared.OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")))
|
shared.opts.add_option("SWIN_tile", shared.OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")))
|
||||||
shared.opts.add_option("SWIN_tile_overlap", shared.OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}, section=('upscaling', "Upscaling")))
|
shared.opts.add_option("SWIN_tile_overlap", shared.OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}, section=('upscaling', "Upscaling")))
|
||||||
|
if int(torch.__version__.split('.')[0]) >= 2 and platform.system() != "Windows": # torch.compile() require pytorch 2.0 or above, and not on Windows
|
||||||
|
shared.opts.add_option("SWIN_torch_compile", shared.OptionInfo(False, "Use torch.compile to accelerate SwinIR.", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling")).info("Takes longer on first run"))
|
||||||
|
|
||||||
|
|
||||||
script_callbacks.on_ui_settings(on_ui_settings)
|
script_callbacks.on_ui_settings(on_ui_settings)
|
||||||
|
@ -200,7 +200,8 @@ onUiLoaded(async() => {
|
|||||||
canvas_hotkey_move: "KeyF",
|
canvas_hotkey_move: "KeyF",
|
||||||
canvas_hotkey_overlap: "KeyO",
|
canvas_hotkey_overlap: "KeyO",
|
||||||
canvas_disabled_functions: [],
|
canvas_disabled_functions: [],
|
||||||
canvas_show_tooltip: true
|
canvas_show_tooltip: true,
|
||||||
|
canvas_blur_prompt: false
|
||||||
};
|
};
|
||||||
|
|
||||||
const functionMap = {
|
const functionMap = {
|
||||||
@ -608,6 +609,19 @@ onUiLoaded(async() => {
|
|||||||
|
|
||||||
// Handle keydown events
|
// Handle keydown events
|
||||||
function handleKeyDown(event) {
|
function handleKeyDown(event) {
|
||||||
|
// Disable key locks to make pasting from the buffer work correctly
|
||||||
|
if ((event.ctrlKey && event.code === 'KeyV') || (event.ctrlKey && event.code === 'KeyC') || event.code === "F5") {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// before activating shortcut, ensure user is not actively typing in an input field
|
||||||
|
if (!hotkeysConfig.canvas_blur_prompt) {
|
||||||
|
if (event.target.nodeName === 'TEXTAREA' || event.target.nodeName === 'INPUT') {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
const hotkeyActions = {
|
const hotkeyActions = {
|
||||||
[hotkeysConfig.canvas_hotkey_reset]: resetZoom,
|
[hotkeysConfig.canvas_hotkey_reset]: resetZoom,
|
||||||
[hotkeysConfig.canvas_hotkey_overlap]: toggleOverlap,
|
[hotkeysConfig.canvas_hotkey_overlap]: toggleOverlap,
|
||||||
@ -686,6 +700,20 @@ onUiLoaded(async() => {
|
|||||||
|
|
||||||
// Handle the move event for pan functionality. Updates the panX and panY variables and applies the new transform to the target element.
|
// Handle the move event for pan functionality. Updates the panX and panY variables and applies the new transform to the target element.
|
||||||
function handleMoveKeyDown(e) {
|
function handleMoveKeyDown(e) {
|
||||||
|
|
||||||
|
// Disable key locks to make pasting from the buffer work correctly
|
||||||
|
if ((e.ctrlKey && e.code === 'KeyV') || (e.ctrlKey && event.code === 'KeyC') || e.code === "F5") {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// before activating shortcut, ensure user is not actively typing in an input field
|
||||||
|
if (!hotkeysConfig.canvas_blur_prompt) {
|
||||||
|
if (e.target.nodeName === 'TEXTAREA' || e.target.nodeName === 'INPUT') {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
if (e.code === hotkeysConfig.canvas_hotkey_move) {
|
if (e.code === hotkeysConfig.canvas_hotkey_move) {
|
||||||
if (!e.ctrlKey && !e.metaKey && isKeyDownHandlerAttached) {
|
if (!e.ctrlKey && !e.metaKey && isKeyDownHandlerAttached) {
|
||||||
e.preventDefault();
|
e.preventDefault();
|
||||||
|
@ -9,5 +9,6 @@ shared.options_templates.update(shared.options_section(('canvas_hotkey', "Canvas
|
|||||||
"canvas_hotkey_reset": shared.OptionInfo("R", "Reset zoom and canvas positon"),
|
"canvas_hotkey_reset": shared.OptionInfo("R", "Reset zoom and canvas positon"),
|
||||||
"canvas_hotkey_overlap": shared.OptionInfo("O", "Toggle overlap").info("Technical button, neededs for testing"),
|
"canvas_hotkey_overlap": shared.OptionInfo("O", "Toggle overlap").info("Technical button, neededs for testing"),
|
||||||
"canvas_show_tooltip": shared.OptionInfo(True, "Enable tooltip on the canvas"),
|
"canvas_show_tooltip": shared.OptionInfo(True, "Enable tooltip on the canvas"),
|
||||||
|
"canvas_blur_prompt": shared.OptionInfo(False, "Take the focus off the prompt when working with a canvas"),
|
||||||
"canvas_disabled_functions": shared.OptionInfo(["Overlap"], "Disable function that you don't use", gr.CheckboxGroup, {"choices": ["Zoom","Adjust brush size", "Moving canvas","Fullscreen","Reset Zoom","Overlap"]}),
|
"canvas_disabled_functions": shared.OptionInfo(["Overlap"], "Disable function that you don't use", gr.CheckboxGroup, {"choices": ["Zoom","Adjust brush size", "Moving canvas","Fullscreen","Reset Zoom","Overlap"]}),
|
||||||
}))
|
}))
|
||||||
|
@ -100,11 +100,12 @@ function keyupEditAttention(event) {
|
|||||||
if (String(weight).length == 1) weight += ".0";
|
if (String(weight).length == 1) weight += ".0";
|
||||||
|
|
||||||
if (closeCharacter == ')' && weight == 1) {
|
if (closeCharacter == ')' && weight == 1) {
|
||||||
text = text.slice(0, selectionStart - 1) + text.slice(selectionStart, selectionEnd) + text.slice(selectionEnd + 5);
|
var endParenPos = text.substring(selectionEnd).indexOf(')');
|
||||||
|
text = text.slice(0, selectionStart - 1) + text.slice(selectionStart, selectionEnd) + text.slice(selectionEnd + endParenPos + 1);
|
||||||
selectionStart--;
|
selectionStart--;
|
||||||
selectionEnd--;
|
selectionEnd--;
|
||||||
} else {
|
} else {
|
||||||
text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + 1 + end - 1);
|
text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + end);
|
||||||
}
|
}
|
||||||
|
|
||||||
target.focus();
|
target.focus();
|
||||||
|
41
javascript/edit-order.js
Normal file
41
javascript/edit-order.js
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
/* alt+left/right moves text in prompt */
|
||||||
|
|
||||||
|
function keyupEditOrder(event) {
|
||||||
|
if (!opts.keyedit_move) return;
|
||||||
|
|
||||||
|
let target = event.originalTarget || event.composedPath()[0];
|
||||||
|
if (!target.matches("*:is([id*='_toprow'] [id*='_prompt'], .prompt) textarea")) return;
|
||||||
|
if (!event.altKey) return;
|
||||||
|
|
||||||
|
let isLeft = event.key == "ArrowLeft";
|
||||||
|
let isRight = event.key == "ArrowRight";
|
||||||
|
if (!isLeft && !isRight) return;
|
||||||
|
event.preventDefault();
|
||||||
|
|
||||||
|
let selectionStart = target.selectionStart;
|
||||||
|
let selectionEnd = target.selectionEnd;
|
||||||
|
let text = target.value;
|
||||||
|
let items = text.split(",");
|
||||||
|
let indexStart = (text.slice(0, selectionStart).match(/,/g) || []).length;
|
||||||
|
let indexEnd = (text.slice(0, selectionEnd).match(/,/g) || []).length;
|
||||||
|
let range = indexEnd - indexStart + 1;
|
||||||
|
|
||||||
|
if (isLeft && indexStart > 0) {
|
||||||
|
items.splice(indexStart - 1, 0, ...items.splice(indexStart, range));
|
||||||
|
target.value = items.join();
|
||||||
|
target.selectionStart = items.slice(0, indexStart - 1).join().length + (indexStart == 1 ? 0 : 1);
|
||||||
|
target.selectionEnd = items.slice(0, indexEnd).join().length;
|
||||||
|
} else if (isRight && indexEnd < items.length - 1) {
|
||||||
|
items.splice(indexStart + 1, 0, ...items.splice(indexStart, range));
|
||||||
|
target.value = items.join();
|
||||||
|
target.selectionStart = items.slice(0, indexStart + 1).join().length + 1;
|
||||||
|
target.selectionEnd = items.slice(0, indexEnd + 2).join().length;
|
||||||
|
}
|
||||||
|
|
||||||
|
event.preventDefault();
|
||||||
|
updateInput(target);
|
||||||
|
}
|
||||||
|
|
||||||
|
addEventListener('keydown', (event) => {
|
||||||
|
keyupEditOrder(event);
|
||||||
|
});
|
@ -72,3 +72,21 @@ function config_state_confirm_restore(_, config_state_name, config_restore_type)
|
|||||||
}
|
}
|
||||||
return [confirmed, config_state_name, config_restore_type];
|
return [confirmed, config_state_name, config_restore_type];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function toggle_all_extensions(event) {
|
||||||
|
gradioApp().querySelectorAll('#extensions .extension_toggle').forEach(function(checkbox_el) {
|
||||||
|
checkbox_el.checked = event.target.checked;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
function toggle_extension() {
|
||||||
|
let all_extensions_toggled = true;
|
||||||
|
for (const checkbox_el of gradioApp().querySelectorAll('#extensions .extension_toggle')) {
|
||||||
|
if (!checkbox_el.checked) {
|
||||||
|
all_extensions_toggled = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
gradioApp().querySelector('#extensions .all_extensions_toggle').checked = all_extensions_toggled;
|
||||||
|
}
|
||||||
|
@ -14,7 +14,7 @@ from fastapi.encoders import jsonable_encoder
|
|||||||
from secrets import compare_digest
|
from secrets import compare_digest
|
||||||
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors
|
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart
|
||||||
from modules.api import models
|
from modules.api import models
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
||||||
@ -22,7 +22,7 @@ from modules.textual_inversion.textual_inversion import create_embedding, train_
|
|||||||
from modules.textual_inversion.preprocess import preprocess
|
from modules.textual_inversion.preprocess import preprocess
|
||||||
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
|
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
|
||||||
from PIL import PngImagePlugin,Image
|
from PIL import PngImagePlugin,Image
|
||||||
from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights
|
from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights, checkpoint_aliases
|
||||||
from modules.sd_vae import vae_dict
|
from modules.sd_vae import vae_dict
|
||||||
from modules.sd_models_config import find_checkpoint_config_near_filename
|
from modules.sd_models_config import find_checkpoint_config_near_filename
|
||||||
from modules.realesrgan_model import get_realesrgan_models
|
from modules.realesrgan_model import get_realesrgan_models
|
||||||
@ -30,13 +30,7 @@ from modules import devices
|
|||||||
from typing import Dict, List, Any
|
from typing import Dict, List, Any
|
||||||
import piexif
|
import piexif
|
||||||
import piexif.helper
|
import piexif.helper
|
||||||
|
from contextlib import closing
|
||||||
|
|
||||||
def upscaler_to_index(name: str):
|
|
||||||
try:
|
|
||||||
return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in shared.sd_upscalers])}") from e
|
|
||||||
|
|
||||||
|
|
||||||
def script_name_to_index(name, scripts):
|
def script_name_to_index(name, scripts):
|
||||||
@ -84,6 +78,8 @@ def encode_pil_to_base64(image):
|
|||||||
image.save(output_bytes, format="PNG", pnginfo=(metadata if use_metadata else None), quality=opts.jpeg_quality)
|
image.save(output_bytes, format="PNG", pnginfo=(metadata if use_metadata else None), quality=opts.jpeg_quality)
|
||||||
|
|
||||||
elif opts.samples_format.lower() in ("jpg", "jpeg", "webp"):
|
elif opts.samples_format.lower() in ("jpg", "jpeg", "webp"):
|
||||||
|
if image.mode == "RGBA":
|
||||||
|
image = image.convert("RGB")
|
||||||
parameters = image.info.get('parameters', None)
|
parameters = image.info.get('parameters', None)
|
||||||
exif_bytes = piexif.dump({
|
exif_bytes = piexif.dump({
|
||||||
"Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(parameters or "", encoding="unicode") }
|
"Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(parameters or "", encoding="unicode") }
|
||||||
@ -209,6 +205,11 @@ class Api:
|
|||||||
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
|
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
|
||||||
self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo])
|
self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo])
|
||||||
|
|
||||||
|
if shared.cmd_opts.api_server_stop:
|
||||||
|
self.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"])
|
||||||
|
self.add_api_route("/sdapi/v1/server-restart", self.restart_webui, methods=["POST"])
|
||||||
|
self.add_api_route("/sdapi/v1/server-stop", self.stop_webui, methods=["POST"])
|
||||||
|
|
||||||
self.default_script_arg_txt2img = []
|
self.default_script_arg_txt2img = []
|
||||||
self.default_script_arg_img2img = []
|
self.default_script_arg_img2img = []
|
||||||
|
|
||||||
@ -324,19 +325,19 @@ class Api:
|
|||||||
args.pop('save_images', None)
|
args.pop('save_images', None)
|
||||||
|
|
||||||
with self.queue_lock:
|
with self.queue_lock:
|
||||||
p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)
|
with closing(StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)) as p:
|
||||||
p.scripts = script_runner
|
p.scripts = script_runner
|
||||||
p.outpath_grids = opts.outdir_txt2img_grids
|
p.outpath_grids = opts.outdir_txt2img_grids
|
||||||
p.outpath_samples = opts.outdir_txt2img_samples
|
p.outpath_samples = opts.outdir_txt2img_samples
|
||||||
|
|
||||||
shared.state.begin()
|
shared.state.begin(job="scripts_txt2img")
|
||||||
if selectable_scripts is not None:
|
if selectable_scripts is not None:
|
||||||
p.script_args = script_args
|
p.script_args = script_args
|
||||||
processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
|
processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
|
||||||
else:
|
else:
|
||||||
p.script_args = tuple(script_args) # Need to pass args as tuple here
|
p.script_args = tuple(script_args) # Need to pass args as tuple here
|
||||||
processed = process_images(p)
|
processed = process_images(p)
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
|
|
||||||
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
|
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
|
||||||
|
|
||||||
@ -380,20 +381,20 @@ class Api:
|
|||||||
args.pop('save_images', None)
|
args.pop('save_images', None)
|
||||||
|
|
||||||
with self.queue_lock:
|
with self.queue_lock:
|
||||||
p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)
|
with closing(StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)) as p:
|
||||||
p.init_images = [decode_base64_to_image(x) for x in init_images]
|
p.init_images = [decode_base64_to_image(x) for x in init_images]
|
||||||
p.scripts = script_runner
|
p.scripts = script_runner
|
||||||
p.outpath_grids = opts.outdir_img2img_grids
|
p.outpath_grids = opts.outdir_img2img_grids
|
||||||
p.outpath_samples = opts.outdir_img2img_samples
|
p.outpath_samples = opts.outdir_img2img_samples
|
||||||
|
|
||||||
shared.state.begin()
|
shared.state.begin(job="scripts_img2img")
|
||||||
if selectable_scripts is not None:
|
if selectable_scripts is not None:
|
||||||
p.script_args = script_args
|
p.script_args = script_args
|
||||||
processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
|
processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
|
||||||
else:
|
else:
|
||||||
p.script_args = tuple(script_args) # Need to pass args as tuple here
|
p.script_args = tuple(script_args) # Need to pass args as tuple here
|
||||||
processed = process_images(p)
|
processed = process_images(p)
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
|
|
||||||
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
|
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
|
||||||
|
|
||||||
@ -517,6 +518,10 @@ class Api:
|
|||||||
return options
|
return options
|
||||||
|
|
||||||
def set_config(self, req: Dict[str, Any]):
|
def set_config(self, req: Dict[str, Any]):
|
||||||
|
checkpoint_name = req.get("sd_model_checkpoint", None)
|
||||||
|
if checkpoint_name is not None and checkpoint_name not in checkpoint_aliases:
|
||||||
|
raise RuntimeError(f"model {checkpoint_name!r} not found")
|
||||||
|
|
||||||
for k, v in req.items():
|
for k, v in req.items():
|
||||||
shared.opts.set(k, v)
|
shared.opts.set(k, v)
|
||||||
|
|
||||||
@ -598,44 +603,42 @@ class Api:
|
|||||||
|
|
||||||
def create_embedding(self, args: dict):
|
def create_embedding(self, args: dict):
|
||||||
try:
|
try:
|
||||||
shared.state.begin()
|
shared.state.begin(job="create_embedding")
|
||||||
filename = create_embedding(**args) # create empty embedding
|
filename = create_embedding(**args) # create empty embedding
|
||||||
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
|
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
|
||||||
shared.state.end()
|
|
||||||
return models.CreateResponse(info=f"create embedding filename: {filename}")
|
return models.CreateResponse(info=f"create embedding filename: {filename}")
|
||||||
except AssertionError as e:
|
except AssertionError as e:
|
||||||
shared.state.end()
|
|
||||||
return models.TrainResponse(info=f"create embedding error: {e}")
|
return models.TrainResponse(info=f"create embedding error: {e}")
|
||||||
|
finally:
|
||||||
|
shared.state.end()
|
||||||
|
|
||||||
|
|
||||||
def create_hypernetwork(self, args: dict):
|
def create_hypernetwork(self, args: dict):
|
||||||
try:
|
try:
|
||||||
shared.state.begin()
|
shared.state.begin(job="create_hypernetwork")
|
||||||
filename = create_hypernetwork(**args) # create empty embedding
|
filename = create_hypernetwork(**args) # create empty embedding
|
||||||
shared.state.end()
|
|
||||||
return models.CreateResponse(info=f"create hypernetwork filename: {filename}")
|
return models.CreateResponse(info=f"create hypernetwork filename: {filename}")
|
||||||
except AssertionError as e:
|
except AssertionError as e:
|
||||||
shared.state.end()
|
|
||||||
return models.TrainResponse(info=f"create hypernetwork error: {e}")
|
return models.TrainResponse(info=f"create hypernetwork error: {e}")
|
||||||
|
finally:
|
||||||
|
shared.state.end()
|
||||||
|
|
||||||
def preprocess(self, args: dict):
|
def preprocess(self, args: dict):
|
||||||
try:
|
try:
|
||||||
shared.state.begin()
|
shared.state.begin(job="preprocess")
|
||||||
preprocess(**args) # quick operation unless blip/booru interrogation is enabled
|
preprocess(**args) # quick operation unless blip/booru interrogation is enabled
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
return models.PreprocessResponse(info = 'preprocess complete')
|
return models.PreprocessResponse(info='preprocess complete')
|
||||||
except KeyError as e:
|
except KeyError as e:
|
||||||
shared.state.end()
|
|
||||||
return models.PreprocessResponse(info=f"preprocess error: invalid token: {e}")
|
return models.PreprocessResponse(info=f"preprocess error: invalid token: {e}")
|
||||||
except AssertionError as e:
|
except Exception as e:
|
||||||
shared.state.end()
|
|
||||||
return models.PreprocessResponse(info=f"preprocess error: {e}")
|
return models.PreprocessResponse(info=f"preprocess error: {e}")
|
||||||
except FileNotFoundError as e:
|
finally:
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
return models.PreprocessResponse(info=f'preprocess error: {e}')
|
|
||||||
|
|
||||||
def train_embedding(self, args: dict):
|
def train_embedding(self, args: dict):
|
||||||
try:
|
try:
|
||||||
shared.state.begin()
|
shared.state.begin(job="train_embedding")
|
||||||
apply_optimizations = shared.opts.training_xattention_optimizations
|
apply_optimizations = shared.opts.training_xattention_optimizations
|
||||||
error = None
|
error = None
|
||||||
filename = ''
|
filename = ''
|
||||||
@ -648,15 +651,15 @@ class Api:
|
|||||||
finally:
|
finally:
|
||||||
if not apply_optimizations:
|
if not apply_optimizations:
|
||||||
sd_hijack.apply_optimizations()
|
sd_hijack.apply_optimizations()
|
||||||
shared.state.end()
|
|
||||||
return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
|
return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
|
||||||
except AssertionError as msg:
|
except Exception as msg:
|
||||||
shared.state.end()
|
|
||||||
return models.TrainResponse(info=f"train embedding error: {msg}")
|
return models.TrainResponse(info=f"train embedding error: {msg}")
|
||||||
|
finally:
|
||||||
|
shared.state.end()
|
||||||
|
|
||||||
def train_hypernetwork(self, args: dict):
|
def train_hypernetwork(self, args: dict):
|
||||||
try:
|
try:
|
||||||
shared.state.begin()
|
shared.state.begin(job="train_hypernetwork")
|
||||||
shared.loaded_hypernetworks = []
|
shared.loaded_hypernetworks = []
|
||||||
apply_optimizations = shared.opts.training_xattention_optimizations
|
apply_optimizations = shared.opts.training_xattention_optimizations
|
||||||
error = None
|
error = None
|
||||||
@ -674,9 +677,10 @@ class Api:
|
|||||||
sd_hijack.apply_optimizations()
|
sd_hijack.apply_optimizations()
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
|
return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
|
||||||
except AssertionError:
|
except Exception as exc:
|
||||||
|
return models.TrainResponse(info=f"train embedding error: {exc}")
|
||||||
|
finally:
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
return models.TrainResponse(info=f"train embedding error: {error}")
|
|
||||||
|
|
||||||
def get_memory(self):
|
def get_memory(self):
|
||||||
try:
|
try:
|
||||||
@ -716,3 +720,15 @@ class Api:
|
|||||||
def launch(self, server_name, port):
|
def launch(self, server_name, port):
|
||||||
self.app.include_router(self.router)
|
self.app.include_router(self.router)
|
||||||
uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=0)
|
uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=0)
|
||||||
|
|
||||||
|
def kill_webui(self):
|
||||||
|
restart.stop_program()
|
||||||
|
|
||||||
|
def restart_webui(self):
|
||||||
|
if restart.is_restartable():
|
||||||
|
restart.restart_program()
|
||||||
|
return Response(status_code=501)
|
||||||
|
|
||||||
|
def stop_webui(request):
|
||||||
|
shared.state.server_command = "stop"
|
||||||
|
return Response("Stopping.")
|
||||||
|
@ -274,10 +274,6 @@ class PromptStyleItem(BaseModel):
|
|||||||
prompt: Optional[str] = Field(title="Prompt")
|
prompt: Optional[str] = Field(title="Prompt")
|
||||||
negative_prompt: Optional[str] = Field(title="Negative Prompt")
|
negative_prompt: Optional[str] = Field(title="Negative Prompt")
|
||||||
|
|
||||||
class ArtistItem(BaseModel):
|
|
||||||
name: str = Field(title="Name")
|
|
||||||
score: float = Field(title="Score")
|
|
||||||
category: str = Field(title="Category")
|
|
||||||
|
|
||||||
class EmbeddingItem(BaseModel):
|
class EmbeddingItem(BaseModel):
|
||||||
step: Optional[int] = Field(title="Step", description="The number of steps that were used to train this embedding, if available")
|
step: Optional[int] = Field(title="Step", description="The number of steps that were used to train this embedding, if available")
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
from functools import wraps
|
||||||
import html
|
import html
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
@ -18,6 +19,7 @@ def wrap_queued_call(func):
|
|||||||
|
|
||||||
|
|
||||||
def wrap_gradio_gpu_call(func, extra_outputs=None):
|
def wrap_gradio_gpu_call(func, extra_outputs=None):
|
||||||
|
@wraps(func)
|
||||||
def f(*args, **kwargs):
|
def f(*args, **kwargs):
|
||||||
|
|
||||||
# if the first argument is a string that says "task(...)", it is treated as a job id
|
# if the first argument is a string that says "task(...)", it is treated as a job id
|
||||||
@ -28,7 +30,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None):
|
|||||||
id_task = None
|
id_task = None
|
||||||
|
|
||||||
with queue_lock:
|
with queue_lock:
|
||||||
shared.state.begin()
|
shared.state.begin(job=id_task)
|
||||||
progress.start_task(id_task)
|
progress.start_task(id_task)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -45,6 +47,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None):
|
|||||||
|
|
||||||
|
|
||||||
def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
|
def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
|
||||||
|
@wraps(func)
|
||||||
def f(*args, extra_outputs_array=extra_outputs, **kwargs):
|
def f(*args, extra_outputs_array=extra_outputs, **kwargs):
|
||||||
run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats
|
run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats
|
||||||
if run_memmon:
|
if run_memmon:
|
||||||
|
@ -106,4 +106,4 @@ parser.add_argument("--skip-version-check", action='store_true', help="Do not ch
|
|||||||
parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
|
parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
|
||||||
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
|
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
|
||||||
parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy')
|
parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy')
|
||||||
parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server')
|
parser.add_argument('--api-server-stop', action='store_true', help='enable server stop/restart/kill via api')
|
||||||
|
@ -15,7 +15,6 @@ model_dir = "Codeformer"
|
|||||||
model_path = os.path.join(models_path, model_dir)
|
model_path = os.path.join(models_path, model_dir)
|
||||||
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
|
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
|
||||||
|
|
||||||
have_codeformer = False
|
|
||||||
codeformer = None
|
codeformer = None
|
||||||
|
|
||||||
|
|
||||||
@ -100,7 +99,7 @@ def setup_model(dirname):
|
|||||||
output = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True)[0]
|
output = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True)[0]
|
||||||
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
|
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
|
||||||
del output
|
del output
|
||||||
torch.cuda.empty_cache()
|
devices.torch_gc()
|
||||||
except Exception:
|
except Exception:
|
||||||
errors.report('Failed inference for CodeFormer', exc_info=True)
|
errors.report('Failed inference for CodeFormer', exc_info=True)
|
||||||
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
|
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
|
||||||
@ -123,9 +122,6 @@ def setup_model(dirname):
|
|||||||
|
|
||||||
return restored_img
|
return restored_img
|
||||||
|
|
||||||
global have_codeformer
|
|
||||||
have_codeformer = True
|
|
||||||
|
|
||||||
global codeformer
|
global codeformer
|
||||||
codeformer = FaceRestorerCodeFormer(dirname)
|
codeformer = FaceRestorerCodeFormer(dirname)
|
||||||
shared.face_restorers.append(codeformer)
|
shared.face_restorers.append(codeformer)
|
||||||
|
@ -15,13 +15,6 @@ def has_mps() -> bool:
|
|||||||
else:
|
else:
|
||||||
return mac_specific.has_mps
|
return mac_specific.has_mps
|
||||||
|
|
||||||
def extract_device_id(args, name):
|
|
||||||
for x in range(len(args)):
|
|
||||||
if name in args[x]:
|
|
||||||
return args[x + 1]
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def get_cuda_device_string():
|
def get_cuda_device_string():
|
||||||
from modules import shared
|
from modules import shared
|
||||||
@ -56,11 +49,15 @@ def get_device_for(task):
|
|||||||
|
|
||||||
|
|
||||||
def torch_gc():
|
def torch_gc():
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
with torch.cuda.device(get_cuda_device_string()):
|
with torch.cuda.device(get_cuda_device_string()):
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
torch.cuda.ipc_collect()
|
torch.cuda.ipc_collect()
|
||||||
|
|
||||||
|
if has_mps():
|
||||||
|
mac_specific.torch_mps_gc()
|
||||||
|
|
||||||
|
|
||||||
def enable_tf32():
|
def enable_tf32():
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
|
@ -1,15 +1,13 @@
|
|||||||
import os
|
import sys
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from basicsr.utils.download_util import load_file_from_url
|
|
||||||
|
|
||||||
import modules.esrgan_model_arch as arch
|
import modules.esrgan_model_arch as arch
|
||||||
from modules import modelloader, images, devices
|
from modules import modelloader, images, devices
|
||||||
from modules.upscaler import Upscaler, UpscalerData
|
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
|
from modules.upscaler import Upscaler, UpscalerData
|
||||||
|
|
||||||
|
|
||||||
def mod2normal(state_dict):
|
def mod2normal(state_dict):
|
||||||
@ -134,7 +132,7 @@ class UpscalerESRGAN(Upscaler):
|
|||||||
scaler_data = UpscalerData(self.model_name, self.model_url, self, 4)
|
scaler_data = UpscalerData(self.model_name, self.model_url, self, 4)
|
||||||
scalers.append(scaler_data)
|
scalers.append(scaler_data)
|
||||||
for file in model_paths:
|
for file in model_paths:
|
||||||
if "http" in file:
|
if file.startswith("http"):
|
||||||
name = self.model_name
|
name = self.model_name
|
||||||
else:
|
else:
|
||||||
name = modelloader.friendly_name(file)
|
name = modelloader.friendly_name(file)
|
||||||
@ -143,26 +141,25 @@ class UpscalerESRGAN(Upscaler):
|
|||||||
self.scalers.append(scaler_data)
|
self.scalers.append(scaler_data)
|
||||||
|
|
||||||
def do_upscale(self, img, selected_model):
|
def do_upscale(self, img, selected_model):
|
||||||
model = self.load_model(selected_model)
|
try:
|
||||||
if model is None:
|
model = self.load_model(selected_model)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Unable to load ESRGAN model {selected_model}: {e}", file=sys.stderr)
|
||||||
return img
|
return img
|
||||||
model.to(devices.device_esrgan)
|
model.to(devices.device_esrgan)
|
||||||
img = esrgan_upscale(model, img)
|
img = esrgan_upscale(model, img)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
def load_model(self, path: str):
|
def load_model(self, path: str):
|
||||||
if "http" in path:
|
if path.startswith("http"):
|
||||||
filename = load_file_from_url(
|
# TODO: this doesn't use `path` at all?
|
||||||
|
filename = modelloader.load_file_from_url(
|
||||||
url=self.model_url,
|
url=self.model_url,
|
||||||
model_dir=self.model_download_path,
|
model_dir=self.model_download_path,
|
||||||
file_name=f"{self.model_name}.pth",
|
file_name=f"{self.model_name}.pth",
|
||||||
progress=True,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
filename = path
|
filename = path
|
||||||
if not os.path.exists(filename) or filename is None:
|
|
||||||
print(f"Unable to load {self.model_path} from {filename}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
|
state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
|
||||||
|
|
||||||
|
@ -103,6 +103,9 @@ def activate(p, extra_network_data):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
errors.display(e, f"activating extra network {extra_network_name}")
|
errors.display(e, f"activating extra network {extra_network_name}")
|
||||||
|
|
||||||
|
if p.scripts is not None:
|
||||||
|
p.scripts.after_extra_networks_activate(p, batch_number=p.iteration, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds, extra_network_data=extra_network_data)
|
||||||
|
|
||||||
|
|
||||||
def deactivate(p, extra_network_data):
|
def deactivate(p, extra_network_data):
|
||||||
"""call deactivate for extra networks in extra_network_data in specified order, then call
|
"""call deactivate for extra networks in extra_network_data in specified order, then call
|
||||||
|
@ -73,8 +73,7 @@ def to_half(tensor, enable):
|
|||||||
|
|
||||||
|
|
||||||
def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata):
|
def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata):
|
||||||
shared.state.begin()
|
shared.state.begin(job="model-merge")
|
||||||
shared.state.job = 'model-merge'
|
|
||||||
|
|
||||||
def fail(message):
|
def fail(message):
|
||||||
shared.state.textinfo = message
|
shared.state.textinfo = message
|
||||||
|
@ -174,31 +174,6 @@ def send_image_and_dimensions(x):
|
|||||||
return img, w, h
|
return img, w, h
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def find_hypernetwork_key(hypernet_name, hypernet_hash=None):
|
|
||||||
"""Determines the config parameter name to use for the hypernet based on the parameters in the infotext.
|
|
||||||
|
|
||||||
Example: an infotext provides "Hypernet: ke-ta" and "Hypernet hash: 1234abcd". For the "Hypernet" config
|
|
||||||
parameter this means there should be an entry that looks like "ke-ta-10000(1234abcd)" to set it to.
|
|
||||||
|
|
||||||
If the infotext has no hash, then a hypernet with the same name will be selected instead.
|
|
||||||
"""
|
|
||||||
hypernet_name = hypernet_name.lower()
|
|
||||||
if hypernet_hash is not None:
|
|
||||||
# Try to match the hash in the name
|
|
||||||
for hypernet_key in shared.hypernetworks.keys():
|
|
||||||
result = re_hypernet_hash.search(hypernet_key)
|
|
||||||
if result is not None and result[1] == hypernet_hash:
|
|
||||||
return hypernet_key
|
|
||||||
else:
|
|
||||||
# Fall back to a hypernet with the same name
|
|
||||||
for hypernet_key in shared.hypernetworks.keys():
|
|
||||||
if hypernet_key.lower().startswith(hypernet_name):
|
|
||||||
return hypernet_key
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def restore_old_hires_fix_params(res):
|
def restore_old_hires_fix_params(res):
|
||||||
"""for infotexts that specify old First pass size parameter, convert it into
|
"""for infotexts that specify old First pass size parameter, convert it into
|
||||||
width, height, and hr scale"""
|
width, height, and hr scale"""
|
||||||
@ -332,10 +307,6 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
settings_map = {}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
infotext_to_setting_name_mapping = [
|
infotext_to_setting_name_mapping = [
|
||||||
('Clip skip', 'CLIP_stop_at_last_layers', ),
|
('Clip skip', 'CLIP_stop_at_last_layers', ),
|
||||||
('Conditional mask weight', 'inpainting_mask_weight'),
|
('Conditional mask weight', 'inpainting_mask_weight'),
|
||||||
|
@ -25,7 +25,7 @@ def gfpgann():
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN")
|
models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN")
|
||||||
if len(models) == 1 and "http" in models[0]:
|
if len(models) == 1 and models[0].startswith("http"):
|
||||||
model_file = models[0]
|
model_file = models[0]
|
||||||
elif len(models) != 0:
|
elif len(models) != 0:
|
||||||
latest_file = max(models, key=os.path.getctime)
|
latest_file = max(models, key=os.path.getctime)
|
||||||
|
@ -3,6 +3,7 @@ import glob
|
|||||||
import html
|
import html
|
||||||
import os
|
import os
|
||||||
import inspect
|
import inspect
|
||||||
|
from contextlib import closing
|
||||||
|
|
||||||
import modules.textual_inversion.dataset
|
import modules.textual_inversion.dataset
|
||||||
import torch
|
import torch
|
||||||
@ -353,17 +354,6 @@ def load_hypernetworks(names, multipliers=None):
|
|||||||
shared.loaded_hypernetworks.append(hypernetwork)
|
shared.loaded_hypernetworks.append(hypernetwork)
|
||||||
|
|
||||||
|
|
||||||
def find_closest_hypernetwork_name(search: str):
|
|
||||||
if not search:
|
|
||||||
return None
|
|
||||||
search = search.lower()
|
|
||||||
applicable = [name for name in shared.hypernetworks if search in name.lower()]
|
|
||||||
if not applicable:
|
|
||||||
return None
|
|
||||||
applicable = sorted(applicable, key=lambda name: len(name))
|
|
||||||
return applicable[0]
|
|
||||||
|
|
||||||
|
|
||||||
def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None):
|
def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None):
|
||||||
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None)
|
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None)
|
||||||
|
|
||||||
@ -446,18 +436,6 @@ def statistics(data):
|
|||||||
return total_information, recent_information
|
return total_information, recent_information
|
||||||
|
|
||||||
|
|
||||||
def report_statistics(loss_info:dict):
|
|
||||||
keys = sorted(loss_info.keys(), key=lambda x: sum(loss_info[x]) / len(loss_info[x]))
|
|
||||||
for key in keys:
|
|
||||||
try:
|
|
||||||
print("Loss statistics for file " + key)
|
|
||||||
info, recent = statistics(list(loss_info[key]))
|
|
||||||
print(info)
|
|
||||||
print(recent)
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
|
|
||||||
|
|
||||||
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
|
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
|
||||||
# Remove illegal characters from name.
|
# Remove illegal characters from name.
|
||||||
name = "".join( x for x in name if (x.isalnum() or x in "._- "))
|
name = "".join( x for x in name if (x.isalnum() or x in "._- "))
|
||||||
@ -734,8 +712,9 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||||||
|
|
||||||
preview_text = p.prompt
|
preview_text = p.prompt
|
||||||
|
|
||||||
processed = processing.process_images(p)
|
with closing(p):
|
||||||
image = processed.images[0] if len(processed.images) > 0 else None
|
processed = processing.process_images(p)
|
||||||
|
image = processed.images[0] if len(processed.images) > 0 else None
|
||||||
|
|
||||||
if unload:
|
if unload:
|
||||||
shared.sd_model.cond_stage_model.to(devices.cpu)
|
shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||||
@ -770,7 +749,6 @@ Last saved image: {html.escape(last_saved_image)}<br/>
|
|||||||
pbar.leave = False
|
pbar.leave = False
|
||||||
pbar.close()
|
pbar.close()
|
||||||
hypernetwork.eval()
|
hypernetwork.eval()
|
||||||
#report_statistics(loss_dict)
|
|
||||||
sd_hijack_checkpoint.remove()
|
sd_hijack_checkpoint.remove()
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
import pytz
|
import pytz
|
||||||
@ -10,7 +12,7 @@ import re
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import piexif
|
import piexif
|
||||||
import piexif.helper
|
import piexif.helper
|
||||||
from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
|
from PIL import Image, ImageFont, ImageDraw, ImageColor, PngImagePlugin
|
||||||
import string
|
import string
|
||||||
import json
|
import json
|
||||||
import hashlib
|
import hashlib
|
||||||
@ -139,6 +141,11 @@ class GridAnnotation:
|
|||||||
|
|
||||||
|
|
||||||
def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
|
def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
|
||||||
|
|
||||||
|
color_active = ImageColor.getcolor(opts.grid_text_active_color, 'RGB')
|
||||||
|
color_inactive = ImageColor.getcolor(opts.grid_text_inactive_color, 'RGB')
|
||||||
|
color_background = ImageColor.getcolor(opts.grid_background_color, 'RGB')
|
||||||
|
|
||||||
def wrap(drawing, text, font, line_length):
|
def wrap(drawing, text, font, line_length):
|
||||||
lines = ['']
|
lines = ['']
|
||||||
for word in text.split():
|
for word in text.split():
|
||||||
@ -168,9 +175,6 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
|
|||||||
|
|
||||||
fnt = get_font(fontsize)
|
fnt = get_font(fontsize)
|
||||||
|
|
||||||
color_active = (0, 0, 0)
|
|
||||||
color_inactive = (153, 153, 153)
|
|
||||||
|
|
||||||
pad_left = 0 if sum([sum([len(line.text) for line in lines]) for lines in ver_texts]) == 0 else width * 3 // 4
|
pad_left = 0 if sum([sum([len(line.text) for line in lines]) for lines in ver_texts]) == 0 else width * 3 // 4
|
||||||
|
|
||||||
cols = im.width // width
|
cols = im.width // width
|
||||||
@ -179,7 +183,7 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
|
|||||||
assert cols == len(hor_texts), f'bad number of horizontal texts: {len(hor_texts)}; must be {cols}'
|
assert cols == len(hor_texts), f'bad number of horizontal texts: {len(hor_texts)}; must be {cols}'
|
||||||
assert rows == len(ver_texts), f'bad number of vertical texts: {len(ver_texts)}; must be {rows}'
|
assert rows == len(ver_texts), f'bad number of vertical texts: {len(ver_texts)}; must be {rows}'
|
||||||
|
|
||||||
calc_img = Image.new("RGB", (1, 1), "white")
|
calc_img = Image.new("RGB", (1, 1), color_background)
|
||||||
calc_d = ImageDraw.Draw(calc_img)
|
calc_d = ImageDraw.Draw(calc_img)
|
||||||
|
|
||||||
for texts, allowed_width in zip(hor_texts + ver_texts, [width] * len(hor_texts) + [pad_left] * len(ver_texts)):
|
for texts, allowed_width in zip(hor_texts + ver_texts, [width] * len(hor_texts) + [pad_left] * len(ver_texts)):
|
||||||
@ -200,7 +204,7 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
|
|||||||
|
|
||||||
pad_top = 0 if sum(hor_text_heights) == 0 else max(hor_text_heights) + line_spacing * 2
|
pad_top = 0 if sum(hor_text_heights) == 0 else max(hor_text_heights) + line_spacing * 2
|
||||||
|
|
||||||
result = Image.new("RGB", (im.width + pad_left + margin * (cols-1), im.height + pad_top + margin * (rows-1)), "white")
|
result = Image.new("RGB", (im.width + pad_left + margin * (cols-1), im.height + pad_top + margin * (rows-1)), color_background)
|
||||||
|
|
||||||
for row in range(rows):
|
for row in range(rows):
|
||||||
for col in range(cols):
|
for col in range(cols):
|
||||||
@ -372,8 +376,8 @@ class FilenameGenerator:
|
|||||||
'hasprompt': lambda self, *args: self.hasprompt(*args), # accepts formats:[hasprompt<prompt1|default><prompt2>..]
|
'hasprompt': lambda self, *args: self.hasprompt(*args), # accepts formats:[hasprompt<prompt1|default><prompt2>..]
|
||||||
'clip_skip': lambda self: opts.data["CLIP_stop_at_last_layers"],
|
'clip_skip': lambda self: opts.data["CLIP_stop_at_last_layers"],
|
||||||
'denoising': lambda self: self.p.denoising_strength if self.p and self.p.denoising_strength else NOTHING_AND_SKIP_PREVIOUS_TEXT,
|
'denoising': lambda self: self.p.denoising_strength if self.p and self.p.denoising_strength else NOTHING_AND_SKIP_PREVIOUS_TEXT,
|
||||||
|
'user': lambda self: self.p.user,
|
||||||
'vae_filename': lambda self: self.get_vae_filename(),
|
'vae_filename': lambda self: self.get_vae_filename(),
|
||||||
|
|
||||||
}
|
}
|
||||||
default_time_format = '%Y%m%d%H%M%S'
|
default_time_format = '%Y%m%d%H%M%S'
|
||||||
|
|
||||||
@ -497,13 +501,23 @@ def get_next_sequence_number(path, basename):
|
|||||||
return result + 1
|
return result + 1
|
||||||
|
|
||||||
|
|
||||||
def save_image_with_geninfo(image, geninfo, filename, extension=None, existing_pnginfo=None):
|
def save_image_with_geninfo(image, geninfo, filename, extension=None, existing_pnginfo=None, pnginfo_section_name='parameters'):
|
||||||
|
"""
|
||||||
|
Saves image to filename, including geninfo as text information for generation info.
|
||||||
|
For PNG images, geninfo is added to existing pnginfo dictionary using the pnginfo_section_name argument as key.
|
||||||
|
For JPG images, there's no dictionary and geninfo just replaces the EXIF description.
|
||||||
|
"""
|
||||||
|
|
||||||
if extension is None:
|
if extension is None:
|
||||||
extension = os.path.splitext(filename)[1]
|
extension = os.path.splitext(filename)[1]
|
||||||
|
|
||||||
image_format = Image.registered_extensions()[extension]
|
image_format = Image.registered_extensions()[extension]
|
||||||
|
|
||||||
if extension.lower() == '.png':
|
if extension.lower() == '.png':
|
||||||
|
existing_pnginfo = existing_pnginfo or {}
|
||||||
|
if opts.enable_pnginfo:
|
||||||
|
existing_pnginfo[pnginfo_section_name] = geninfo
|
||||||
|
|
||||||
if opts.enable_pnginfo:
|
if opts.enable_pnginfo:
|
||||||
pnginfo_data = PngImagePlugin.PngInfo()
|
pnginfo_data = PngImagePlugin.PngInfo()
|
||||||
for k, v in (existing_pnginfo or {}).items():
|
for k, v in (existing_pnginfo or {}).items():
|
||||||
@ -622,7 +636,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
|||||||
"""
|
"""
|
||||||
temp_file_path = f"{filename_without_extension}.tmp"
|
temp_file_path = f"{filename_without_extension}.tmp"
|
||||||
|
|
||||||
save_image_with_geninfo(image_to_save, info, temp_file_path, extension, params.pnginfo)
|
save_image_with_geninfo(image_to_save, info, temp_file_path, extension, existing_pnginfo=params.pnginfo, pnginfo_section_name=pnginfo_section_name)
|
||||||
|
|
||||||
os.replace(temp_file_path, filename_without_extension + extension)
|
os.replace(temp_file_path, filename_without_extension + extension)
|
||||||
|
|
||||||
@ -639,12 +653,18 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
|||||||
oversize = image.width > opts.target_side_length or image.height > opts.target_side_length
|
oversize = image.width > opts.target_side_length or image.height > opts.target_side_length
|
||||||
if opts.export_for_4chan and (oversize or os.stat(fullfn).st_size > opts.img_downscale_threshold * 1024 * 1024):
|
if opts.export_for_4chan and (oversize or os.stat(fullfn).st_size > opts.img_downscale_threshold * 1024 * 1024):
|
||||||
ratio = image.width / image.height
|
ratio = image.width / image.height
|
||||||
|
resize_to = None
|
||||||
if oversize and ratio > 1:
|
if oversize and ratio > 1:
|
||||||
image = image.resize((round(opts.target_side_length), round(image.height * opts.target_side_length / image.width)), LANCZOS)
|
resize_to = round(opts.target_side_length), round(image.height * opts.target_side_length / image.width)
|
||||||
elif oversize:
|
elif oversize:
|
||||||
image = image.resize((round(image.width * opts.target_side_length / image.height), round(opts.target_side_length)), LANCZOS)
|
resize_to = round(image.width * opts.target_side_length / image.height), round(opts.target_side_length)
|
||||||
|
|
||||||
|
if resize_to is not None:
|
||||||
|
try:
|
||||||
|
# Resizing image with LANCZOS could throw an exception if e.g. image mode is I;16
|
||||||
|
image = image.resize(resize_to, LANCZOS)
|
||||||
|
except Exception:
|
||||||
|
image = image.resize(resize_to)
|
||||||
try:
|
try:
|
||||||
_atomically_save_image(image, fullfn_without_extension, ".jpg")
|
_atomically_save_image(image, fullfn_without_extension, ".jpg")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -662,8 +682,15 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
|||||||
return fullfn, txt_fullfn
|
return fullfn, txt_fullfn
|
||||||
|
|
||||||
|
|
||||||
def read_info_from_image(image):
|
IGNORED_INFO_KEYS = {
|
||||||
items = image.info or {}
|
'jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
|
||||||
|
'loop', 'background', 'timestamp', 'duration', 'progressive', 'progression',
|
||||||
|
'icc_profile', 'chromaticity', 'photoshop',
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def read_info_from_image(image: Image.Image) -> tuple[str | None, dict]:
|
||||||
|
items = (image.info or {}).copy()
|
||||||
|
|
||||||
geninfo = items.pop('parameters', None)
|
geninfo = items.pop('parameters', None)
|
||||||
|
|
||||||
@ -679,9 +706,7 @@ def read_info_from_image(image):
|
|||||||
items['exif comment'] = exif_comment
|
items['exif comment'] = exif_comment
|
||||||
geninfo = exif_comment
|
geninfo = exif_comment
|
||||||
|
|
||||||
for field in ['jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
|
for field in IGNORED_INFO_KEYS:
|
||||||
'loop', 'background', 'timestamp', 'duration', 'progressive', 'progression',
|
|
||||||
'icc_profile', 'chromaticity']:
|
|
||||||
items.pop(field, None)
|
items.pop(field, None)
|
||||||
|
|
||||||
if items.get("Software", None) == "NovelAI":
|
if items.get("Software", None) == "NovelAI":
|
||||||
|
@ -1,11 +1,13 @@
|
|||||||
import os
|
import os
|
||||||
|
from contextlib import closing
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError
|
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
from modules import sd_samplers
|
from modules import sd_samplers, images as imgutil
|
||||||
from modules.generation_parameters_copypaste import create_override_settings_dict
|
from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters
|
||||||
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
|
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
|
||||||
from modules.shared import opts, state
|
from modules.shared import opts, state
|
||||||
from modules.images import save_image
|
from modules.images import save_image
|
||||||
@ -15,10 +17,10 @@ from modules.ui import plaintext_to_html
|
|||||||
import modules.scripts
|
import modules.scripts
|
||||||
|
|
||||||
|
|
||||||
def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=False, scale_by=1.0):
|
def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=False, scale_by=1.0, use_png_info=False, png_info_props=None, png_info_dir=None):
|
||||||
processing.fix_seed(p)
|
processing.fix_seed(p)
|
||||||
|
|
||||||
images = shared.listfiles(input_dir)
|
images = list(shared.walk_files(input_dir, allowed_extensions=(".png", ".jpg", ".jpeg", ".webp")))
|
||||||
|
|
||||||
is_inpaint_batch = False
|
is_inpaint_batch = False
|
||||||
if inpaint_mask_dir:
|
if inpaint_mask_dir:
|
||||||
@ -37,6 +39,14 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
|
|||||||
|
|
||||||
state.job_count = len(images) * p.n_iter
|
state.job_count = len(images) * p.n_iter
|
||||||
|
|
||||||
|
# extract "default" params to use in case getting png info fails
|
||||||
|
prompt = p.prompt
|
||||||
|
negative_prompt = p.negative_prompt
|
||||||
|
seed = p.seed
|
||||||
|
cfg_scale = p.cfg_scale
|
||||||
|
sampler_name = p.sampler_name
|
||||||
|
steps = p.steps
|
||||||
|
|
||||||
for i, image in enumerate(images):
|
for i, image in enumerate(images):
|
||||||
state.job = f"{i+1} out of {len(images)}"
|
state.job = f"{i+1} out of {len(images)}"
|
||||||
if state.skipped:
|
if state.skipped:
|
||||||
@ -80,6 +90,25 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
|
|||||||
mask_image = Image.open(mask_image_path)
|
mask_image = Image.open(mask_image_path)
|
||||||
p.image_mask = mask_image
|
p.image_mask = mask_image
|
||||||
|
|
||||||
|
if use_png_info:
|
||||||
|
try:
|
||||||
|
info_img = img
|
||||||
|
if png_info_dir:
|
||||||
|
info_img_path = os.path.join(png_info_dir, os.path.basename(image))
|
||||||
|
info_img = Image.open(info_img_path)
|
||||||
|
geninfo, _ = imgutil.read_info_from_image(info_img)
|
||||||
|
parsed_parameters = parse_generation_parameters(geninfo)
|
||||||
|
parsed_parameters = {k: v for k, v in parsed_parameters.items() if k in (png_info_props or {})}
|
||||||
|
except Exception:
|
||||||
|
parsed_parameters = {}
|
||||||
|
|
||||||
|
p.prompt = prompt + (" " + parsed_parameters["Prompt"] if "Prompt" in parsed_parameters else "")
|
||||||
|
p.negative_prompt = negative_prompt + (" " + parsed_parameters["Negative prompt"] if "Negative prompt" in parsed_parameters else "")
|
||||||
|
p.seed = int(parsed_parameters.get("Seed", seed))
|
||||||
|
p.cfg_scale = float(parsed_parameters.get("CFG scale", cfg_scale))
|
||||||
|
p.sampler_name = parsed_parameters.get("Sampler", sampler_name)
|
||||||
|
p.steps = int(parsed_parameters.get("Steps", steps))
|
||||||
|
|
||||||
proc = modules.scripts.scripts_img2img.run(p, *args)
|
proc = modules.scripts.scripts_img2img.run(p, *args)
|
||||||
if proc is None:
|
if proc is None:
|
||||||
proc = process_images(p)
|
proc = process_images(p)
|
||||||
@ -87,18 +116,19 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
|
|||||||
for n, processed_image in enumerate(proc.images):
|
for n, processed_image in enumerate(proc.images):
|
||||||
filename = image_path.stem
|
filename = image_path.stem
|
||||||
infotext = proc.infotext(p, n)
|
infotext = proc.infotext(p, n)
|
||||||
|
relpath = os.path.dirname(os.path.relpath(image, input_dir))
|
||||||
|
|
||||||
if n > 0:
|
if n > 0:
|
||||||
filename += f"-{n}"
|
filename += f"-{n}"
|
||||||
|
|
||||||
if not save_normally:
|
if not save_normally:
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(os.path.join(output_dir, relpath), exist_ok=True)
|
||||||
if processed_image.mode == 'RGBA':
|
if processed_image.mode == 'RGBA':
|
||||||
processed_image = processed_image.convert("RGB")
|
processed_image = processed_image.convert("RGB")
|
||||||
save_image(processed_image, output_dir, None, extension=opts.samples_format, info=infotext, forced_filename=filename, save_to_dirs=False)
|
save_image(processed_image, os.path.join(output_dir, relpath), None, extension=opts.samples_format, info=infotext, forced_filename=filename, save_to_dirs=False)
|
||||||
|
|
||||||
|
|
||||||
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args):
|
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
|
||||||
override_settings = create_override_settings_dict(override_settings_texts)
|
override_settings = create_override_settings_dict(override_settings_texts)
|
||||||
|
|
||||||
is_batch = mode == 5
|
is_batch = mode == 5
|
||||||
@ -181,24 +211,25 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
|
|||||||
p.scripts = modules.scripts.scripts_img2img
|
p.scripts = modules.scripts.scripts_img2img
|
||||||
p.script_args = args
|
p.script_args = args
|
||||||
|
|
||||||
|
p.user = request.username
|
||||||
|
|
||||||
if shared.cmd_opts.enable_console_prompts:
|
if shared.cmd_opts.enable_console_prompts:
|
||||||
print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
|
print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
|
||||||
|
|
||||||
if mask:
|
if mask:
|
||||||
p.extra_generation_params["Mask blur"] = mask_blur
|
p.extra_generation_params["Mask blur"] = mask_blur
|
||||||
|
|
||||||
if is_batch:
|
with closing(p):
|
||||||
assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
|
if is_batch:
|
||||||
|
assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
|
||||||
|
|
||||||
process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by)
|
process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by, use_png_info=img2img_batch_use_png_info, png_info_props=img2img_batch_png_info_props, png_info_dir=img2img_batch_png_info_dir)
|
||||||
|
|
||||||
processed = Processed(p, [], p.seed, "")
|
processed = Processed(p, [], p.seed, "")
|
||||||
else:
|
else:
|
||||||
processed = modules.scripts.scripts_img2img.run(p, *args)
|
processed = modules.scripts.scripts_img2img.run(p, *args)
|
||||||
if processed is None:
|
if processed is None:
|
||||||
processed = process_images(p)
|
processed = process_images(p)
|
||||||
|
|
||||||
p.close()
|
|
||||||
|
|
||||||
shared.total_tqdm.clear()
|
shared.total_tqdm.clear()
|
||||||
|
|
||||||
|
@ -184,8 +184,7 @@ class InterrogateModels:
|
|||||||
|
|
||||||
def interrogate(self, pil_image):
|
def interrogate(self, pil_image):
|
||||||
res = ""
|
res = ""
|
||||||
shared.state.begin()
|
shared.state.begin(job="interrogate")
|
||||||
shared.state.job = 'interrogate'
|
|
||||||
try:
|
try:
|
||||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||||
lowvram.send_everything_to_cpu()
|
lowvram.send_everything_to_cpu()
|
||||||
|
@ -142,15 +142,15 @@ def git_clone(url, dir, name, commithash=None):
|
|||||||
if commithash is None:
|
if commithash is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
current_hash = run(f'"{git}" -C "{dir}" rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}").strip()
|
current_hash = run(f'"{git}" -C "{dir}" rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}", live=False).strip()
|
||||||
if current_hash == commithash:
|
if current_hash == commithash:
|
||||||
return
|
return
|
||||||
|
|
||||||
run(f'"{git}" -C "{dir}" fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")
|
run(f'"{git}" -C "{dir}" fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")
|
||||||
run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}")
|
run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}", live=True)
|
||||||
return
|
return
|
||||||
|
|
||||||
run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}")
|
run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}", live=True)
|
||||||
|
|
||||||
if commithash is not None:
|
if commithash is not None:
|
||||||
run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
|
run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
|
||||||
|
@ -1,22 +1,41 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import platform
|
import platform
|
||||||
from modules.sd_hijack_utils import CondFunc
|
from modules.sd_hijack_utils import CondFunc
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
|
log = logging.getLogger()
|
||||||
|
|
||||||
# has_mps is only available in nightly pytorch (for now) and macOS 12.3+.
|
|
||||||
# check `getattr` and try it for compatibility
|
# before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+,
|
||||||
|
# use check `getattr` and try it for compatibility.
|
||||||
|
# in torch version 1.13, backends.mps.is_available() and backends.mps.is_built() are introduced in to check mps availabilty,
|
||||||
|
# since torch 2.0.1+ nightly build, getattr(torch, 'has_mps', False) was deprecated, see https://github.com/pytorch/pytorch/pull/103279
|
||||||
def check_for_mps() -> bool:
|
def check_for_mps() -> bool:
|
||||||
if not getattr(torch, 'has_mps', False):
|
if version.parse(torch.__version__) <= version.parse("2.0.1"):
|
||||||
return False
|
if not getattr(torch, 'has_mps', False):
|
||||||
try:
|
return False
|
||||||
torch.zeros(1).to(torch.device("mps"))
|
try:
|
||||||
return True
|
torch.zeros(1).to(torch.device("mps"))
|
||||||
except Exception:
|
return True
|
||||||
return False
|
except Exception:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return torch.backends.mps.is_available() and torch.backends.mps.is_built()
|
||||||
|
|
||||||
|
|
||||||
has_mps = check_for_mps()
|
has_mps = check_for_mps()
|
||||||
|
|
||||||
|
|
||||||
|
def torch_mps_gc() -> None:
|
||||||
|
try:
|
||||||
|
from torch.mps import empty_cache
|
||||||
|
empty_cache()
|
||||||
|
except Exception:
|
||||||
|
log.warning("MPS garbage collection failed", exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/89784
|
# MPS workaround for https://github.com/pytorch/pytorch/issues/89784
|
||||||
def cumsum_fix(input, cumsum_func, *args, **kwargs):
|
def cumsum_fix(input, cumsum_func, *args, **kwargs):
|
||||||
if input.device.type == 'mps':
|
if input.device.type == 'mps':
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import importlib
|
import importlib
|
||||||
@ -8,6 +10,29 @@ from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, Upscale
|
|||||||
from modules.paths import script_path, models_path
|
from modules.paths import script_path, models_path
|
||||||
|
|
||||||
|
|
||||||
|
def load_file_from_url(
|
||||||
|
url: str,
|
||||||
|
*,
|
||||||
|
model_dir: str,
|
||||||
|
progress: bool = True,
|
||||||
|
file_name: str | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""Download a file from `url` into `model_dir`, using the file present if possible.
|
||||||
|
|
||||||
|
Returns the path to the downloaded file.
|
||||||
|
"""
|
||||||
|
os.makedirs(model_dir, exist_ok=True)
|
||||||
|
if not file_name:
|
||||||
|
parts = urlparse(url)
|
||||||
|
file_name = os.path.basename(parts.path)
|
||||||
|
cached_file = os.path.abspath(os.path.join(model_dir, file_name))
|
||||||
|
if not os.path.exists(cached_file):
|
||||||
|
print(f'Downloading: "{url}" to {cached_file}\n')
|
||||||
|
from torch.hub import download_url_to_file
|
||||||
|
download_url_to_file(url, cached_file, progress=progress)
|
||||||
|
return cached_file
|
||||||
|
|
||||||
|
|
||||||
def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None) -> list:
|
def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None) -> list:
|
||||||
"""
|
"""
|
||||||
A one-and done loader to try finding the desired models in specified directories.
|
A one-and done loader to try finding the desired models in specified directories.
|
||||||
@ -46,9 +71,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
|
|||||||
|
|
||||||
if model_url is not None and len(output) == 0:
|
if model_url is not None and len(output) == 0:
|
||||||
if download_name is not None:
|
if download_name is not None:
|
||||||
from basicsr.utils.download_util import load_file_from_url
|
output.append(load_file_from_url(model_url, model_dir=places[0], file_name=download_name))
|
||||||
dl = load_file_from_url(model_url, places[0], True, download_name)
|
|
||||||
output.append(dl)
|
|
||||||
else:
|
else:
|
||||||
output.append(model_url)
|
output.append(model_url)
|
||||||
|
|
||||||
@ -59,7 +82,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
|
|||||||
|
|
||||||
|
|
||||||
def friendly_name(file: str):
|
def friendly_name(file: str):
|
||||||
if "http" in file:
|
if file.startswith("http"):
|
||||||
file = urlparse(file).path
|
file = urlparse(file).path
|
||||||
|
|
||||||
file = os.path.basename(file)
|
file = os.path.basename(file)
|
||||||
|
@ -38,17 +38,3 @@ for d, must_exist, what, options in path_dirs:
|
|||||||
else:
|
else:
|
||||||
sys.path.append(d)
|
sys.path.append(d)
|
||||||
paths[what] = d
|
paths[what] = d
|
||||||
|
|
||||||
|
|
||||||
class Prioritize:
|
|
||||||
def __init__(self, name):
|
|
||||||
self.name = name
|
|
||||||
self.path = None
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
self.path = sys.path.copy()
|
|
||||||
sys.path = [paths[self.name]] + sys.path
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
||||||
sys.path = self.path
|
|
||||||
self.path = None
|
|
||||||
|
@ -9,8 +9,7 @@ from modules.shared import opts
|
|||||||
def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True):
|
def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True):
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
shared.state.begin()
|
shared.state.begin(job="extras")
|
||||||
shared.state.job = 'extras'
|
|
||||||
|
|
||||||
image_data = []
|
image_data = []
|
||||||
image_names = []
|
image_names = []
|
||||||
@ -54,7 +53,9 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
|
|||||||
for image, name in zip(image_data, image_names):
|
for image, name in zip(image_data, image_names):
|
||||||
shared.state.textinfo = name
|
shared.state.textinfo = name
|
||||||
|
|
||||||
existing_pnginfo = image.info or {}
|
parameters, existing_pnginfo = images.read_info_from_image(image)
|
||||||
|
if parameters:
|
||||||
|
existing_pnginfo["parameters"] = parameters
|
||||||
|
|
||||||
pp = scripts_postprocessing.PostprocessedImage(image.convert("RGB"))
|
pp = scripts_postprocessing.PostprocessedImage(image.convert("RGB"))
|
||||||
|
|
||||||
|
@ -184,6 +184,8 @@ class StableDiffusionProcessing:
|
|||||||
self.uc = None
|
self.uc = None
|
||||||
self.c = None
|
self.c = None
|
||||||
|
|
||||||
|
self.user = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def sd_model(self):
|
def sd_model(self):
|
||||||
return shared.sd_model
|
return shared.sd_model
|
||||||
@ -549,7 +551,7 @@ def program_version():
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0):
|
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0, use_main_prompt=False):
|
||||||
index = position_in_batch + iteration * p.batch_size
|
index = position_in_batch + iteration * p.batch_size
|
||||||
|
|
||||||
clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
|
clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
|
||||||
@ -573,7 +575,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
|
|||||||
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
|
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
|
||||||
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
|
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
|
||||||
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
|
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
|
||||||
"Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
|
"Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
|
||||||
"Denoising strength": getattr(p, 'denoising_strength', None),
|
"Denoising strength": getattr(p, 'denoising_strength', None),
|
||||||
"Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
|
"Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
|
||||||
"Clip skip": None if clip_skip <= 1 else clip_skip,
|
"Clip skip": None if clip_skip <= 1 else clip_skip,
|
||||||
@ -585,13 +587,15 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
|
|||||||
"NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
|
"NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
|
||||||
**p.extra_generation_params,
|
**p.extra_generation_params,
|
||||||
"Version": program_version() if opts.add_version_to_infotext else None,
|
"Version": program_version() if opts.add_version_to_infotext else None,
|
||||||
|
"User": p.user if opts.add_user_name_to_info else None,
|
||||||
}
|
}
|
||||||
|
|
||||||
generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
|
generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
|
||||||
|
|
||||||
|
prompt_text = p.prompt if use_main_prompt else all_prompts[index]
|
||||||
negative_prompt_text = f"\nNegative prompt: {p.all_negative_prompts[index]}" if p.all_negative_prompts[index] else ""
|
negative_prompt_text = f"\nNegative prompt: {p.all_negative_prompts[index]}" if p.all_negative_prompts[index] else ""
|
||||||
|
|
||||||
return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
|
return f"{prompt_text}{negative_prompt_text}\n{generation_params_text}".strip()
|
||||||
|
|
||||||
|
|
||||||
def process_images(p: StableDiffusionProcessing) -> Processed:
|
def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||||
@ -602,7 +606,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
|
# if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
|
||||||
if sd_models.checkpoint_alisases.get(p.override_settings.get('sd_model_checkpoint')) is None:
|
if sd_models.checkpoint_aliases.get(p.override_settings.get('sd_model_checkpoint')) is None:
|
||||||
p.override_settings.pop('sd_model_checkpoint', None)
|
p.override_settings.pop('sd_model_checkpoint', None)
|
||||||
sd_models.reload_model_weights()
|
sd_models.reload_model_weights()
|
||||||
|
|
||||||
@ -663,8 +667,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
else:
|
else:
|
||||||
p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]
|
p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]
|
||||||
|
|
||||||
def infotext(iteration=0, position_in_batch=0):
|
def infotext(iteration=0, position_in_batch=0, use_main_prompt=False):
|
||||||
return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch)
|
return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch, use_main_prompt)
|
||||||
|
|
||||||
if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
|
if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
|
||||||
model_hijack.embedding_db.load_textual_inversion_embeddings()
|
model_hijack.embedding_db.load_textual_inversion_embeddings()
|
||||||
@ -824,7 +828,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
grid = images.image_grid(output_images, p.batch_size)
|
grid = images.image_grid(output_images, p.batch_size)
|
||||||
|
|
||||||
if opts.return_grid:
|
if opts.return_grid:
|
||||||
text = infotext()
|
text = infotext(use_main_prompt=True)
|
||||||
infotexts.insert(0, text)
|
infotexts.insert(0, text)
|
||||||
if opts.enable_pnginfo:
|
if opts.enable_pnginfo:
|
||||||
grid.info["parameters"] = text
|
grid.info["parameters"] = text
|
||||||
@ -832,7 +836,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
index_of_first_image = 1
|
index_of_first_image = 1
|
||||||
|
|
||||||
if opts.grid_save:
|
if opts.grid_save:
|
||||||
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
|
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(use_main_prompt=True), short_filename=not opts.grid_extended_filename, p=p, grid=True)
|
||||||
|
|
||||||
if not p.disable_extra_networks and p.extra_network_data:
|
if not p.disable_extra_networks and p.extra_network_data:
|
||||||
extra_networks.deactivate(p, p.extra_network_data)
|
extra_networks.deactivate(p, p.extra_network_data)
|
||||||
@ -1074,6 +1078,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
|
|
||||||
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True))
|
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True))
|
||||||
|
|
||||||
|
if self.scripts is not None:
|
||||||
|
self.scripts.before_hr(self)
|
||||||
|
|
||||||
samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
|
samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
|
||||||
|
|
||||||
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
|
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
|
||||||
|
@ -2,7 +2,6 @@ import os
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from basicsr.utils.download_util import load_file_from_url
|
|
||||||
from realesrgan import RealESRGANer
|
from realesrgan import RealESRGANer
|
||||||
|
|
||||||
from modules.upscaler import Upscaler, UpscalerData
|
from modules.upscaler import Upscaler, UpscalerData
|
||||||
@ -43,9 +42,10 @@ class UpscalerRealESRGAN(Upscaler):
|
|||||||
if not self.enable:
|
if not self.enable:
|
||||||
return img
|
return img
|
||||||
|
|
||||||
info = self.load_model(path)
|
try:
|
||||||
if not os.path.exists(info.local_data_path):
|
info = self.load_model(path)
|
||||||
print(f"Unable to load RealESRGAN model: {info.name}")
|
except Exception:
|
||||||
|
errors.report(f"Unable to load RealESRGAN model {path}", exc_info=True)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
upsampler = RealESRGANer(
|
upsampler = RealESRGANer(
|
||||||
@ -63,20 +63,17 @@ class UpscalerRealESRGAN(Upscaler):
|
|||||||
return image
|
return image
|
||||||
|
|
||||||
def load_model(self, path):
|
def load_model(self, path):
|
||||||
try:
|
for scaler in self.scalers:
|
||||||
info = next(iter([scaler for scaler in self.scalers if scaler.data_path == path]), None)
|
if scaler.data_path == path:
|
||||||
|
if scaler.local_data_path.startswith("http"):
|
||||||
if info is None:
|
scaler.local_data_path = modelloader.load_file_from_url(
|
||||||
print(f"Unable to find model info: {path}")
|
scaler.data_path,
|
||||||
return None
|
model_dir=self.model_download_path,
|
||||||
|
)
|
||||||
if info.local_data_path.startswith("http"):
|
if not os.path.exists(scaler.local_data_path):
|
||||||
info.local_data_path = load_file_from_url(url=info.data_path, model_dir=self.model_download_path, progress=True)
|
raise FileNotFoundError(f"RealESRGAN data missing: {scaler.local_data_path}")
|
||||||
|
return scaler
|
||||||
return info
|
raise ValueError(f"Unable to find model info: {path}")
|
||||||
except Exception:
|
|
||||||
errors.report("Error making Real-ESRGAN models list", exc_info=True)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def load_models(self, _):
|
def load_models(self, _):
|
||||||
return get_realesrgan_models(self)
|
return get_realesrgan_models(self)
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
|
import inspect
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
@ -116,6 +117,21 @@ class Script:
|
|||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def after_extra_networks_activate(self, p, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Calledafter extra networks activation, before conds calculation
|
||||||
|
allow modification of the network after extra networks activation been applied
|
||||||
|
won't be call if p.disable_extra_networks
|
||||||
|
|
||||||
|
**kwargs will have those items:
|
||||||
|
- batch_number - index of current batch, from 0 to number of batches-1
|
||||||
|
- prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things
|
||||||
|
- seeds - list of seeds for current batch
|
||||||
|
- subseeds - list of subseeds for current batch
|
||||||
|
- extra_network_data - list of ExtraNetworkParams for current stage
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
def process_batch(self, p, *args, **kwargs):
|
def process_batch(self, p, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Same as process(), but called for every batch.
|
Same as process(), but called for every batch.
|
||||||
@ -186,6 +202,11 @@ class Script:
|
|||||||
|
|
||||||
return f'script_{tabname}{title}_{item_id}'
|
return f'script_{tabname}{title}_{item_id}'
|
||||||
|
|
||||||
|
def before_hr(self, p, *args):
|
||||||
|
"""
|
||||||
|
This function is called before hires fix start.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
current_basedir = paths.script_path
|
current_basedir = paths.script_path
|
||||||
|
|
||||||
@ -249,7 +270,7 @@ def load_scripts():
|
|||||||
|
|
||||||
def register_scripts_from_module(module):
|
def register_scripts_from_module(module):
|
||||||
for script_class in module.__dict__.values():
|
for script_class in module.__dict__.values():
|
||||||
if type(script_class) != type:
|
if not inspect.isclass(script_class):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if issubclass(script_class, Script):
|
if issubclass(script_class, Script):
|
||||||
@ -483,6 +504,14 @@ class ScriptRunner:
|
|||||||
except Exception:
|
except Exception:
|
||||||
errors.report(f"Error running before_process_batch: {script.filename}", exc_info=True)
|
errors.report(f"Error running before_process_batch: {script.filename}", exc_info=True)
|
||||||
|
|
||||||
|
def after_extra_networks_activate(self, p, **kwargs):
|
||||||
|
for script in self.alwayson_scripts:
|
||||||
|
try:
|
||||||
|
script_args = p.script_args[script.args_from:script.args_to]
|
||||||
|
script.after_extra_networks_activate(p, *script_args, **kwargs)
|
||||||
|
except Exception:
|
||||||
|
errors.report(f"Error running after_extra_networks_activate: {script.filename}", exc_info=True)
|
||||||
|
|
||||||
def process_batch(self, p, **kwargs):
|
def process_batch(self, p, **kwargs):
|
||||||
for script in self.alwayson_scripts:
|
for script in self.alwayson_scripts:
|
||||||
try:
|
try:
|
||||||
@ -548,6 +577,15 @@ class ScriptRunner:
|
|||||||
self.scripts[si].args_to = args_to
|
self.scripts[si].args_to = args_to
|
||||||
|
|
||||||
|
|
||||||
|
def before_hr(self, p):
|
||||||
|
for script in self.alwayson_scripts:
|
||||||
|
try:
|
||||||
|
script_args = p.script_args[script.args_from:script.args_to]
|
||||||
|
script.before_hr(p, *script_args)
|
||||||
|
except Exception:
|
||||||
|
errors.report(f"Error running before_hr: {script.filename}", exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
scripts_txt2img: ScriptRunner = None
|
scripts_txt2img: ScriptRunner = None
|
||||||
scripts_img2img: ScriptRunner = None
|
scripts_img2img: ScriptRunner = None
|
||||||
scripts_postproc: scripts_postprocessing.ScriptPostprocessingRunner = None
|
scripts_postproc: scripts_postprocessing.ScriptPostprocessingRunner = None
|
||||||
|
@ -23,7 +23,8 @@ model_dir = "Stable-diffusion"
|
|||||||
model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
|
model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
|
||||||
|
|
||||||
checkpoints_list = {}
|
checkpoints_list = {}
|
||||||
checkpoint_alisases = {}
|
checkpoint_aliases = {}
|
||||||
|
checkpoint_alisases = checkpoint_aliases # for compatibility with old name
|
||||||
checkpoints_loaded = collections.OrderedDict()
|
checkpoints_loaded = collections.OrderedDict()
|
||||||
|
|
||||||
|
|
||||||
@ -66,7 +67,7 @@ class CheckpointInfo:
|
|||||||
def register(self):
|
def register(self):
|
||||||
checkpoints_list[self.title] = self
|
checkpoints_list[self.title] = self
|
||||||
for id in self.ids:
|
for id in self.ids:
|
||||||
checkpoint_alisases[id] = self
|
checkpoint_aliases[id] = self
|
||||||
|
|
||||||
def calculate_shorthash(self):
|
def calculate_shorthash(self):
|
||||||
self.sha256 = hashes.sha256(self.filename, f"checkpoint/{self.name}")
|
self.sha256 = hashes.sha256(self.filename, f"checkpoint/{self.name}")
|
||||||
@ -112,7 +113,7 @@ def checkpoint_tiles():
|
|||||||
|
|
||||||
def list_models():
|
def list_models():
|
||||||
checkpoints_list.clear()
|
checkpoints_list.clear()
|
||||||
checkpoint_alisases.clear()
|
checkpoint_aliases.clear()
|
||||||
|
|
||||||
cmd_ckpt = shared.cmd_opts.ckpt
|
cmd_ckpt = shared.cmd_opts.ckpt
|
||||||
if shared.cmd_opts.no_download_sd_model or cmd_ckpt != shared.sd_model_file or os.path.exists(cmd_ckpt):
|
if shared.cmd_opts.no_download_sd_model or cmd_ckpt != shared.sd_model_file or os.path.exists(cmd_ckpt):
|
||||||
@ -136,7 +137,7 @@ def list_models():
|
|||||||
|
|
||||||
|
|
||||||
def get_closet_checkpoint_match(search_string):
|
def get_closet_checkpoint_match(search_string):
|
||||||
checkpoint_info = checkpoint_alisases.get(search_string, None)
|
checkpoint_info = checkpoint_aliases.get(search_string, None)
|
||||||
if checkpoint_info is not None:
|
if checkpoint_info is not None:
|
||||||
return checkpoint_info
|
return checkpoint_info
|
||||||
|
|
||||||
@ -166,7 +167,7 @@ def select_checkpoint():
|
|||||||
"""Raises `FileNotFoundError` if no checkpoints are found."""
|
"""Raises `FileNotFoundError` if no checkpoints are found."""
|
||||||
model_checkpoint = shared.opts.sd_model_checkpoint
|
model_checkpoint = shared.opts.sd_model_checkpoint
|
||||||
|
|
||||||
checkpoint_info = checkpoint_alisases.get(model_checkpoint, None)
|
checkpoint_info = checkpoint_aliases.get(model_checkpoint, None)
|
||||||
if checkpoint_info is not None:
|
if checkpoint_info is not None:
|
||||||
return checkpoint_info
|
return checkpoint_info
|
||||||
|
|
||||||
@ -247,7 +248,12 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None
|
|||||||
_, extension = os.path.splitext(checkpoint_file)
|
_, extension = os.path.splitext(checkpoint_file)
|
||||||
if extension.lower() == ".safetensors":
|
if extension.lower() == ".safetensors":
|
||||||
device = map_location or shared.weight_load_location or devices.get_optimal_device_name()
|
device = map_location or shared.weight_load_location or devices.get_optimal_device_name()
|
||||||
pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
|
|
||||||
|
if not shared.opts.disable_mmap_load_safetensors:
|
||||||
|
pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
|
||||||
|
else:
|
||||||
|
pl_sd = safetensors.torch.load(open(checkpoint_file, 'rb').read())
|
||||||
|
pl_sd = {k: v.to(device) for k, v in pl_sd.items()}
|
||||||
else:
|
else:
|
||||||
pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)
|
pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)
|
||||||
|
|
||||||
@ -585,7 +591,6 @@ def unload_model_weights(sd_model=None, info=None):
|
|||||||
sd_model = None
|
sd_model = None
|
||||||
gc.collect()
|
gc.collect()
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
print(f"Unloaded weights {timer.summary()}.")
|
print(f"Unloaded weights {timer.summary()}.")
|
||||||
|
|
||||||
|
@ -1,9 +1,11 @@
|
|||||||
import datetime
|
import datetime
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
import logging
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import torch
|
import torch
|
||||||
@ -18,6 +20,8 @@ from modules.paths_internal import models_path, script_path, data_path, sd_confi
|
|||||||
from ldm.models.diffusion.ddpm import LatentDiffusion
|
from ldm.models.diffusion.ddpm import LatentDiffusion
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
demo = None
|
demo = None
|
||||||
|
|
||||||
parser = cmd_args.parser
|
parser = cmd_args.parser
|
||||||
@ -144,12 +148,15 @@ class State:
|
|||||||
def request_restart(self) -> None:
|
def request_restart(self) -> None:
|
||||||
self.interrupt()
|
self.interrupt()
|
||||||
self.server_command = "restart"
|
self.server_command = "restart"
|
||||||
|
log.info("Received restart request")
|
||||||
|
|
||||||
def skip(self):
|
def skip(self):
|
||||||
self.skipped = True
|
self.skipped = True
|
||||||
|
log.info("Received skip request")
|
||||||
|
|
||||||
def interrupt(self):
|
def interrupt(self):
|
||||||
self.interrupted = True
|
self.interrupted = True
|
||||||
|
log.info("Received interrupt request")
|
||||||
|
|
||||||
def nextjob(self):
|
def nextjob(self):
|
||||||
if opts.live_previews_enable and opts.show_progress_every_n_steps == -1:
|
if opts.live_previews_enable and opts.show_progress_every_n_steps == -1:
|
||||||
@ -173,7 +180,7 @@ class State:
|
|||||||
|
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
def begin(self):
|
def begin(self, job: str = "(unknown)"):
|
||||||
self.sampling_step = 0
|
self.sampling_step = 0
|
||||||
self.job_count = -1
|
self.job_count = -1
|
||||||
self.processing_has_refined_job_count = False
|
self.processing_has_refined_job_count = False
|
||||||
@ -187,10 +194,13 @@ class State:
|
|||||||
self.interrupted = False
|
self.interrupted = False
|
||||||
self.textinfo = None
|
self.textinfo = None
|
||||||
self.time_start = time.time()
|
self.time_start = time.time()
|
||||||
|
self.job = job
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
log.info("Starting job %s", job)
|
||||||
|
|
||||||
def end(self):
|
def end(self):
|
||||||
|
duration = time.time() - self.time_start
|
||||||
|
log.info("Ending job %s (%.2f seconds)", self.job, duration)
|
||||||
self.job = ""
|
self.job = ""
|
||||||
self.job_count = 0
|
self.job_count = 0
|
||||||
|
|
||||||
@ -311,6 +321,10 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
|
|||||||
"grid_prevent_empty_spots": OptionInfo(False, "Prevent empty spots in grid (when set to autodetect)"),
|
"grid_prevent_empty_spots": OptionInfo(False, "Prevent empty spots in grid (when set to autodetect)"),
|
||||||
"grid_zip_filename_pattern": OptionInfo("", "Archive filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
|
"grid_zip_filename_pattern": OptionInfo("", "Archive filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
|
||||||
"n_rows": OptionInfo(-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", gr.Slider, {"minimum": -1, "maximum": 16, "step": 1}),
|
"n_rows": OptionInfo(-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", gr.Slider, {"minimum": -1, "maximum": 16, "step": 1}),
|
||||||
|
"font": OptionInfo("", "Font for image grids that have text"),
|
||||||
|
"grid_text_active_color": OptionInfo("#000000", "Text color for image grids", ui_components.FormColorPicker, {}),
|
||||||
|
"grid_text_inactive_color": OptionInfo("#999999", "Inactive text color for image grids", ui_components.FormColorPicker, {}),
|
||||||
|
"grid_background_color": OptionInfo("#ffffff", "Background color for image grids", ui_components.FormColorPicker, {}),
|
||||||
|
|
||||||
"enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"),
|
"enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"),
|
||||||
"save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."),
|
"save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."),
|
||||||
@ -376,6 +390,8 @@ options_templates.update(options_section(('system', "System"), {
|
|||||||
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
|
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
|
||||||
"print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."),
|
"print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."),
|
||||||
"list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""),
|
"list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""),
|
||||||
|
"disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"),
|
||||||
|
"github_proxy": OptionInfo("None", "Github proxy", ui_components.DropdownEditable, lambda: {"choices": ["None", "ghproxy.com", "hub.yzuu.cf", "hub.njuu.cf", "hub.nuaa.cf"]}).info("for custom inputs will just replace github.com with the input"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('training', "Training"), {
|
options_templates.update(options_section(('training', "Training"), {
|
||||||
@ -470,7 +486,6 @@ options_templates.update(options_section(('ui', "User interface"), {
|
|||||||
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
|
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
|
||||||
"send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
|
"send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
|
||||||
"send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
|
"send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
|
||||||
"font": OptionInfo("", "Font for image grids that have text"),
|
|
||||||
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
|
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
|
||||||
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
|
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
|
||||||
"js_modal_lightbox_gamepad": OptionInfo(False, "Navigate image viewer with gamepad"),
|
"js_modal_lightbox_gamepad": OptionInfo(False, "Navigate image viewer with gamepad"),
|
||||||
@ -481,6 +496,7 @@ options_templates.update(options_section(('ui', "User interface"), {
|
|||||||
"keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
|
"keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
|
||||||
"keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
|
"keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
|
||||||
"keyedit_delimiters": OptionInfo(".,\\/!?%^*;:{}=`~()", "Ctrl+up/down word delimiters"),
|
"keyedit_delimiters": OptionInfo(".,\\/!?%^*;:{}=`~()", "Ctrl+up/down word delimiters"),
|
||||||
|
"keyedit_move": OptionInfo(True, "Alt+left/right moves prompt elements"),
|
||||||
"quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_restart(),
|
"quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_restart(),
|
||||||
"ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(),
|
"ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(),
|
||||||
"hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(),
|
"hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(),
|
||||||
@ -493,6 +509,7 @@ options_templates.update(options_section(('ui', "User interface"), {
|
|||||||
options_templates.update(options_section(('infotext', "Infotext"), {
|
options_templates.update(options_section(('infotext', "Infotext"), {
|
||||||
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
|
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
|
||||||
"add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
|
"add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
|
||||||
|
"add_user_name_to_info": OptionInfo(False, "Add user name to generation information when authenticated"),
|
||||||
"add_version_to_infotext": OptionInfo(True, "Add program version to generation information"),
|
"add_version_to_infotext": OptionInfo(True, "Add program version to generation information"),
|
||||||
"disable_weights_auto_swap": OptionInfo(True, "Disregard checkpoint information from pasted infotext").info("when reading generation parameters from text into UI"),
|
"disable_weights_auto_swap": OptionInfo(True, "Disregard checkpoint information from pasted infotext").info("when reading generation parameters from text into UI"),
|
||||||
"infotext_styles": OptionInfo("Apply if any", "Infer styles from prompts of pasted infotext", gr.Radio, {"choices": ["Ignore", "Apply", "Discard", "Apply if any"]}).info("when reading generation parameters from text into UI)").html("""<ul style='margin-left: 1.5em'>
|
"infotext_styles": OptionInfo("Apply if any", "Infer styles from prompts of pasted infotext", gr.Radio, {"choices": ["Ignore", "Apply", "Discard", "Apply if any"]}).info("when reading generation parameters from text into UI)").html("""<ul style='margin-left: 1.5em'>
|
||||||
@ -817,8 +834,12 @@ mem_mon = modules.memmon.MemUsageMonitor("MemMon", device, opts)
|
|||||||
mem_mon.start()
|
mem_mon.start()
|
||||||
|
|
||||||
|
|
||||||
|
def natural_sort_key(s, regex=re.compile('([0-9]+)')):
|
||||||
|
return [int(text) if text.isdigit() else text.lower() for text in regex.split(s)]
|
||||||
|
|
||||||
|
|
||||||
def listfiles(dirname):
|
def listfiles(dirname):
|
||||||
filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname), key=str.lower) if not x.startswith(".")]
|
filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname), key=natural_sort_key) if not x.startswith(".")]
|
||||||
return [file for file in filenames if os.path.isfile(file)]
|
return [file for file in filenames if os.path.isfile(file)]
|
||||||
|
|
||||||
|
|
||||||
@ -843,8 +864,11 @@ def walk_files(path, allowed_extensions=None):
|
|||||||
if allowed_extensions is not None:
|
if allowed_extensions is not None:
|
||||||
allowed_extensions = set(allowed_extensions)
|
allowed_extensions = set(allowed_extensions)
|
||||||
|
|
||||||
for root, _, files in os.walk(path, followlinks=True):
|
items = list(os.walk(path, followlinks=True))
|
||||||
for filename in files:
|
items = sorted(items, key=lambda x: natural_sort_key(x[0]))
|
||||||
|
|
||||||
|
for root, _, files in items:
|
||||||
|
for filename in sorted(files, key=natural_sort_key):
|
||||||
if allowed_extensions is not None:
|
if allowed_extensions is not None:
|
||||||
_, ext = os.path.splitext(filename)
|
_, ext = os.path.splitext(filename)
|
||||||
if ext not in allowed_extensions:
|
if ext not in allowed_extensions:
|
||||||
|
@ -2,11 +2,51 @@ import datetime
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
|
||||||
saved_params_shared = {"model_name", "model_hash", "initial_step", "num_of_dataset_images", "learn_rate", "batch_size", "clip_grad_mode", "clip_grad_value", "gradient_step", "data_root", "log_directory", "training_width", "training_height", "steps", "create_image_every", "template_file", "gradient_step", "latent_sampling_method"}
|
saved_params_shared = {
|
||||||
saved_params_ti = {"embedding_name", "num_vectors_per_token", "save_embedding_every", "save_image_with_stored_embedding"}
|
"batch_size",
|
||||||
saved_params_hypernet = {"hypernetwork_name", "layer_structure", "activation_func", "weight_init", "add_layer_norm", "use_dropout", "save_hypernetwork_every"}
|
"clip_grad_mode",
|
||||||
|
"clip_grad_value",
|
||||||
|
"create_image_every",
|
||||||
|
"data_root",
|
||||||
|
"gradient_step",
|
||||||
|
"initial_step",
|
||||||
|
"latent_sampling_method",
|
||||||
|
"learn_rate",
|
||||||
|
"log_directory",
|
||||||
|
"model_hash",
|
||||||
|
"model_name",
|
||||||
|
"num_of_dataset_images",
|
||||||
|
"steps",
|
||||||
|
"template_file",
|
||||||
|
"training_height",
|
||||||
|
"training_width",
|
||||||
|
}
|
||||||
|
saved_params_ti = {
|
||||||
|
"embedding_name",
|
||||||
|
"num_vectors_per_token",
|
||||||
|
"save_embedding_every",
|
||||||
|
"save_image_with_stored_embedding",
|
||||||
|
}
|
||||||
|
saved_params_hypernet = {
|
||||||
|
"activation_func",
|
||||||
|
"add_layer_norm",
|
||||||
|
"hypernetwork_name",
|
||||||
|
"layer_structure",
|
||||||
|
"save_hypernetwork_every",
|
||||||
|
"use_dropout",
|
||||||
|
"weight_init",
|
||||||
|
}
|
||||||
saved_params_all = saved_params_shared | saved_params_ti | saved_params_hypernet
|
saved_params_all = saved_params_shared | saved_params_ti | saved_params_hypernet
|
||||||
saved_params_previews = {"preview_prompt", "preview_negative_prompt", "preview_steps", "preview_sampler_index", "preview_cfg_scale", "preview_seed", "preview_width", "preview_height"}
|
saved_params_previews = {
|
||||||
|
"preview_cfg_scale",
|
||||||
|
"preview_height",
|
||||||
|
"preview_negative_prompt",
|
||||||
|
"preview_prompt",
|
||||||
|
"preview_sampler_index",
|
||||||
|
"preview_seed",
|
||||||
|
"preview_steps",
|
||||||
|
"preview_width",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def save_settings_to_file(log_directory, all_params):
|
def save_settings_to_file(log_directory, all_params):
|
||||||
|
@ -7,7 +7,7 @@ from modules import paths, shared, images, deepbooru
|
|||||||
from modules.textual_inversion import autocrop
|
from modules.textual_inversion import autocrop
|
||||||
|
|
||||||
|
|
||||||
def preprocess(id_task, process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None):
|
def preprocess(id_task, process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.15, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None):
|
||||||
try:
|
try:
|
||||||
if process_caption:
|
if process_caption:
|
||||||
shared.interrogator.load()
|
shared.interrogator.load()
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
from contextlib import closing
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
@ -584,8 +585,9 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
|
|||||||
|
|
||||||
preview_text = p.prompt
|
preview_text = p.prompt
|
||||||
|
|
||||||
processed = processing.process_images(p)
|
with closing(p):
|
||||||
image = processed.images[0] if len(processed.images) > 0 else None
|
processed = processing.process_images(p)
|
||||||
|
image = processed.images[0] if len(processed.images) > 0 else None
|
||||||
|
|
||||||
if unload:
|
if unload:
|
||||||
shared.sd_model.first_stage_model.to(devices.cpu)
|
shared.sd_model.first_stage_model.to(devices.cpu)
|
||||||
|
@ -1,13 +1,15 @@
|
|||||||
|
from contextlib import closing
|
||||||
|
|
||||||
import modules.scripts
|
import modules.scripts
|
||||||
from modules import sd_samplers, processing
|
from modules import sd_samplers, processing
|
||||||
from modules.generation_parameters_copypaste import create_override_settings_dict
|
from modules.generation_parameters_copypaste import create_override_settings_dict
|
||||||
from modules.shared import opts, cmd_opts
|
from modules.shared import opts, cmd_opts
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from modules.ui import plaintext_to_html
|
from modules.ui import plaintext_to_html
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
|
||||||
|
def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args):
|
||||||
def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, *args):
|
|
||||||
override_settings = create_override_settings_dict(override_settings_texts)
|
override_settings = create_override_settings_dict(override_settings_texts)
|
||||||
|
|
||||||
p = processing.StableDiffusionProcessingTxt2Img(
|
p = processing.StableDiffusionProcessingTxt2Img(
|
||||||
@ -48,15 +50,16 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
|
|||||||
p.scripts = modules.scripts.scripts_txt2img
|
p.scripts = modules.scripts.scripts_txt2img
|
||||||
p.script_args = args
|
p.script_args = args
|
||||||
|
|
||||||
|
p.user = request.username
|
||||||
|
|
||||||
if cmd_opts.enable_console_prompts:
|
if cmd_opts.enable_console_prompts:
|
||||||
print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
|
print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
|
||||||
|
|
||||||
processed = modules.scripts.scripts_txt2img.run(p, *args)
|
with closing(p):
|
||||||
|
processed = modules.scripts.scripts_txt2img.run(p, *args)
|
||||||
|
|
||||||
if processed is None:
|
if processed is None:
|
||||||
processed = processing.process_images(p)
|
processed = processing.process_images(p)
|
||||||
|
|
||||||
p.close()
|
|
||||||
|
|
||||||
shared.total_tqdm.clear()
|
shared.total_tqdm.clear()
|
||||||
|
|
||||||
|
@ -155,7 +155,7 @@ def process_interrogate(interrogation_function, mode, ii_input_dir, ii_output_di
|
|||||||
img = Image.open(image)
|
img = Image.open(image)
|
||||||
filename = os.path.basename(image)
|
filename = os.path.basename(image)
|
||||||
left, _ = os.path.splitext(filename)
|
left, _ = os.path.splitext(filename)
|
||||||
print(interrogation_function(img), file=open(os.path.join(ii_output_dir, f"{left}.txt"), 'a'))
|
print(interrogation_function(img), file=open(os.path.join(ii_output_dir, f"{left}.txt"), 'a', encoding='utf-8'))
|
||||||
|
|
||||||
return [gr.update(), None]
|
return [gr.update(), None]
|
||||||
|
|
||||||
@ -733,6 +733,10 @@ def create_ui():
|
|||||||
img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir")
|
img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir")
|
||||||
img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir")
|
img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir")
|
||||||
img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory (required for inpaint batch processing only)", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir")
|
img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory (required for inpaint batch processing only)", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir")
|
||||||
|
with gr.Accordion("PNG info", open=False):
|
||||||
|
img2img_batch_use_png_info = gr.Checkbox(label="Append png info to prompts", **shared.hide_dirs, elem_id="img2img_batch_use_png_info")
|
||||||
|
img2img_batch_png_info_dir = gr.Textbox(label="PNG info directory", **shared.hide_dirs, placeholder="Leave empty to use input directory", elem_id="img2img_batch_png_info_dir")
|
||||||
|
img2img_batch_png_info_props = gr.CheckboxGroup(["Prompt", "Negative prompt", "Seed", "CFG scale", "Sampler", "Steps"], label="Parameters to take from png info", info="Prompts from png info will be appended to prompts set in ui.")
|
||||||
|
|
||||||
img2img_tabs = [tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch]
|
img2img_tabs = [tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch]
|
||||||
|
|
||||||
@ -773,7 +777,7 @@ def create_ui():
|
|||||||
selected_scale_tab = gr.State(value=0)
|
selected_scale_tab = gr.State(value=0)
|
||||||
|
|
||||||
with gr.Tabs():
|
with gr.Tabs():
|
||||||
with gr.Tab(label="Resize to") as tab_scale_to:
|
with gr.Tab(label="Resize to", elem_id="img2img_tab_resize_to") as tab_scale_to:
|
||||||
with FormRow():
|
with FormRow():
|
||||||
with gr.Column(elem_id="img2img_column_size", scale=4):
|
with gr.Column(elem_id="img2img_column_size", scale=4):
|
||||||
width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
|
width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
|
||||||
@ -782,7 +786,7 @@ def create_ui():
|
|||||||
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn")
|
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn")
|
||||||
detect_image_size_btn = ToolButton(value=detect_image_size_symbol, elem_id="img2img_detect_image_size_btn")
|
detect_image_size_btn = ToolButton(value=detect_image_size_symbol, elem_id="img2img_detect_image_size_btn")
|
||||||
|
|
||||||
with gr.Tab(label="Resize by") as tab_scale_by:
|
with gr.Tab(label="Resize by", elem_id="img2img_tab_resize_by") as tab_scale_by:
|
||||||
scale_by = gr.Slider(minimum=0.05, maximum=4.0, step=0.05, label="Scale", value=1.0, elem_id="img2img_scale")
|
scale_by = gr.Slider(minimum=0.05, maximum=4.0, step=0.05, label="Scale", value=1.0, elem_id="img2img_scale")
|
||||||
|
|
||||||
with FormRow():
|
with FormRow():
|
||||||
@ -934,6 +938,9 @@ def create_ui():
|
|||||||
img2img_batch_output_dir,
|
img2img_batch_output_dir,
|
||||||
img2img_batch_inpaint_mask_dir,
|
img2img_batch_inpaint_mask_dir,
|
||||||
override_settings,
|
override_settings,
|
||||||
|
img2img_batch_use_png_info,
|
||||||
|
img2img_batch_png_info_props,
|
||||||
|
img2img_batch_png_info_dir,
|
||||||
] + custom_inputs,
|
] + custom_inputs,
|
||||||
outputs=[
|
outputs=[
|
||||||
img2img_gallery,
|
img2img_gallery,
|
||||||
|
@ -138,7 +138,10 @@ def extension_table():
|
|||||||
<table id="extensions">
|
<table id="extensions">
|
||||||
<thead>
|
<thead>
|
||||||
<tr>
|
<tr>
|
||||||
<th><abbr title="Use checkbox to enable the extension; it will be enabled or disabled when you click apply button">Extension</abbr></th>
|
<th>
|
||||||
|
<input class="gr-check-radio gr-checkbox all_extensions_toggle" type="checkbox" {'checked="checked"' if all(ext.enabled for ext in extensions.extensions) else ''} onchange="toggle_all_extensions(event)" />
|
||||||
|
<abbr title="Use checkbox to enable the extension; it will be enabled or disabled when you click apply button">Extension</abbr>
|
||||||
|
</th>
|
||||||
<th>URL</th>
|
<th>URL</th>
|
||||||
<th>Branch</th>
|
<th>Branch</th>
|
||||||
<th>Version</th>
|
<th>Version</th>
|
||||||
@ -170,7 +173,7 @@ def extension_table():
|
|||||||
|
|
||||||
code += f"""
|
code += f"""
|
||||||
<tr>
|
<tr>
|
||||||
<td><label{style}><input class="gr-check-radio gr-checkbox" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''}>{html.escape(ext.name)}</label></td>
|
<td><label{style}><input class="gr-check-radio gr-checkbox extension_toggle" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''} onchange="toggle_extension(event)" />{html.escape(ext.name)}</label></td>
|
||||||
<td>{remote}</td>
|
<td>{remote}</td>
|
||||||
<td>{ext.branch}</td>
|
<td>{ext.branch}</td>
|
||||||
<td>{version_link}</td>
|
<td>{version_link}</td>
|
||||||
@ -322,6 +325,17 @@ def normalize_git_url(url):
|
|||||||
return url
|
return url
|
||||||
|
|
||||||
|
|
||||||
|
def github_proxy(url):
|
||||||
|
proxy = shared.opts.github_proxy
|
||||||
|
|
||||||
|
if proxy == 'None':
|
||||||
|
return url
|
||||||
|
if proxy == 'ghproxy.com':
|
||||||
|
return "https://ghproxy.com/" + url
|
||||||
|
|
||||||
|
return url.replace('github.com', proxy)
|
||||||
|
|
||||||
|
|
||||||
def install_extension_from_url(dirname, url, branch_name=None):
|
def install_extension_from_url(dirname, url, branch_name=None):
|
||||||
check_access()
|
check_access()
|
||||||
|
|
||||||
@ -332,6 +346,8 @@ def install_extension_from_url(dirname, url, branch_name=None):
|
|||||||
|
|
||||||
assert url, 'No URL specified'
|
assert url, 'No URL specified'
|
||||||
|
|
||||||
|
url = github_proxy(url)
|
||||||
|
|
||||||
if dirname is None or dirname == "":
|
if dirname is None or dirname == "":
|
||||||
*parts, last_part = url.split('/')
|
*parts, last_part = url.split('/')
|
||||||
last_part = normalize_git_url(last_part)
|
last_part = normalize_git_url(last_part)
|
||||||
@ -351,12 +367,12 @@ def install_extension_from_url(dirname, url, branch_name=None):
|
|||||||
shutil.rmtree(tmpdir, True)
|
shutil.rmtree(tmpdir, True)
|
||||||
if not branch_name:
|
if not branch_name:
|
||||||
# if no branch is specified, use the default branch
|
# if no branch is specified, use the default branch
|
||||||
with git.Repo.clone_from(url, tmpdir, filter=['blob:none']) as repo:
|
with git.Repo.clone_from(url, tmpdir, filter=['blob:none'], verbose=False) as repo:
|
||||||
repo.remote().fetch()
|
repo.remote().fetch()
|
||||||
for submodule in repo.submodules:
|
for submodule in repo.submodules:
|
||||||
submodule.update()
|
submodule.update()
|
||||||
else:
|
else:
|
||||||
with git.Repo.clone_from(url, tmpdir, filter=['blob:none'], branch=branch_name) as repo:
|
with git.Repo.clone_from(url, tmpdir, filter=['blob:none'], branch=branch_name, verbose=False) as repo:
|
||||||
repo.remote().fetch()
|
repo.remote().fetch()
|
||||||
for submodule in repo.submodules:
|
for submodule in repo.submodules:
|
||||||
submodule.update()
|
submodule.update()
|
||||||
@ -421,9 +437,19 @@ sort_ordering = [
|
|||||||
(False, lambda x: x.get('name', 'z')),
|
(False, lambda x: x.get('name', 'z')),
|
||||||
(True, lambda x: x.get('name', 'z')),
|
(True, lambda x: x.get('name', 'z')),
|
||||||
(False, lambda x: 'z'),
|
(False, lambda x: 'z'),
|
||||||
|
(True, lambda x: x.get('commit_time', '')),
|
||||||
|
(True, lambda x: x.get('created_at', '')),
|
||||||
|
(True, lambda x: x.get('stars', 0)),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_date(info: dict, key):
|
||||||
|
try:
|
||||||
|
return datetime.strptime(info.get(key), "%Y-%m-%dT%H:%M:%SZ").strftime("%Y-%m-%d")
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
return ''
|
||||||
|
|
||||||
|
|
||||||
def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=""):
|
def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=""):
|
||||||
extlist = available_extensions["extensions"]
|
extlist = available_extensions["extensions"]
|
||||||
installed_extension_urls = {normalize_git_url(extension.remote): extension.name for extension in extensions.extensions}
|
installed_extension_urls = {normalize_git_url(extension.remote): extension.name for extension in extensions.extensions}
|
||||||
@ -448,7 +474,10 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text="
|
|||||||
|
|
||||||
for ext in sorted(extlist, key=sort_function, reverse=sort_reverse):
|
for ext in sorted(extlist, key=sort_function, reverse=sort_reverse):
|
||||||
name = ext.get("name", "noname")
|
name = ext.get("name", "noname")
|
||||||
|
stars = int(ext.get("stars", 0))
|
||||||
added = ext.get('added', 'unknown')
|
added = ext.get('added', 'unknown')
|
||||||
|
update_time = get_date(ext, 'commit_time')
|
||||||
|
create_time = get_date(ext, 'created_at')
|
||||||
url = ext.get("url", None)
|
url = ext.get("url", None)
|
||||||
description = ext.get("description", "")
|
description = ext.get("description", "")
|
||||||
extension_tags = ext.get("tags", [])
|
extension_tags = ext.get("tags", [])
|
||||||
@ -475,7 +504,8 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text="
|
|||||||
code += f"""
|
code += f"""
|
||||||
<tr>
|
<tr>
|
||||||
<td><a href="{html.escape(url)}" target="_blank">{html.escape(name)}</a><br />{tags_text}</td>
|
<td><a href="{html.escape(url)}" target="_blank">{html.escape(name)}</a><br />{tags_text}</td>
|
||||||
<td>{html.escape(description)}<p class="info"><span class="date_added">Added: {html.escape(added)}</span></p></td>
|
<td>{html.escape(description)}<p class="info">
|
||||||
|
<span class="date_added">Update: {html.escape(update_time)} Added: {html.escape(added)} Created: {html.escape(create_time)}</span><span class="star_count">stars: <b>{stars}</b></a></p></td>
|
||||||
<td>{install_code}</td>
|
<td>{install_code}</td>
|
||||||
</tr>
|
</tr>
|
||||||
|
|
||||||
@ -559,7 +589,7 @@ def create_ui():
|
|||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
hide_tags = gr.CheckboxGroup(value=["ads", "localization", "installed"], label="Hide extensions with tags", choices=["script", "ads", "localization", "installed"])
|
hide_tags = gr.CheckboxGroup(value=["ads", "localization", "installed"], label="Hide extensions with tags", choices=["script", "ads", "localization", "installed"])
|
||||||
sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order", ], type="index")
|
sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order",'update time', 'create time', "stars"], type="index")
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
search_extensions_text = gr.Text(label="Search").style(container=False)
|
search_extensions_text = gr.Text(label="Search").style(container=False)
|
||||||
@ -568,9 +598,9 @@ def create_ui():
|
|||||||
available_extensions_table = gr.HTML()
|
available_extensions_table = gr.HTML()
|
||||||
|
|
||||||
refresh_available_extensions_button.click(
|
refresh_available_extensions_button.click(
|
||||||
fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update(), gr.update()]),
|
fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update(), gr.update(), gr.update()]),
|
||||||
inputs=[available_extensions_index, hide_tags, sort_column],
|
inputs=[available_extensions_index, hide_tags, sort_column],
|
||||||
outputs=[available_extensions_index, available_extensions_table, hide_tags, install_result, search_extensions_text],
|
outputs=[available_extensions_index, available_extensions_table, hide_tags, search_extensions_text, install_result],
|
||||||
)
|
)
|
||||||
|
|
||||||
install_extension_button.click(
|
install_extension_button.click(
|
||||||
|
@ -30,8 +30,8 @@ def fetch_file(filename: str = ""):
|
|||||||
raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
|
raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
|
||||||
|
|
||||||
ext = os.path.splitext(filename)[1].lower()
|
ext = os.path.splitext(filename)[1].lower()
|
||||||
if ext not in (".png", ".jpg", ".jpeg", ".webp"):
|
if ext not in (".png", ".jpg", ".jpeg", ".webp", ".gif"):
|
||||||
raise ValueError(f"File cannot be fetched: {filename}. Only png and jpg and webp.")
|
raise ValueError(f"File cannot be fetched: {filename}. Only png, jpg, webp, and gif.")
|
||||||
|
|
||||||
# would profit from returning 304
|
# would profit from returning 304
|
||||||
return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
|
return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
|
||||||
@ -90,8 +90,8 @@ class ExtraNetworksPage:
|
|||||||
|
|
||||||
subdirs = {}
|
subdirs = {}
|
||||||
for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]:
|
for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]:
|
||||||
for root, dirs, _ in os.walk(parentdir, followlinks=True):
|
for root, dirs, _ in sorted(os.walk(parentdir, followlinks=True), key=lambda x: shared.natural_sort_key(x[0])):
|
||||||
for dirname in dirs:
|
for dirname in sorted(dirs, key=shared.natural_sort_key):
|
||||||
x = os.path.join(root, dirname)
|
x = os.path.join(root, dirname)
|
||||||
|
|
||||||
if not os.path.isdir(x):
|
if not os.path.isdir(x):
|
||||||
|
17
style.css
17
style.css
@ -704,11 +704,24 @@ table.popup-table .link{
|
|||||||
margin: 0;
|
margin: 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
#available_extensions .date_added{
|
#available_extensions .info{
|
||||||
opacity: 0.85;
|
margin: 0.5em 0;
|
||||||
|
display: flex;
|
||||||
|
margin-top: auto;
|
||||||
|
opacity: 0.80;
|
||||||
font-size: 90%;
|
font-size: 90%;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#available_extensions .date_added{
|
||||||
|
margin-right: auto;
|
||||||
|
display: inline-block;
|
||||||
|
}
|
||||||
|
|
||||||
|
#available_extensions .star_count{
|
||||||
|
margin-left: auto;
|
||||||
|
display: inline-block;
|
||||||
|
}
|
||||||
|
|
||||||
/* replace original footer with ours */
|
/* replace original footer with ours */
|
||||||
|
|
||||||
footer {
|
footer {
|
||||||
|
34
webui.py
34
webui.py
@ -11,13 +11,24 @@ import json
|
|||||||
from threading import Thread
|
from threading import Thread
|
||||||
from typing import Iterable
|
from typing import Iterable
|
||||||
|
|
||||||
from fastapi import FastAPI, Response
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.middleware.gzip import GZipMiddleware
|
from fastapi.middleware.gzip import GZipMiddleware
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
# We can't use cmd_opts for this because it will not have been initialized at this point.
|
||||||
|
log_level = os.environ.get("SD_WEBUI_LOG_LEVEL")
|
||||||
|
if log_level:
|
||||||
|
log_level = getattr(logging, log_level.upper(), None) or logging.INFO
|
||||||
|
logging.basicConfig(
|
||||||
|
level=log_level,
|
||||||
|
format='%(asctime)s %(levelname)s [%(name)s] %(message)s',
|
||||||
|
datefmt='%Y-%m-%d %H:%M:%S',
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.getLogger("torch.distributed.nn").setLevel(logging.ERROR) # sshh...
|
||||||
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
|
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
|
||||||
|
|
||||||
from modules import paths, timer, import_hook, errors, devices # noqa: F401
|
from modules import paths, timer, import_hook, errors, devices # noqa: F401
|
||||||
@ -32,7 +43,7 @@ warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvisi
|
|||||||
|
|
||||||
startup_timer.record("import torch")
|
startup_timer.record("import torch")
|
||||||
|
|
||||||
import gradio
|
import gradio # noqa: F401
|
||||||
startup_timer.record("import gradio")
|
startup_timer.record("import gradio")
|
||||||
|
|
||||||
import ldm.modules.encoders.modules # noqa: F401
|
import ldm.modules.encoders.modules # noqa: F401
|
||||||
@ -359,12 +370,11 @@ def api_only():
|
|||||||
modules.script_callbacks.app_started_callback(None, app)
|
modules.script_callbacks.app_started_callback(None, app)
|
||||||
|
|
||||||
print(f"Startup time: {startup_timer.summary()}.")
|
print(f"Startup time: {startup_timer.summary()}.")
|
||||||
api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861)
|
api.launch(
|
||||||
|
server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1",
|
||||||
|
port=cmd_opts.port if cmd_opts.port else 7861,
|
||||||
def stop_route(request):
|
root_path = f"/{cmd_opts.subpath}"
|
||||||
shared.state.server_command = "stop"
|
)
|
||||||
return Response("Stopping.")
|
|
||||||
|
|
||||||
|
|
||||||
def webui():
|
def webui():
|
||||||
@ -403,9 +413,8 @@ def webui():
|
|||||||
"docs_url": "/docs",
|
"docs_url": "/docs",
|
||||||
"redoc_url": "/redoc",
|
"redoc_url": "/redoc",
|
||||||
},
|
},
|
||||||
|
root_path=f"/{cmd_opts.subpath}" if cmd_opts.subpath else "",
|
||||||
)
|
)
|
||||||
if cmd_opts.add_stop_route:
|
|
||||||
app.add_route("/_stop", stop_route, methods=["POST"])
|
|
||||||
|
|
||||||
# after initial launch, disable --autolaunch for subsequent restarts
|
# after initial launch, disable --autolaunch for subsequent restarts
|
||||||
cmd_opts.autolaunch = False
|
cmd_opts.autolaunch = False
|
||||||
@ -436,11 +445,6 @@ def webui():
|
|||||||
timer.startup_record = startup_timer.dump()
|
timer.startup_record = startup_timer.dump()
|
||||||
print(f"Startup time: {startup_timer.summary()}.")
|
print(f"Startup time: {startup_timer.summary()}.")
|
||||||
|
|
||||||
if cmd_opts.subpath:
|
|
||||||
redirector = FastAPI()
|
|
||||||
redirector.get("/")
|
|
||||||
gradio.mount_gradio_app(redirector, shared.demo, path=f"/{cmd_opts.subpath}")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
server_command = shared.state.wait_for_server_command(timeout=5)
|
server_command = shared.state.wait_for_server_command(timeout=5)
|
||||||
|
16
webui.sh
16
webui.sh
@ -4,26 +4,28 @@
|
|||||||
# change the variables in webui-user.sh instead #
|
# change the variables in webui-user.sh instead #
|
||||||
#################################################
|
#################################################
|
||||||
|
|
||||||
|
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
|
||||||
|
|
||||||
# If run from macOS, load defaults from webui-macos-env.sh
|
# If run from macOS, load defaults from webui-macos-env.sh
|
||||||
if [[ "$OSTYPE" == "darwin"* ]]; then
|
if [[ "$OSTYPE" == "darwin"* ]]; then
|
||||||
if [[ -f webui-macos-env.sh ]]
|
if [[ -f "$SCRIPT_DIR"/webui-macos-env.sh ]]
|
||||||
then
|
then
|
||||||
source ./webui-macos-env.sh
|
source "$SCRIPT_DIR"/webui-macos-env.sh
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Read variables from webui-user.sh
|
# Read variables from webui-user.sh
|
||||||
# shellcheck source=/dev/null
|
# shellcheck source=/dev/null
|
||||||
if [[ -f webui-user.sh ]]
|
if [[ -f "$SCRIPT_DIR"/webui-user.sh ]]
|
||||||
then
|
then
|
||||||
source ./webui-user.sh
|
source "$SCRIPT_DIR"/webui-user.sh
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Set defaults
|
# Set defaults
|
||||||
# Install directory without trailing slash
|
# Install directory without trailing slash
|
||||||
if [[ -z "${install_dir}" ]]
|
if [[ -z "${install_dir}" ]]
|
||||||
then
|
then
|
||||||
install_dir="$(pwd)"
|
install_dir="$SCRIPT_DIR"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Name of the subdirectory (defaults to stable-diffusion-webui)
|
# Name of the subdirectory (defaults to stable-diffusion-webui)
|
||||||
@ -131,6 +133,10 @@ case "$gpu_info" in
|
|||||||
;;
|
;;
|
||||||
*"Navi 2"*) export HSA_OVERRIDE_GFX_VERSION=10.3.0
|
*"Navi 2"*) export HSA_OVERRIDE_GFX_VERSION=10.3.0
|
||||||
;;
|
;;
|
||||||
|
*"Navi 3"*) [[ -z "${TORCH_COMMAND}" ]] && \
|
||||||
|
export TORCH_COMMAND="pip install --pre torch==2.1.0.dev-20230614+rocm5.5 torchvision==0.16.0.dev-20230614+rocm5.5 --index-url https://download.pytorch.org/whl/nightly/rocm5.5"
|
||||||
|
# Navi 3 needs at least 5.5 which is only on the nightly chain
|
||||||
|
;;
|
||||||
*"Renoir"*) export HSA_OVERRIDE_GFX_VERSION=9.0.0
|
*"Renoir"*) export HSA_OVERRIDE_GFX_VERSION=9.0.0
|
||||||
printf "\n%s\n" "${delimiter}"
|
printf "\n%s\n" "${delimiter}"
|
||||||
printf "Experimental support for Renoir: make sure to have at least 4GB of VRAM and 10GB of RAM or enable cpu mode: --use-cpu all --no-half"
|
printf "Experimental support for Renoir: make sure to have at least 4GB of VRAM and 10GB of RAM or enable cpu mode: --use-cpu all --no-half"
|
||||||
|
Loading…
Reference in New Issue
Block a user