make it possible for scripts to add cross attention optimizations
add UI selection for cross attention optimization
This commit is contained in:
parent
2e006fa500
commit
2582a0fd3b
@ -53,16 +53,16 @@ parser.add_argument("--xformers", action='store_true', help="enable xformers for
|
||||
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
|
||||
parser.add_argument("--xformers-flash-attention", action='store_true', help="enable xformers with Flash Attention to improve reproducibility (supported for SD2.x or variant only)")
|
||||
parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything")
|
||||
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.")
|
||||
parser.add_argument("--opt-sub-quad-attention", action='store_true', help="enable memory efficient sub-quadratic cross-attention layer optimization")
|
||||
parser.add_argument("--opt-split-attention", action='store_true', help="prefer Doggettx's cross-attention layer optimization for automatic choice of optimization")
|
||||
parser.add_argument("--opt-sub-quad-attention", action='store_true', help="prefer memory efficient sub-quadratic cross-attention layer optimization for automatic choice of optimization")
|
||||
parser.add_argument("--sub-quad-q-chunk-size", type=int, help="query chunk size for the sub-quadratic cross-attention layer optimization to use", default=1024)
|
||||
parser.add_argument("--sub-quad-kv-chunk-size", type=int, help="kv chunk size for the sub-quadratic cross-attention layer optimization to use", default=None)
|
||||
parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the percentage of VRAM threshold for the sub-quadratic cross-attention layer optimization to use chunking", default=None)
|
||||
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
|
||||
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
|
||||
parser.add_argument("--opt-sdp-attention", action='store_true', help="enable scaled dot product cross-attention layer optimization; requires PyTorch 2.*")
|
||||
parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="enable scaled dot product cross-attention layer optimization without memory efficient attention, makes image generation deterministic; requires PyTorch 2.*")
|
||||
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
|
||||
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="prefer InvokeAI's cross-attention layer optimization for automatic choice of optimization")
|
||||
parser.add_argument("--opt-split-attention-v1", action='store_true', help="prefer older version of split attention optimization for automatic choice of optimization")
|
||||
parser.add_argument("--opt-sdp-attention", action='store_true', help="prefer scaled dot product cross-attention layer optimization for automatic choice of optimization; requires PyTorch 2.*")
|
||||
parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="prefer scaled dot product cross-attention layer optimization without memory efficient attention for automatic choice of optimization, makes image generation deterministic; requires PyTorch 2.*")
|
||||
parser.add_argument("--disable-opt-split-attention", action='store_true', help="does not do anything")
|
||||
parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
|
||||
parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
|
||||
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
|
||||
|
@ -110,6 +110,7 @@ callback_map = dict(
|
||||
callbacks_script_unloaded=[],
|
||||
callbacks_before_ui=[],
|
||||
callbacks_on_reload=[],
|
||||
callbacks_list_optimizers=[],
|
||||
)
|
||||
|
||||
|
||||
@ -258,6 +259,18 @@ def before_ui_callback():
|
||||
report_exception(c, 'before_ui')
|
||||
|
||||
|
||||
def list_optimizers_callback():
|
||||
res = []
|
||||
|
||||
for c in callback_map['callbacks_list_optimizers']:
|
||||
try:
|
||||
c.callback(res)
|
||||
except Exception:
|
||||
report_exception(c, 'list_optimizers')
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def add_callback(callbacks, fun):
|
||||
stack = [x for x in inspect.stack() if x.filename != __file__]
|
||||
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
|
||||
@ -409,3 +422,11 @@ def on_before_ui(callback):
|
||||
"""register a function to be called before the UI is created."""
|
||||
|
||||
add_callback(callback_map['callbacks_before_ui'], callback)
|
||||
|
||||
|
||||
def on_list_optimizers(callback):
|
||||
"""register a function to be called when UI is making a list of cross attention optimization options.
|
||||
The function will be called with one argument, a list, and shall add objects of type modules.sd_hijack_optimizations.SdOptimization
|
||||
to it."""
|
||||
|
||||
add_callback(callback_map['callbacks_list_optimizers'], callback)
|
||||
|
@ -3,8 +3,9 @@ from torch.nn.functional import silu
|
||||
from types import MethodType
|
||||
|
||||
import modules.textual_inversion.textual_inversion
|
||||
from modules import devices, sd_hijack_optimizations, shared
|
||||
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors
|
||||
from modules.hypernetworks import hypernetwork
|
||||
from modules.sd_hijack_optimizations import diffusionmodules_model_AttnBlock_forward
|
||||
from modules.shared import cmd_opts
|
||||
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
|
||||
|
||||
@ -28,57 +29,56 @@ ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"]
|
||||
ldm.modules.attention.print = lambda *args: None
|
||||
ldm.modules.diffusionmodules.model.print = lambda *args: None
|
||||
|
||||
optimizers = []
|
||||
current_optimizer: sd_hijack_optimizations.SdOptimization = None
|
||||
|
||||
|
||||
def list_optimizers():
|
||||
new_optimizers = script_callbacks.list_optimizers_callback()
|
||||
|
||||
new_optimizers = [x for x in new_optimizers if x.is_available()]
|
||||
|
||||
new_optimizers = sorted(new_optimizers, key=lambda x: x.priority(), reverse=True)
|
||||
|
||||
optimizers.clear()
|
||||
optimizers.extend(new_optimizers)
|
||||
|
||||
|
||||
def apply_optimizations():
|
||||
global current_optimizer
|
||||
|
||||
undo_optimizations()
|
||||
|
||||
ldm.modules.diffusionmodules.model.nonlinearity = silu
|
||||
ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
|
||||
|
||||
optimization_method = None
|
||||
if current_optimizer is not None:
|
||||
current_optimizer.undo()
|
||||
current_optimizer = None
|
||||
|
||||
can_use_sdp = hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(torch.nn.functional.scaled_dot_product_attention) # not everyone has torch 2.x to use sdp
|
||||
selection = shared.opts.cross_attention_optimization
|
||||
if selection == "Automatic" and len(optimizers) > 0:
|
||||
matching_optimizer = next(iter([x for x in optimizers if x.cmd_opt and getattr(shared.cmd_opts, x.cmd_opt, False)]), optimizers[0])
|
||||
else:
|
||||
matching_optimizer = next(iter([x for x in optimizers if x.title() == selection]), None)
|
||||
|
||||
if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
|
||||
print("Applying xformers cross attention optimization.")
|
||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
|
||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
|
||||
optimization_method = 'xformers'
|
||||
elif cmd_opts.opt_sdp_no_mem_attention and can_use_sdp:
|
||||
print("Applying scaled dot product cross attention optimization (without memory efficient attention).")
|
||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_no_mem_attention_forward
|
||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sdp_no_mem_attnblock_forward
|
||||
optimization_method = 'sdp-no-mem'
|
||||
elif cmd_opts.opt_sdp_attention and can_use_sdp:
|
||||
print("Applying scaled dot product cross attention optimization.")
|
||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_attention_forward
|
||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sdp_attnblock_forward
|
||||
optimization_method = 'sdp'
|
||||
elif cmd_opts.opt_sub_quad_attention:
|
||||
print("Applying sub-quadratic cross attention optimization.")
|
||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.sub_quad_attention_forward
|
||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sub_quad_attnblock_forward
|
||||
optimization_method = 'sub-quadratic'
|
||||
elif cmd_opts.opt_split_attention_v1:
|
||||
print("Applying v1 cross attention optimization.")
|
||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
|
||||
optimization_method = 'V1'
|
||||
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not cmd_opts.opt_split_attention and not torch.cuda.is_available()):
|
||||
print("Applying cross attention optimization (InvokeAI).")
|
||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI
|
||||
optimization_method = 'InvokeAI'
|
||||
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
|
||||
print("Applying cross attention optimization (Doggettx).")
|
||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
|
||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward
|
||||
optimization_method = 'Doggettx'
|
||||
if selection == "None":
|
||||
matching_optimizer = None
|
||||
elif matching_optimizer is None:
|
||||
matching_optimizer = optimizers[0]
|
||||
|
||||
return optimization_method
|
||||
if matching_optimizer is not None:
|
||||
print(f"Applying optimization: {matching_optimizer.name}")
|
||||
matching_optimizer.apply()
|
||||
current_optimizer = matching_optimizer
|
||||
return current_optimizer.name
|
||||
else:
|
||||
return ''
|
||||
|
||||
|
||||
def undo_optimizations():
|
||||
ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
|
||||
ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
|
||||
ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
|
||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
|
||||
|
||||
|
||||
@ -169,7 +169,11 @@ class StableDiffusionModelHijack:
|
||||
if m.cond_stage_key == "edit":
|
||||
sd_hijack_unet.hijack_ddpm_edit()
|
||||
|
||||
try:
|
||||
self.optimization_method = apply_optimizations()
|
||||
except Exception as e:
|
||||
errors.display(e, "applying cross attention optimization")
|
||||
undo_optimizations()
|
||||
|
||||
self.clip = m.cond_stage_model
|
||||
|
||||
@ -223,6 +227,10 @@ class StableDiffusionModelHijack:
|
||||
|
||||
return token_count, self.clip.get_target_prompt_token_count(token_count)
|
||||
|
||||
def redo_hijack(self, m):
|
||||
self.undo_hijack(m)
|
||||
self.hijack(m)
|
||||
|
||||
|
||||
class EmbeddingsWithFixes(torch.nn.Module):
|
||||
def __init__(self, wrapped, embeddings):
|
||||
|
@ -9,10 +9,139 @@ from torch import einsum
|
||||
from ldm.util import default
|
||||
from einops import rearrange
|
||||
|
||||
from modules import shared, errors, devices
|
||||
from modules import shared, errors, devices, sub_quadratic_attention, script_callbacks
|
||||
from modules.hypernetworks import hypernetwork
|
||||
|
||||
from .sub_quadratic_attention import efficient_dot_product_attention
|
||||
import ldm.modules.attention
|
||||
import ldm.modules.diffusionmodules.model
|
||||
|
||||
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
|
||||
|
||||
|
||||
class SdOptimization:
|
||||
def __init__(self, name, label=None, cmd_opt=None):
|
||||
self.name = name
|
||||
self.label = label
|
||||
self.cmd_opt = cmd_opt
|
||||
|
||||
def title(self):
|
||||
if self.label is None:
|
||||
return self.name
|
||||
|
||||
return f"{self.name} - {self.label}"
|
||||
|
||||
def is_available(self):
|
||||
return True
|
||||
|
||||
def priority(self):
|
||||
return 0
|
||||
|
||||
def apply(self):
|
||||
pass
|
||||
|
||||
def undo(self):
|
||||
ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
|
||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
|
||||
|
||||
|
||||
class SdOptimizationXformers(SdOptimization):
|
||||
def __init__(self):
|
||||
super().__init__("xformers", cmd_opt="xformers")
|
||||
|
||||
def is_available(self):
|
||||
return shared.cmd_opts.force_enable_xformers or (shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0))
|
||||
|
||||
def priority(self):
|
||||
return 100
|
||||
|
||||
def apply(self):
|
||||
ldm.modules.attention.CrossAttention.forward = xformers_attention_forward
|
||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward
|
||||
|
||||
|
||||
class SdOptimizationSdpNoMem(SdOptimization):
|
||||
def __init__(self, name="sdp-no-mem", label="scaled dot product without memory efficient attention", cmd_opt="opt_sdp_no_mem_attention"):
|
||||
super().__init__(name, label, cmd_opt)
|
||||
|
||||
def is_available(self):
|
||||
return hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(torch.nn.functional.scaled_dot_product_attention)
|
||||
|
||||
def priority(self):
|
||||
return 90
|
||||
|
||||
def apply(self):
|
||||
ldm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward
|
||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward
|
||||
|
||||
|
||||
class SdOptimizationSdp(SdOptimizationSdpNoMem):
|
||||
def __init__(self):
|
||||
super().__init__("sdp", "scaled dot product", cmd_opt="opt_sdp_attention")
|
||||
|
||||
def priority(self):
|
||||
return 80
|
||||
|
||||
def apply(self):
|
||||
ldm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward
|
||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward
|
||||
|
||||
|
||||
class SdOptimizationSubQuad(SdOptimization):
|
||||
def __init__(self):
|
||||
super().__init__("sub-quadratic", cmd_opt="opt_sub_quad_attention")
|
||||
|
||||
def priority(self):
|
||||
return 10
|
||||
|
||||
def apply(self):
|
||||
ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
|
||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward
|
||||
|
||||
|
||||
class SdOptimizationV1(SdOptimization):
|
||||
def __init__(self):
|
||||
super().__init__("V1", "original v1", cmd_opt="opt_split_attention_v1")
|
||||
|
||||
def priority(self):
|
||||
return 10
|
||||
|
||||
def apply(self):
|
||||
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
|
||||
|
||||
|
||||
class SdOptimizationInvokeAI(SdOptimization):
|
||||
def __init__(self):
|
||||
super().__init__("InvokeAI", cmd_opt="opt_split_attention_invokeai")
|
||||
|
||||
def priority(self):
|
||||
return 1000 if not torch.cuda.is_available() else 10
|
||||
|
||||
def apply(self):
|
||||
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI
|
||||
|
||||
|
||||
class SdOptimizationDoggettx(SdOptimization):
|
||||
def __init__(self):
|
||||
super().__init__("Doggettx", cmd_opt="opt_split_attention")
|
||||
|
||||
def priority(self):
|
||||
return 20
|
||||
|
||||
def apply(self):
|
||||
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
|
||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
|
||||
|
||||
|
||||
def list_optimizers(res):
|
||||
res.extend([
|
||||
SdOptimizationXformers(),
|
||||
SdOptimizationSdpNoMem(),
|
||||
SdOptimizationSdp(),
|
||||
SdOptimizationSubQuad(),
|
||||
SdOptimizationV1(),
|
||||
SdOptimizationInvokeAI(),
|
||||
SdOptimizationDoggettx(),
|
||||
])
|
||||
|
||||
|
||||
if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
|
||||
@ -299,7 +428,7 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_
|
||||
kv_chunk_size = k_tokens
|
||||
|
||||
with devices.without_autocast(disable=q.dtype == v.dtype):
|
||||
return efficient_dot_product_attention(
|
||||
return sub_quadratic_attention.efficient_dot_product_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
|
@ -417,6 +417,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('optimizations', "Optimizations"), {
|
||||
"cross_attention_optimization": OptionInfo("Automatic", "Cross attention optimization", gr.Dropdown, lambda: {"choices": shared_items.cross_attention_optimizations()}),
|
||||
"s_min_uncond": OptionInfo(0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 4.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"),
|
||||
"token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"),
|
||||
"token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio for img2img", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
|
||||
|
@ -21,3 +21,11 @@ def refresh_vae_list():
|
||||
import modules.sd_vae
|
||||
|
||||
modules.sd_vae.refresh_vae_list()
|
||||
|
||||
|
||||
def cross_attention_optimizations():
|
||||
import modules.sd_hijack
|
||||
|
||||
return ["Automatic"] + [x.title() for x in modules.sd_hijack.optimizers] + ["None"]
|
||||
|
||||
|
||||
|
10
webui.py
10
webui.py
@ -52,6 +52,7 @@ import modules.img2img
|
||||
import modules.lowvram
|
||||
import modules.scripts
|
||||
import modules.sd_hijack
|
||||
import modules.sd_hijack_optimizations
|
||||
import modules.sd_models
|
||||
import modules.sd_vae
|
||||
import modules.txt2img
|
||||
@ -200,6 +201,10 @@ def initialize():
|
||||
modules.textual_inversion.textual_inversion.list_textual_inversion_templates()
|
||||
startup_timer.record("refresh textual inversion templates")
|
||||
|
||||
modules.script_callbacks.on_list_optimizers(modules.sd_hijack_optimizations.list_optimizers)
|
||||
modules.sd_hijack.list_optimizers()
|
||||
startup_timer.record("scripts list_optimizers")
|
||||
|
||||
# load model in parallel to other startup stuff
|
||||
Thread(target=lambda: shared.sd_model).start()
|
||||
|
||||
@ -208,6 +213,7 @@ def initialize():
|
||||
shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
|
||||
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
|
||||
shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
|
||||
shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: modules.sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False)
|
||||
startup_timer.record("opts onchange")
|
||||
|
||||
shared.reload_hypernetworks()
|
||||
@ -428,6 +434,10 @@ def webui():
|
||||
extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
|
||||
startup_timer.record("initialize extra networks")
|
||||
|
||||
modules.script_callbacks.on_list_optimizers(modules.sd_hijack_optimizations.list_optimizers)
|
||||
modules.sd_hijack.list_optimizers()
|
||||
startup_timer.record("scripts list_optimizers")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if cmd_opts.nowebui:
|
||||
|
Loading…
Reference in New Issue
Block a user