manual fixes for ruff
This commit is contained in:
parent
762265eab5
commit
96d6ca4199
@ -243,7 +243,7 @@ def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize
|
|||||||
x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True)
|
x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True)
|
||||||
log["sample_noquant"] = x_sample_noquant
|
log["sample_noquant"] = x_sample_noquant
|
||||||
log["sample_diff"] = torch.abs(x_sample_noquant - x_sample)
|
log["sample_diff"] = torch.abs(x_sample_noquant - x_sample)
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
log["sample"] = x_sample
|
log["sample"] = x_sample
|
||||||
|
@ -7,7 +7,8 @@ from basicsr.utils.download_util import load_file_from_url
|
|||||||
from modules.upscaler import Upscaler, UpscalerData
|
from modules.upscaler import Upscaler, UpscalerData
|
||||||
from ldsr_model_arch import LDSR
|
from ldsr_model_arch import LDSR
|
||||||
from modules import shared, script_callbacks
|
from modules import shared, script_callbacks
|
||||||
import sd_hijack_autoencoder, sd_hijack_ddpm_v1
|
import sd_hijack_autoencoder
|
||||||
|
import sd_hijack_ddpm_v1
|
||||||
|
|
||||||
|
|
||||||
class UpscalerLDSR(Upscaler):
|
class UpscalerLDSR(Upscaler):
|
||||||
|
@ -1,16 +1,21 @@
|
|||||||
# The content of this file comes from the ldm/models/autoencoder.py file of the compvis/stable-diffusion repo
|
# The content of this file comes from the ldm/models/autoencoder.py file of the compvis/stable-diffusion repo
|
||||||
# The VQModel & VQModelInterface were subsequently removed from ldm/models/autoencoder.py when we moved to the stability-ai/stablediffusion repo
|
# The VQModel & VQModelInterface were subsequently removed from ldm/models/autoencoder.py when we moved to the stability-ai/stablediffusion repo
|
||||||
# As the LDSR upscaler relies on VQModel & VQModelInterface, the hijack aims to put them back into the ldm.models.autoencoder
|
# As the LDSR upscaler relies on VQModel & VQModelInterface, the hijack aims to put them back into the ldm.models.autoencoder
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
from torch.optim.lr_scheduler import LambdaLR
|
||||||
|
|
||||||
|
from ldm.modules.ema import LitEma
|
||||||
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
|
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
|
||||||
from ldm.modules.diffusionmodules.model import Encoder, Decoder
|
from ldm.modules.diffusionmodules.model import Encoder, Decoder
|
||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
|
|
||||||
import ldm.models.autoencoder
|
import ldm.models.autoencoder
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
class VQModel(pl.LightningModule):
|
class VQModel(pl.LightningModule):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -249,7 +254,8 @@ class VQModel(pl.LightningModule):
|
|||||||
if plot_ema:
|
if plot_ema:
|
||||||
with self.ema_scope():
|
with self.ema_scope():
|
||||||
xrec_ema, _ = self(x)
|
xrec_ema, _ = self(x)
|
||||||
if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
|
if x.shape[1] > 3:
|
||||||
|
xrec_ema = self.to_rgb(xrec_ema)
|
||||||
log["reconstructions_ema"] = xrec_ema
|
log["reconstructions_ema"] = xrec_ema
|
||||||
return log
|
return log
|
||||||
|
|
||||||
|
@ -450,7 +450,7 @@ class LatentDiffusionV1(DDPMV1):
|
|||||||
self.cond_stage_key = cond_stage_key
|
self.cond_stage_key = cond_stage_key
|
||||||
try:
|
try:
|
||||||
self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
|
self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
|
||||||
except:
|
except Exception:
|
||||||
self.num_downs = 0
|
self.num_downs = 0
|
||||||
if not scale_by_std:
|
if not scale_by_std:
|
||||||
self.scale_factor = scale_factor
|
self.scale_factor = scale_factor
|
||||||
@ -877,16 +877,6 @@ class LatentDiffusionV1(DDPMV1):
|
|||||||
c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
|
c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
|
||||||
return self.p_losses(x, c, t, *args, **kwargs)
|
return self.p_losses(x, c, t, *args, **kwargs)
|
||||||
|
|
||||||
def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset
|
|
||||||
def rescale_bbox(bbox):
|
|
||||||
x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
|
|
||||||
y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
|
|
||||||
w = min(bbox[2] / crop_coordinates[2], 1 - x0)
|
|
||||||
h = min(bbox[3] / crop_coordinates[3], 1 - y0)
|
|
||||||
return x0, y0, w, h
|
|
||||||
|
|
||||||
return [rescale_bbox(b) for b in bboxes]
|
|
||||||
|
|
||||||
def apply_model(self, x_noisy, t, cond, return_ids=False):
|
def apply_model(self, x_noisy, t, cond, return_ids=False):
|
||||||
|
|
||||||
if isinstance(cond, dict):
|
if isinstance(cond, dict):
|
||||||
@ -1157,8 +1147,10 @@ class LatentDiffusionV1(DDPMV1):
|
|||||||
|
|
||||||
if i % log_every_t == 0 or i == timesteps - 1:
|
if i % log_every_t == 0 or i == timesteps - 1:
|
||||||
intermediates.append(x0_partial)
|
intermediates.append(x0_partial)
|
||||||
if callback: callback(i)
|
if callback:
|
||||||
if img_callback: img_callback(img, i)
|
callback(i)
|
||||||
|
if img_callback:
|
||||||
|
img_callback(img, i)
|
||||||
return img, intermediates
|
return img, intermediates
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@ -1205,8 +1197,10 @@ class LatentDiffusionV1(DDPMV1):
|
|||||||
|
|
||||||
if i % log_every_t == 0 or i == timesteps - 1:
|
if i % log_every_t == 0 or i == timesteps - 1:
|
||||||
intermediates.append(img)
|
intermediates.append(img)
|
||||||
if callback: callback(i)
|
if callback:
|
||||||
if img_callback: img_callback(img, i)
|
callback(i)
|
||||||
|
if img_callback:
|
||||||
|
img_callback(img, i)
|
||||||
|
|
||||||
if return_intermediates:
|
if return_intermediates:
|
||||||
return img, intermediates
|
return img, intermediates
|
||||||
@ -1322,7 +1316,7 @@ class LatentDiffusionV1(DDPMV1):
|
|||||||
|
|
||||||
if inpaint:
|
if inpaint:
|
||||||
# make a simple center square
|
# make a simple center square
|
||||||
b, h, w = z.shape[0], z.shape[2], z.shape[3]
|
h, w = z.shape[2], z.shape[3]
|
||||||
mask = torch.ones(N, h, w).to(self.device)
|
mask = torch.ones(N, h, w).to(self.device)
|
||||||
# zeros will be filled in
|
# zeros will be filled in
|
||||||
mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
|
mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
|
||||||
|
@ -61,7 +61,9 @@ class WMSA(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
output: tensor shape [b h w c]
|
output: tensor shape [b h w c]
|
||||||
"""
|
"""
|
||||||
if self.type != 'W': x = torch.roll(x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2))
|
if self.type != 'W':
|
||||||
|
x = torch.roll(x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2))
|
||||||
|
|
||||||
x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size)
|
x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size)
|
||||||
h_windows = x.size(1)
|
h_windows = x.size(1)
|
||||||
w_windows = x.size(2)
|
w_windows = x.size(2)
|
||||||
@ -85,8 +87,9 @@ class WMSA(nn.Module):
|
|||||||
output = self.linear(output)
|
output = self.linear(output)
|
||||||
output = rearrange(output, 'b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c', w1=h_windows, p1=self.window_size)
|
output = rearrange(output, 'b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c', w1=h_windows, p1=self.window_size)
|
||||||
|
|
||||||
if self.type != 'W': output = torch.roll(output, shifts=(self.window_size // 2, self.window_size // 2),
|
if self.type != 'W':
|
||||||
dims=(1, 2))
|
output = torch.roll(output, shifts=(self.window_size // 2, self.window_size // 2), dims=(1, 2))
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def relative_embedding(self):
|
def relative_embedding(self):
|
||||||
|
@ -45,7 +45,7 @@ class UpscalerSwinIR(Upscaler):
|
|||||||
img = upscale(img, model)
|
img = upscale(img, model)
|
||||||
try:
|
try:
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
return img
|
return img
|
||||||
|
|
||||||
|
@ -15,7 +15,8 @@ from secrets import compare_digest
|
|||||||
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing
|
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing
|
||||||
from modules.api.models import *
|
from modules.api import models
|
||||||
|
from modules.shared import opts
|
||||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
||||||
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
|
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
|
||||||
from modules.textual_inversion.preprocess import preprocess
|
from modules.textual_inversion.preprocess import preprocess
|
||||||
@ -25,20 +26,21 @@ from modules.sd_models import checkpoints_list, unload_model_weights, reload_mod
|
|||||||
from modules.sd_models_config import find_checkpoint_config_near_filename
|
from modules.sd_models_config import find_checkpoint_config_near_filename
|
||||||
from modules.realesrgan_model import get_realesrgan_models
|
from modules.realesrgan_model import get_realesrgan_models
|
||||||
from modules import devices
|
from modules import devices
|
||||||
from typing import List
|
from typing import Dict, List, Any
|
||||||
import piexif
|
import piexif
|
||||||
import piexif.helper
|
import piexif.helper
|
||||||
|
|
||||||
|
|
||||||
def upscaler_to_index(name: str):
|
def upscaler_to_index(name: str):
|
||||||
try:
|
try:
|
||||||
return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
|
return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
|
||||||
except:
|
except Exception:
|
||||||
raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in sd_upscalers])}")
|
raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in shared.sd_upscalers])}")
|
||||||
|
|
||||||
def script_name_to_index(name, scripts):
|
def script_name_to_index(name, scripts):
|
||||||
try:
|
try:
|
||||||
return [script.title().lower() for script in scripts].index(name.lower())
|
return [script.title().lower() for script in scripts].index(name.lower())
|
||||||
except:
|
except Exception:
|
||||||
raise HTTPException(status_code=422, detail=f"Script '{name}' not found")
|
raise HTTPException(status_code=422, detail=f"Script '{name}' not found")
|
||||||
|
|
||||||
def validate_sampler_name(name):
|
def validate_sampler_name(name):
|
||||||
@ -99,7 +101,7 @@ def api_middleware(app: FastAPI):
|
|||||||
import starlette # importing just so it can be placed on silent list
|
import starlette # importing just so it can be placed on silent list
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
console = Console()
|
console = Console()
|
||||||
except:
|
except Exception:
|
||||||
import traceback
|
import traceback
|
||||||
rich_available = False
|
rich_available = False
|
||||||
|
|
||||||
@ -166,36 +168,36 @@ class Api:
|
|||||||
self.app = app
|
self.app = app
|
||||||
self.queue_lock = queue_lock
|
self.queue_lock = queue_lock
|
||||||
api_middleware(self.app)
|
api_middleware(self.app)
|
||||||
self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse)
|
self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=models.TextToImageResponse)
|
||||||
self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse)
|
self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=models.ImageToImageResponse)
|
||||||
self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
|
self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=models.ExtrasSingleImageResponse)
|
||||||
self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
|
self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=models.ExtrasBatchImagesResponse)
|
||||||
self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
|
self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=models.PNGInfoResponse)
|
||||||
self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
|
self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=models.ProgressResponse)
|
||||||
self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
|
self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
|
||||||
self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
|
self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
|
||||||
self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"])
|
self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"])
|
||||||
self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel)
|
self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel)
|
||||||
self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
|
self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
|
||||||
self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=FlagsModel)
|
self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel)
|
||||||
self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[SamplerItem])
|
self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem])
|
||||||
self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[UpscalerItem])
|
self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem])
|
||||||
self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[SDModelItem])
|
self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem])
|
||||||
self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[HypernetworkItem])
|
self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem])
|
||||||
self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem])
|
self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[models.FaceRestorerItem])
|
||||||
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem])
|
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem])
|
||||||
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[PromptStyleItem])
|
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[models.PromptStyleItem])
|
||||||
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=EmbeddingsResponse)
|
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)
|
||||||
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
|
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
|
||||||
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=CreateResponse)
|
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
|
||||||
self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=CreateResponse)
|
self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse)
|
||||||
self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=PreprocessResponse)
|
self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse)
|
||||||
self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=TrainResponse)
|
self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse)
|
||||||
self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=TrainResponse)
|
self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse)
|
||||||
self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=MemoryResponse)
|
self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse)
|
||||||
self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
|
self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
|
||||||
self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
|
self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
|
||||||
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=ScriptsList)
|
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
|
||||||
|
|
||||||
self.default_script_arg_txt2img = []
|
self.default_script_arg_txt2img = []
|
||||||
self.default_script_arg_img2img = []
|
self.default_script_arg_img2img = []
|
||||||
@ -224,7 +226,7 @@ class Api:
|
|||||||
t2ilist = [str(title.lower()) for title in scripts.scripts_txt2img.titles]
|
t2ilist = [str(title.lower()) for title in scripts.scripts_txt2img.titles]
|
||||||
i2ilist = [str(title.lower()) for title in scripts.scripts_img2img.titles]
|
i2ilist = [str(title.lower()) for title in scripts.scripts_img2img.titles]
|
||||||
|
|
||||||
return ScriptsList(txt2img = t2ilist, img2img = i2ilist)
|
return models.ScriptsList(txt2img=t2ilist, img2img=i2ilist)
|
||||||
|
|
||||||
def get_script(self, script_name, script_runner):
|
def get_script(self, script_name, script_runner):
|
||||||
if script_name is None or script_name == "":
|
if script_name is None or script_name == "":
|
||||||
@ -276,7 +278,7 @@ class Api:
|
|||||||
script_args[alwayson_script.args_from + idx] = request.alwayson_scripts[alwayson_script_name]["args"][idx]
|
script_args[alwayson_script.args_from + idx] = request.alwayson_scripts[alwayson_script_name]["args"][idx]
|
||||||
return script_args
|
return script_args
|
||||||
|
|
||||||
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
|
def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI):
|
||||||
script_runner = scripts.scripts_txt2img
|
script_runner = scripts.scripts_txt2img
|
||||||
if not script_runner.scripts:
|
if not script_runner.scripts:
|
||||||
script_runner.initialize_scripts(False)
|
script_runner.initialize_scripts(False)
|
||||||
@ -320,9 +322,9 @@ class Api:
|
|||||||
|
|
||||||
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
|
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
|
||||||
|
|
||||||
return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
|
return models.TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
|
||||||
|
|
||||||
def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
|
def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI):
|
||||||
init_images = img2imgreq.init_images
|
init_images = img2imgreq.init_images
|
||||||
if init_images is None:
|
if init_images is None:
|
||||||
raise HTTPException(status_code=404, detail="Init image not found")
|
raise HTTPException(status_code=404, detail="Init image not found")
|
||||||
@ -381,9 +383,9 @@ class Api:
|
|||||||
img2imgreq.init_images = None
|
img2imgreq.init_images = None
|
||||||
img2imgreq.mask = None
|
img2imgreq.mask = None
|
||||||
|
|
||||||
return ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js())
|
return models.ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js())
|
||||||
|
|
||||||
def extras_single_image_api(self, req: ExtrasSingleImageRequest):
|
def extras_single_image_api(self, req: models.ExtrasSingleImageRequest):
|
||||||
reqDict = setUpscalers(req)
|
reqDict = setUpscalers(req)
|
||||||
|
|
||||||
reqDict['image'] = decode_base64_to_image(reqDict['image'])
|
reqDict['image'] = decode_base64_to_image(reqDict['image'])
|
||||||
@ -391,9 +393,9 @@ class Api:
|
|||||||
with self.queue_lock:
|
with self.queue_lock:
|
||||||
result = postprocessing.run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict)
|
result = postprocessing.run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict)
|
||||||
|
|
||||||
return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])
|
return models.ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])
|
||||||
|
|
||||||
def extras_batch_images_api(self, req: ExtrasBatchImagesRequest):
|
def extras_batch_images_api(self, req: models.ExtrasBatchImagesRequest):
|
||||||
reqDict = setUpscalers(req)
|
reqDict = setUpscalers(req)
|
||||||
|
|
||||||
image_list = reqDict.pop('imageList', [])
|
image_list = reqDict.pop('imageList', [])
|
||||||
@ -402,15 +404,15 @@ class Api:
|
|||||||
with self.queue_lock:
|
with self.queue_lock:
|
||||||
result = postprocessing.run_extras(extras_mode=1, image_folder=image_folder, image="", input_dir="", output_dir="", save_output=False, **reqDict)
|
result = postprocessing.run_extras(extras_mode=1, image_folder=image_folder, image="", input_dir="", output_dir="", save_output=False, **reqDict)
|
||||||
|
|
||||||
return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
|
return models.ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
|
||||||
|
|
||||||
def pnginfoapi(self, req: PNGInfoRequest):
|
def pnginfoapi(self, req: models.PNGInfoRequest):
|
||||||
if(not req.image.strip()):
|
if(not req.image.strip()):
|
||||||
return PNGInfoResponse(info="")
|
return models.PNGInfoResponse(info="")
|
||||||
|
|
||||||
image = decode_base64_to_image(req.image.strip())
|
image = decode_base64_to_image(req.image.strip())
|
||||||
if image is None:
|
if image is None:
|
||||||
return PNGInfoResponse(info="")
|
return models.PNGInfoResponse(info="")
|
||||||
|
|
||||||
geninfo, items = images.read_info_from_image(image)
|
geninfo, items = images.read_info_from_image(image)
|
||||||
if geninfo is None:
|
if geninfo is None:
|
||||||
@ -418,13 +420,13 @@ class Api:
|
|||||||
|
|
||||||
items = {**{'parameters': geninfo}, **items}
|
items = {**{'parameters': geninfo}, **items}
|
||||||
|
|
||||||
return PNGInfoResponse(info=geninfo, items=items)
|
return models.PNGInfoResponse(info=geninfo, items=items)
|
||||||
|
|
||||||
def progressapi(self, req: ProgressRequest = Depends()):
|
def progressapi(self, req: models.ProgressRequest = Depends()):
|
||||||
# copy from check_progress_call of ui.py
|
# copy from check_progress_call of ui.py
|
||||||
|
|
||||||
if shared.state.job_count == 0:
|
if shared.state.job_count == 0:
|
||||||
return ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict(), textinfo=shared.state.textinfo)
|
return models.ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict(), textinfo=shared.state.textinfo)
|
||||||
|
|
||||||
# avoid dividing zero
|
# avoid dividing zero
|
||||||
progress = 0.01
|
progress = 0.01
|
||||||
@ -446,9 +448,9 @@ class Api:
|
|||||||
if shared.state.current_image and not req.skip_current_image:
|
if shared.state.current_image and not req.skip_current_image:
|
||||||
current_image = encode_pil_to_base64(shared.state.current_image)
|
current_image = encode_pil_to_base64(shared.state.current_image)
|
||||||
|
|
||||||
return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo)
|
return models.ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo)
|
||||||
|
|
||||||
def interrogateapi(self, interrogatereq: InterrogateRequest):
|
def interrogateapi(self, interrogatereq: models.InterrogateRequest):
|
||||||
image_b64 = interrogatereq.image
|
image_b64 = interrogatereq.image
|
||||||
if image_b64 is None:
|
if image_b64 is None:
|
||||||
raise HTTPException(status_code=404, detail="Image not found")
|
raise HTTPException(status_code=404, detail="Image not found")
|
||||||
@ -465,7 +467,7 @@ class Api:
|
|||||||
else:
|
else:
|
||||||
raise HTTPException(status_code=404, detail="Model not found")
|
raise HTTPException(status_code=404, detail="Model not found")
|
||||||
|
|
||||||
return InterrogateResponse(caption=processed)
|
return models.InterrogateResponse(caption=processed)
|
||||||
|
|
||||||
def interruptapi(self):
|
def interruptapi(self):
|
||||||
shared.state.interrupt()
|
shared.state.interrupt()
|
||||||
@ -570,36 +572,36 @@ class Api:
|
|||||||
filename = create_embedding(**args) # create empty embedding
|
filename = create_embedding(**args) # create empty embedding
|
||||||
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
|
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
return CreateResponse(info=f"create embedding filename: {filename}")
|
return models.CreateResponse(info=f"create embedding filename: {filename}")
|
||||||
except AssertionError as e:
|
except AssertionError as e:
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
return TrainResponse(info=f"create embedding error: {e}")
|
return models.TrainResponse(info=f"create embedding error: {e}")
|
||||||
|
|
||||||
def create_hypernetwork(self, args: dict):
|
def create_hypernetwork(self, args: dict):
|
||||||
try:
|
try:
|
||||||
shared.state.begin()
|
shared.state.begin()
|
||||||
filename = create_hypernetwork(**args) # create empty embedding
|
filename = create_hypernetwork(**args) # create empty embedding
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
return CreateResponse(info=f"create hypernetwork filename: {filename}")
|
return models.CreateResponse(info=f"create hypernetwork filename: {filename}")
|
||||||
except AssertionError as e:
|
except AssertionError as e:
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
return TrainResponse(info=f"create hypernetwork error: {e}")
|
return models.TrainResponse(info=f"create hypernetwork error: {e}")
|
||||||
|
|
||||||
def preprocess(self, args: dict):
|
def preprocess(self, args: dict):
|
||||||
try:
|
try:
|
||||||
shared.state.begin()
|
shared.state.begin()
|
||||||
preprocess(**args) # quick operation unless blip/booru interrogation is enabled
|
preprocess(**args) # quick operation unless blip/booru interrogation is enabled
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
return PreprocessResponse(info = 'preprocess complete')
|
return models.PreprocessResponse(info = 'preprocess complete')
|
||||||
except KeyError as e:
|
except KeyError as e:
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
return PreprocessResponse(info=f"preprocess error: invalid token: {e}")
|
return models.PreprocessResponse(info=f"preprocess error: invalid token: {e}")
|
||||||
except AssertionError as e:
|
except AssertionError as e:
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
return PreprocessResponse(info=f"preprocess error: {e}")
|
return models.PreprocessResponse(info=f"preprocess error: {e}")
|
||||||
except FileNotFoundError as e:
|
except FileNotFoundError as e:
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
return PreprocessResponse(info=f'preprocess error: {e}')
|
return models.PreprocessResponse(info=f'preprocess error: {e}')
|
||||||
|
|
||||||
def train_embedding(self, args: dict):
|
def train_embedding(self, args: dict):
|
||||||
try:
|
try:
|
||||||
@ -617,10 +619,10 @@ class Api:
|
|||||||
if not apply_optimizations:
|
if not apply_optimizations:
|
||||||
sd_hijack.apply_optimizations()
|
sd_hijack.apply_optimizations()
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
return TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
|
return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
|
||||||
except AssertionError as msg:
|
except AssertionError as msg:
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
return TrainResponse(info=f"train embedding error: {msg}")
|
return models.TrainResponse(info=f"train embedding error: {msg}")
|
||||||
|
|
||||||
def train_hypernetwork(self, args: dict):
|
def train_hypernetwork(self, args: dict):
|
||||||
try:
|
try:
|
||||||
@ -641,14 +643,15 @@ class Api:
|
|||||||
if not apply_optimizations:
|
if not apply_optimizations:
|
||||||
sd_hijack.apply_optimizations()
|
sd_hijack.apply_optimizations()
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
return TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
|
return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
|
||||||
except AssertionError:
|
except AssertionError:
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
return TrainResponse(info=f"train embedding error: {error}")
|
return models.TrainResponse(info=f"train embedding error: {error}")
|
||||||
|
|
||||||
def get_memory(self):
|
def get_memory(self):
|
||||||
try:
|
try:
|
||||||
import os, psutil
|
import os
|
||||||
|
import psutil
|
||||||
process = psutil.Process(os.getpid())
|
process = psutil.Process(os.getpid())
|
||||||
res = process.memory_info() # only rss is cross-platform guaranteed so we dont rely on other values
|
res = process.memory_info() # only rss is cross-platform guaranteed so we dont rely on other values
|
||||||
ram_total = 100 * res.rss / process.memory_percent() # and total memory is calculated as actual value is not cross-platform safe
|
ram_total = 100 * res.rss / process.memory_percent() # and total memory is calculated as actual value is not cross-platform safe
|
||||||
@ -675,10 +678,10 @@ class Api:
|
|||||||
'events': warnings,
|
'events': warnings,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
cuda = { 'error': 'unavailable' }
|
cuda = {'error': 'unavailable'}
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
cuda = { 'error': f'{err}' }
|
cuda = {'error': f'{err}'}
|
||||||
return MemoryResponse(ram = ram, cuda = cuda)
|
return models.MemoryResponse(ram=ram, cuda=cuda)
|
||||||
|
|
||||||
def launch(self, server_name, port):
|
def launch(self, server_name, port):
|
||||||
self.app.include_router(self.router)
|
self.app.include_router(self.router)
|
||||||
|
@ -223,8 +223,9 @@ for key in _options:
|
|||||||
if(_options[key].dest != 'help'):
|
if(_options[key].dest != 'help'):
|
||||||
flag = _options[key]
|
flag = _options[key]
|
||||||
_type = str
|
_type = str
|
||||||
if _options[key].default is not None: _type = type(_options[key].default)
|
if _options[key].default is not None:
|
||||||
flags.update({flag.dest: (_type,Field(default=flag.default, description=flag.help))})
|
_type = type(_options[key].default)
|
||||||
|
flags.update({flag.dest: (_type, Field(default=flag.default, description=flag.help))})
|
||||||
|
|
||||||
FlagsModel = create_model("Flags", **flags)
|
FlagsModel = create_model("Flags", **flags)
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ from torch import nn, Tensor
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
|
|
||||||
from modules.codeformer.vqgan_arch import *
|
from modules.codeformer.vqgan_arch import VQAutoEncoder, ResBlock
|
||||||
from basicsr.utils import get_root_logger
|
from basicsr.utils import get_root_logger
|
||||||
from basicsr.utils.registry import ARCH_REGISTRY
|
from basicsr.utils.registry import ARCH_REGISTRY
|
||||||
|
|
||||||
|
@ -438,9 +438,11 @@ def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=
|
|||||||
padding = padding if pad_type == 'zero' else 0
|
padding = padding if pad_type == 'zero' else 0
|
||||||
|
|
||||||
if convtype=='PartialConv2D':
|
if convtype=='PartialConv2D':
|
||||||
|
from torchvision.ops import PartialConv2d # this is definitely not going to work, but PartialConv2d doesn't work anyway and this shuts up static analyzer
|
||||||
c = PartialConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
c = PartialConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
||||||
dilation=dilation, bias=bias, groups=groups)
|
dilation=dilation, bias=bias, groups=groups)
|
||||||
elif convtype=='DeformConv2D':
|
elif convtype=='DeformConv2D':
|
||||||
|
from torchvision.ops import DeformConv2d # not tested
|
||||||
c = DeformConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
c = DeformConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
||||||
dilation=dilation, bias=bias, groups=groups)
|
dilation=dilation, bias=bias, groups=groups)
|
||||||
elif convtype=='Conv3D':
|
elif convtype=='Conv3D':
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from modules import extra_networks, shared, extra_networks
|
from modules import extra_networks, shared
|
||||||
from modules.hypernetworks import hypernetwork
|
from modules.hypernetworks import hypernetwork
|
||||||
|
|
||||||
|
|
||||||
|
@ -472,9 +472,9 @@ def get_next_sequence_number(path, basename):
|
|||||||
prefix_length = len(basename)
|
prefix_length = len(basename)
|
||||||
for p in os.listdir(path):
|
for p in os.listdir(path):
|
||||||
if p.startswith(basename):
|
if p.startswith(basename):
|
||||||
l = os.path.splitext(p[prefix_length:])[0].split('-') # splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)
|
parts = os.path.splitext(p[prefix_length:])[0].split('-') # splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)
|
||||||
try:
|
try:
|
||||||
result = max(int(l[0]), result)
|
result = max(int(parts[0]), result)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -13,7 +13,6 @@ from modules.shared import opts, state
|
|||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
import modules.processing as processing
|
import modules.processing as processing
|
||||||
from modules.ui import plaintext_to_html
|
from modules.ui import plaintext_to_html
|
||||||
import modules.images as images
|
|
||||||
import modules.scripts
|
import modules.scripts
|
||||||
|
|
||||||
|
|
||||||
|
@ -11,7 +11,6 @@ import torch.hub
|
|||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from torchvision.transforms.functional import InterpolationMode
|
from torchvision.transforms.functional import InterpolationMode
|
||||||
|
|
||||||
import modules.shared as shared
|
|
||||||
from modules import devices, paths, shared, lowvram, modelloader, errors
|
from modules import devices, paths, shared, lowvram, modelloader, errors
|
||||||
|
|
||||||
blip_image_eval_size = 384
|
blip_image_eval_size = 384
|
||||||
|
@ -108,12 +108,12 @@ def move_files(src_path: str, dest_path: str, ext_filter: str = None):
|
|||||||
print(f"Moving {file} from {src_path} to {dest_path}.")
|
print(f"Moving {file} from {src_path} to {dest_path}.")
|
||||||
try:
|
try:
|
||||||
shutil.move(fullpath, dest_path)
|
shutil.move(fullpath, dest_path)
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
if len(os.listdir(src_path)) == 0:
|
if len(os.listdir(src_path)) == 0:
|
||||||
print(f"Removing empty folder: {src_path}")
|
print(f"Removing empty folder: {src_path}")
|
||||||
shutil.rmtree(src_path, True)
|
shutil.rmtree(src_path, True)
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@ -141,7 +141,7 @@ def load_upscalers():
|
|||||||
full_model = f"modules.{model_name}_model"
|
full_model = f"modules.{model_name}_model"
|
||||||
try:
|
try:
|
||||||
importlib.import_module(full_model)
|
importlib.import_module(full_model)
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
datas = []
|
datas = []
|
||||||
|
@ -479,7 +479,7 @@ class LatentDiffusion(DDPM):
|
|||||||
self.cond_stage_key = cond_stage_key
|
self.cond_stage_key = cond_stage_key
|
||||||
try:
|
try:
|
||||||
self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
|
self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
|
||||||
except:
|
except Exception:
|
||||||
self.num_downs = 0
|
self.num_downs = 0
|
||||||
if not scale_by_std:
|
if not scale_by_std:
|
||||||
self.scale_factor = scale_factor
|
self.scale_factor = scale_factor
|
||||||
@ -891,16 +891,6 @@ class LatentDiffusion(DDPM):
|
|||||||
c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
|
c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
|
||||||
return self.p_losses(x, c, t, *args, **kwargs)
|
return self.p_losses(x, c, t, *args, **kwargs)
|
||||||
|
|
||||||
def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset
|
|
||||||
def rescale_bbox(bbox):
|
|
||||||
x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
|
|
||||||
y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
|
|
||||||
w = min(bbox[2] / crop_coordinates[2], 1 - x0)
|
|
||||||
h = min(bbox[3] / crop_coordinates[3], 1 - y0)
|
|
||||||
return x0, y0, w, h
|
|
||||||
|
|
||||||
return [rescale_bbox(b) for b in bboxes]
|
|
||||||
|
|
||||||
def apply_model(self, x_noisy, t, cond, return_ids=False):
|
def apply_model(self, x_noisy, t, cond, return_ids=False):
|
||||||
|
|
||||||
if isinstance(cond, dict):
|
if isinstance(cond, dict):
|
||||||
@ -1171,8 +1161,10 @@ class LatentDiffusion(DDPM):
|
|||||||
|
|
||||||
if i % log_every_t == 0 or i == timesteps - 1:
|
if i % log_every_t == 0 or i == timesteps - 1:
|
||||||
intermediates.append(x0_partial)
|
intermediates.append(x0_partial)
|
||||||
if callback: callback(i)
|
if callback:
|
||||||
if img_callback: img_callback(img, i)
|
callback(i)
|
||||||
|
if img_callback:
|
||||||
|
img_callback(img, i)
|
||||||
return img, intermediates
|
return img, intermediates
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@ -1219,8 +1211,10 @@ class LatentDiffusion(DDPM):
|
|||||||
|
|
||||||
if i % log_every_t == 0 or i == timesteps - 1:
|
if i % log_every_t == 0 or i == timesteps - 1:
|
||||||
intermediates.append(img)
|
intermediates.append(img)
|
||||||
if callback: callback(i)
|
if callback:
|
||||||
if img_callback: img_callback(img, i)
|
callback(i)
|
||||||
|
if img_callback:
|
||||||
|
img_callback(img, i)
|
||||||
|
|
||||||
if return_intermediates:
|
if return_intermediates:
|
||||||
return img, intermediates
|
return img, intermediates
|
||||||
@ -1337,7 +1331,7 @@ class LatentDiffusion(DDPM):
|
|||||||
|
|
||||||
if inpaint:
|
if inpaint:
|
||||||
# make a simple center square
|
# make a simple center square
|
||||||
b, h, w = z.shape[0], z.shape[2], z.shape[3]
|
h, w = z.shape[2], z.shape[3]
|
||||||
mask = torch.ones(N, h, w).to(self.device)
|
mask = torch.ones(N, h, w).to(self.device)
|
||||||
# zeros will be filled in
|
# zeros will be filled in
|
||||||
mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
|
mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
|
||||||
|
@ -54,7 +54,8 @@ class UniPCSampler(object):
|
|||||||
if conditioning is not None:
|
if conditioning is not None:
|
||||||
if isinstance(conditioning, dict):
|
if isinstance(conditioning, dict):
|
||||||
ctmp = conditioning[list(conditioning.keys())[0]]
|
ctmp = conditioning[list(conditioning.keys())[0]]
|
||||||
while isinstance(ctmp, list): ctmp = ctmp[0]
|
while isinstance(ctmp, list):
|
||||||
|
ctmp = ctmp[0]
|
||||||
cbs = ctmp.shape[0]
|
cbs = ctmp.shape[0]
|
||||||
if cbs != batch_size:
|
if cbs != batch_size:
|
||||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||||
|
@ -664,7 +664,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
if not shared.opts.dont_fix_second_order_samplers_schedule:
|
if not shared.opts.dont_fix_second_order_samplers_schedule:
|
||||||
try:
|
try:
|
||||||
step_multiplier = 2 if sd_samplers.all_samplers_map.get(p.sampler_name).aliases[0] in ['k_dpmpp_2s_a', 'k_dpmpp_2s_a_ka', 'k_dpmpp_sde', 'k_dpmpp_sde_ka', 'k_dpm_2', 'k_dpm_2_a', 'k_heun'] else 1
|
step_multiplier = 2 if sd_samplers.all_samplers_map.get(p.sampler_name).aliases[0] in ['k_dpmpp_2s_a', 'k_dpmpp_2s_a_ka', 'k_dpmpp_sde', 'k_dpmpp_sde_ka', 'k_dpm_2', 'k_dpm_2_a', 'k_heun'] else 1
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
uc = get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, p.steps * step_multiplier, cached_uc)
|
uc = get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, p.steps * step_multiplier, cached_uc)
|
||||||
c = get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, p.steps * step_multiplier, cached_c)
|
c = get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, p.steps * step_multiplier, cached_c)
|
||||||
|
@ -54,18 +54,21 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def collect_steps(steps, tree):
|
def collect_steps(steps, tree):
|
||||||
l = [steps]
|
res = [steps]
|
||||||
|
|
||||||
class CollectSteps(lark.Visitor):
|
class CollectSteps(lark.Visitor):
|
||||||
def scheduled(self, tree):
|
def scheduled(self, tree):
|
||||||
tree.children[-1] = float(tree.children[-1])
|
tree.children[-1] = float(tree.children[-1])
|
||||||
if tree.children[-1] < 1:
|
if tree.children[-1] < 1:
|
||||||
tree.children[-1] *= steps
|
tree.children[-1] *= steps
|
||||||
tree.children[-1] = min(steps, int(tree.children[-1]))
|
tree.children[-1] = min(steps, int(tree.children[-1]))
|
||||||
l.append(tree.children[-1])
|
res.append(tree.children[-1])
|
||||||
|
|
||||||
def alternate(self, tree):
|
def alternate(self, tree):
|
||||||
l.extend(range(1, steps+1))
|
res.extend(range(1, steps+1))
|
||||||
|
|
||||||
CollectSteps().visit(tree)
|
CollectSteps().visit(tree)
|
||||||
return sorted(set(l))
|
return sorted(set(res))
|
||||||
|
|
||||||
def at_step(step, tree):
|
def at_step(step, tree):
|
||||||
class AtStep(lark.Transformer):
|
class AtStep(lark.Transformer):
|
||||||
|
@ -185,7 +185,7 @@ def image_face_points(im, settings):
|
|||||||
try:
|
try:
|
||||||
faces = classifier.detectMultiScale(gray, scaleFactor=1.1,
|
faces = classifier.detectMultiScale(gray, scaleFactor=1.1,
|
||||||
minNeighbors=7, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE)
|
minNeighbors=7, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE)
|
||||||
except:
|
except Exception:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if len(faces) > 0:
|
if len(faces) > 0:
|
||||||
|
@ -1,15 +1,9 @@
|
|||||||
import html
|
|
||||||
import json
|
import json
|
||||||
import math
|
|
||||||
import mimetypes
|
import mimetypes
|
||||||
import os
|
import os
|
||||||
import platform
|
|
||||||
import random
|
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
|
||||||
import time
|
|
||||||
import traceback
|
import traceback
|
||||||
from functools import partial, reduce
|
from functools import reduce
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
@ -45,7 +45,7 @@ class Upscaler:
|
|||||||
try:
|
try:
|
||||||
import cv2
|
import cv2
|
||||||
self.can_tile = True
|
self.can_tile = True
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
Loading…
Reference in New Issue
Block a user