make it possible for scripts to add cross attention optimizations

add UI selection for cross attention optimization
This commit is contained in:
AUTOMATIC 2023-05-18 22:48:28 +03:00
parent 2e006fa500
commit 2582a0fd3b
7 changed files with 226 additions and 49 deletions

View File

@ -53,16 +53,16 @@ parser.add_argument("--xformers", action='store_true', help="enable xformers for
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
parser.add_argument("--xformers-flash-attention", action='store_true', help="enable xformers with Flash Attention to improve reproducibility (supported for SD2.x or variant only)")
parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything")
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.")
parser.add_argument("--opt-sub-quad-attention", action='store_true', help="enable memory efficient sub-quadratic cross-attention layer optimization")
parser.add_argument("--opt-split-attention", action='store_true', help="prefer Doggettx's cross-attention layer optimization for automatic choice of optimization")
parser.add_argument("--opt-sub-quad-attention", action='store_true', help="prefer memory efficient sub-quadratic cross-attention layer optimization for automatic choice of optimization")
parser.add_argument("--sub-quad-q-chunk-size", type=int, help="query chunk size for the sub-quadratic cross-attention layer optimization to use", default=1024)
parser.add_argument("--sub-quad-kv-chunk-size", type=int, help="kv chunk size for the sub-quadratic cross-attention layer optimization to use", default=None)
parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the percentage of VRAM threshold for the sub-quadratic cross-attention layer optimization to use chunking", default=None)
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
parser.add_argument("--opt-sdp-attention", action='store_true', help="enable scaled dot product cross-attention layer optimization; requires PyTorch 2.*")
parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="enable scaled dot product cross-attention layer optimization without memory efficient attention, makes image generation deterministic; requires PyTorch 2.*")
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="prefer InvokeAI's cross-attention layer optimization for automatic choice of optimization")
parser.add_argument("--opt-split-attention-v1", action='store_true', help="prefer older version of split attention optimization for automatic choice of optimization")
parser.add_argument("--opt-sdp-attention", action='store_true', help="prefer scaled dot product cross-attention layer optimization for automatic choice of optimization; requires PyTorch 2.*")
parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="prefer scaled dot product cross-attention layer optimization without memory efficient attention for automatic choice of optimization, makes image generation deterministic; requires PyTorch 2.*")
parser.add_argument("--disable-opt-split-attention", action='store_true', help="does not do anything")
parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")

View File

@ -110,6 +110,7 @@ callback_map = dict(
callbacks_script_unloaded=[],
callbacks_before_ui=[],
callbacks_on_reload=[],
callbacks_list_optimizers=[],
)
@ -258,6 +259,18 @@ def before_ui_callback():
report_exception(c, 'before_ui')
def list_optimizers_callback():
res = []
for c in callback_map['callbacks_list_optimizers']:
try:
c.callback(res)
except Exception:
report_exception(c, 'list_optimizers')
return res
def add_callback(callbacks, fun):
stack = [x for x in inspect.stack() if x.filename != __file__]
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
@ -409,3 +422,11 @@ def on_before_ui(callback):
"""register a function to be called before the UI is created."""
add_callback(callback_map['callbacks_before_ui'], callback)
def on_list_optimizers(callback):
"""register a function to be called when UI is making a list of cross attention optimization options.
The function will be called with one argument, a list, and shall add objects of type modules.sd_hijack_optimizations.SdOptimization
to it."""
add_callback(callback_map['callbacks_list_optimizers'], callback)

View File

@ -3,8 +3,9 @@ from torch.nn.functional import silu
from types import MethodType
import modules.textual_inversion.textual_inversion
from modules import devices, sd_hijack_optimizations, shared
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors
from modules.hypernetworks import hypernetwork
from modules.sd_hijack_optimizations import diffusionmodules_model_AttnBlock_forward
from modules.shared import cmd_opts
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
@ -28,57 +29,56 @@ ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"]
ldm.modules.attention.print = lambda *args: None
ldm.modules.diffusionmodules.model.print = lambda *args: None
optimizers = []
current_optimizer: sd_hijack_optimizations.SdOptimization = None
def list_optimizers():
new_optimizers = script_callbacks.list_optimizers_callback()
new_optimizers = [x for x in new_optimizers if x.is_available()]
new_optimizers = sorted(new_optimizers, key=lambda x: x.priority(), reverse=True)
optimizers.clear()
optimizers.extend(new_optimizers)
def apply_optimizations():
global current_optimizer
undo_optimizations()
ldm.modules.diffusionmodules.model.nonlinearity = silu
ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
optimization_method = None
if current_optimizer is not None:
current_optimizer.undo()
current_optimizer = None
can_use_sdp = hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(torch.nn.functional.scaled_dot_product_attention) # not everyone has torch 2.x to use sdp
selection = shared.opts.cross_attention_optimization
if selection == "Automatic" and len(optimizers) > 0:
matching_optimizer = next(iter([x for x in optimizers if x.cmd_opt and getattr(shared.cmd_opts, x.cmd_opt, False)]), optimizers[0])
else:
matching_optimizer = next(iter([x for x in optimizers if x.title() == selection]), None)
if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
print("Applying xformers cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
optimization_method = 'xformers'
elif cmd_opts.opt_sdp_no_mem_attention and can_use_sdp:
print("Applying scaled dot product cross attention optimization (without memory efficient attention).")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_no_mem_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sdp_no_mem_attnblock_forward
optimization_method = 'sdp-no-mem'
elif cmd_opts.opt_sdp_attention and can_use_sdp:
print("Applying scaled dot product cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sdp_attnblock_forward
optimization_method = 'sdp'
elif cmd_opts.opt_sub_quad_attention:
print("Applying sub-quadratic cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.sub_quad_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sub_quad_attnblock_forward
optimization_method = 'sub-quadratic'
elif cmd_opts.opt_split_attention_v1:
print("Applying v1 cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
optimization_method = 'V1'
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not cmd_opts.opt_split_attention and not torch.cuda.is_available()):
print("Applying cross attention optimization (InvokeAI).")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI
optimization_method = 'InvokeAI'
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
print("Applying cross attention optimization (Doggettx).")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward
optimization_method = 'Doggettx'
if selection == "None":
matching_optimizer = None
elif matching_optimizer is None:
matching_optimizer = optimizers[0]
return optimization_method
if matching_optimizer is not None:
print(f"Applying optimization: {matching_optimizer.name}")
matching_optimizer.apply()
current_optimizer = matching_optimizer
return current_optimizer.name
else:
return ''
def undo_optimizations():
ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
@ -169,7 +169,11 @@ class StableDiffusionModelHijack:
if m.cond_stage_key == "edit":
sd_hijack_unet.hijack_ddpm_edit()
try:
self.optimization_method = apply_optimizations()
except Exception as e:
errors.display(e, "applying cross attention optimization")
undo_optimizations()
self.clip = m.cond_stage_model
@ -223,6 +227,10 @@ class StableDiffusionModelHijack:
return token_count, self.clip.get_target_prompt_token_count(token_count)
def redo_hijack(self, m):
self.undo_hijack(m)
self.hijack(m)
class EmbeddingsWithFixes(torch.nn.Module):
def __init__(self, wrapped, embeddings):

View File

@ -9,10 +9,139 @@ from torch import einsum
from ldm.util import default
from einops import rearrange
from modules import shared, errors, devices
from modules import shared, errors, devices, sub_quadratic_attention, script_callbacks
from modules.hypernetworks import hypernetwork
from .sub_quadratic_attention import efficient_dot_product_attention
import ldm.modules.attention
import ldm.modules.diffusionmodules.model
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
class SdOptimization:
def __init__(self, name, label=None, cmd_opt=None):
self.name = name
self.label = label
self.cmd_opt = cmd_opt
def title(self):
if self.label is None:
return self.name
return f"{self.name} - {self.label}"
def is_available(self):
return True
def priority(self):
return 0
def apply(self):
pass
def undo(self):
ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
class SdOptimizationXformers(SdOptimization):
def __init__(self):
super().__init__("xformers", cmd_opt="xformers")
def is_available(self):
return shared.cmd_opts.force_enable_xformers or (shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0))
def priority(self):
return 100
def apply(self):
ldm.modules.attention.CrossAttention.forward = xformers_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward
class SdOptimizationSdpNoMem(SdOptimization):
def __init__(self, name="sdp-no-mem", label="scaled dot product without memory efficient attention", cmd_opt="opt_sdp_no_mem_attention"):
super().__init__(name, label, cmd_opt)
def is_available(self):
return hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(torch.nn.functional.scaled_dot_product_attention)
def priority(self):
return 90
def apply(self):
ldm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward
class SdOptimizationSdp(SdOptimizationSdpNoMem):
def __init__(self):
super().__init__("sdp", "scaled dot product", cmd_opt="opt_sdp_attention")
def priority(self):
return 80
def apply(self):
ldm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward
class SdOptimizationSubQuad(SdOptimization):
def __init__(self):
super().__init__("sub-quadratic", cmd_opt="opt_sub_quad_attention")
def priority(self):
return 10
def apply(self):
ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward
class SdOptimizationV1(SdOptimization):
def __init__(self):
super().__init__("V1", "original v1", cmd_opt="opt_split_attention_v1")
def priority(self):
return 10
def apply(self):
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
class SdOptimizationInvokeAI(SdOptimization):
def __init__(self):
super().__init__("InvokeAI", cmd_opt="opt_split_attention_invokeai")
def priority(self):
return 1000 if not torch.cuda.is_available() else 10
def apply(self):
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI
class SdOptimizationDoggettx(SdOptimization):
def __init__(self):
super().__init__("Doggettx", cmd_opt="opt_split_attention")
def priority(self):
return 20
def apply(self):
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
def list_optimizers(res):
res.extend([
SdOptimizationXformers(),
SdOptimizationSdpNoMem(),
SdOptimizationSdp(),
SdOptimizationSubQuad(),
SdOptimizationV1(),
SdOptimizationInvokeAI(),
SdOptimizationDoggettx(),
])
if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
@ -299,7 +428,7 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_
kv_chunk_size = k_tokens
with devices.without_autocast(disable=q.dtype == v.dtype):
return efficient_dot_product_attention(
return sub_quadratic_attention.efficient_dot_product_attention(
q,
k,
v,

View File

@ -417,6 +417,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
}))
options_templates.update(options_section(('optimizations', "Optimizations"), {
"cross_attention_optimization": OptionInfo("Automatic", "Cross attention optimization", gr.Dropdown, lambda: {"choices": shared_items.cross_attention_optimizations()}),
"s_min_uncond": OptionInfo(0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 4.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"),
"token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"),
"token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio for img2img", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),

View File

@ -21,3 +21,11 @@ def refresh_vae_list():
import modules.sd_vae
modules.sd_vae.refresh_vae_list()
def cross_attention_optimizations():
import modules.sd_hijack
return ["Automatic"] + [x.title() for x in modules.sd_hijack.optimizers] + ["None"]

View File

@ -52,6 +52,7 @@ import modules.img2img
import modules.lowvram
import modules.scripts
import modules.sd_hijack
import modules.sd_hijack_optimizations
import modules.sd_models
import modules.sd_vae
import modules.txt2img
@ -200,6 +201,10 @@ def initialize():
modules.textual_inversion.textual_inversion.list_textual_inversion_templates()
startup_timer.record("refresh textual inversion templates")
modules.script_callbacks.on_list_optimizers(modules.sd_hijack_optimizations.list_optimizers)
modules.sd_hijack.list_optimizers()
startup_timer.record("scripts list_optimizers")
# load model in parallel to other startup stuff
Thread(target=lambda: shared.sd_model).start()
@ -208,6 +213,7 @@ def initialize():
shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: modules.sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False)
startup_timer.record("opts onchange")
shared.reload_hypernetworks()
@ -428,6 +434,10 @@ def webui():
extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
startup_timer.record("initialize extra networks")
modules.script_callbacks.on_list_optimizers(modules.sd_hijack_optimizations.list_optimizers)
modules.sd_hijack.list_optimizers()
startup_timer.record("scripts list_optimizers")
if __name__ == "__main__":
if cmd_opts.nowebui: