Merge pull request #9256 from papuSpartan/tomesd
Integrate optional speed and memory improvements by token merging (via dbolya/tomesd)
This commit is contained in:
commit
7f6ef764b9
@ -308,8 +308,10 @@ infotext_to_setting_name_mapping = [
|
|||||||
('UniPC skip type', 'uni_pc_skip_type'),
|
('UniPC skip type', 'uni_pc_skip_type'),
|
||||||
('UniPC order', 'uni_pc_order'),
|
('UniPC order', 'uni_pc_order'),
|
||||||
('UniPC lower order final', 'uni_pc_lower_order_final'),
|
('UniPC lower order final', 'uni_pc_lower_order_final'),
|
||||||
|
('Token merging ratio', 'token_merging_ratio'),
|
||||||
|
('Token merging ratio hr', 'token_merging_ratio_hr'),
|
||||||
('RNG', 'randn_source'),
|
('RNG', 'randn_source'),
|
||||||
('NGMS', 's_min_uncond'),
|
('NGMS', 's_min_uncond')
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -29,6 +29,13 @@ from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
|
|||||||
|
|
||||||
from einops import repeat, rearrange
|
from einops import repeat, rearrange
|
||||||
from blendmodes.blend import blendLayers, BlendType
|
from blendmodes.blend import blendLayers, BlendType
|
||||||
|
import tomesd
|
||||||
|
|
||||||
|
# add a logger for the processing module
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
# manually set output level here since there is no option to do so yet through launch options
|
||||||
|
# logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s %(name)s %(message)s')
|
||||||
|
|
||||||
|
|
||||||
# some of those options should not be changed at all because they would break the model, so I removed them from options.
|
# some of those options should not be changed at all because they would break the model, so I removed them from options.
|
||||||
opt_C = 4
|
opt_C = 4
|
||||||
@ -471,6 +478,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
|
|||||||
index = position_in_batch + iteration * p.batch_size
|
index = position_in_batch + iteration * p.batch_size
|
||||||
|
|
||||||
clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
|
clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
|
||||||
|
enable_hr = getattr(p, 'enable_hr', False)
|
||||||
|
|
||||||
generation_params = {
|
generation_params = {
|
||||||
"Steps": p.steps,
|
"Steps": p.steps,
|
||||||
@ -489,6 +497,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
|
|||||||
"Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
|
"Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
|
||||||
"Clip skip": None if clip_skip <= 1 else clip_skip,
|
"Clip skip": None if clip_skip <= 1 else clip_skip,
|
||||||
"ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
|
"ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
|
||||||
|
"Token merging ratio": None if opts.token_merging_ratio == 0 else opts.token_merging_ratio,
|
||||||
|
"Token merging ratio hr": None if not enable_hr or opts.token_merging_ratio_hr == 0 else opts.token_merging_ratio_hr,
|
||||||
"Init image hash": getattr(p, 'init_img_hash', None),
|
"Init image hash": getattr(p, 'init_img_hash', None),
|
||||||
"RNG": opts.randn_source if opts.randn_source != "GPU" else None,
|
"RNG": opts.randn_source if opts.randn_source != "GPU" else None,
|
||||||
"NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
|
"NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
|
||||||
@ -522,9 +532,18 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||||||
if k == 'sd_vae':
|
if k == 'sd_vae':
|
||||||
sd_vae.reload_vae_weights()
|
sd_vae.reload_vae_weights()
|
||||||
|
|
||||||
|
if opts.token_merging_ratio > 0:
|
||||||
|
sd_models.apply_token_merging(sd_model=p.sd_model, hr=False)
|
||||||
|
logger.debug(f"Token merging applied to first pass. Ratio: '{opts.token_merging_ratio}'")
|
||||||
|
|
||||||
res = process_images_inner(p)
|
res = process_images_inner(p)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
|
# undo model optimizations made by tomesd
|
||||||
|
if opts.token_merging_ratio > 0:
|
||||||
|
tomesd.remove_patch(p.sd_model)
|
||||||
|
logger.debug('Token merging model optimizations removed')
|
||||||
|
|
||||||
# restore opts to original state
|
# restore opts to original state
|
||||||
if p.override_settings_restore_afterwards:
|
if p.override_settings_restore_afterwards:
|
||||||
for k, v in stored_opts.items():
|
for k, v in stored_opts.items():
|
||||||
@ -977,8 +996,22 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
x = None
|
x = None
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
|
# apply token merging optimizations from tomesd for high-res pass
|
||||||
|
if opts.token_merging_ratio_hr > 0:
|
||||||
|
# in case the user has used separate merge ratios
|
||||||
|
if opts.token_merging_ratio > 0:
|
||||||
|
tomesd.remove_patch(self.sd_model)
|
||||||
|
logger.debug('Adjusting token merging ratio for high-res pass')
|
||||||
|
|
||||||
|
sd_models.apply_token_merging(sd_model=self.sd_model, hr=True)
|
||||||
|
logger.debug(f"Applied token merging for high-res pass. Ratio: '{opts.token_merging_ratio_hr}'")
|
||||||
|
|
||||||
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
|
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
|
||||||
|
|
||||||
|
if opts.token_merging_ratio_hr > 0 or opts.token_merging_ratio > 0:
|
||||||
|
tomesd.remove_patch(self.sd_model)
|
||||||
|
logger.debug('Removed token merging optimizations from model')
|
||||||
|
|
||||||
self.is_hr_pass = False
|
self.is_hr_pass = False
|
||||||
|
|
||||||
return samples
|
return samples
|
||||||
|
@ -17,6 +17,7 @@ from ldm.util import instantiate_from_config
|
|||||||
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config
|
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config
|
||||||
from modules.sd_hijack_inpainting import do_inpainting_hijack
|
from modules.sd_hijack_inpainting import do_inpainting_hijack
|
||||||
from modules.timer import Timer
|
from modules.timer import Timer
|
||||||
|
import tomesd
|
||||||
|
|
||||||
model_dir = "Stable-diffusion"
|
model_dir = "Stable-diffusion"
|
||||||
model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
|
model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
|
||||||
@ -578,3 +579,25 @@ def unload_model_weights(sd_model=None, info=None):
|
|||||||
print(f"Unloaded weights {timer.summary()}.")
|
print(f"Unloaded weights {timer.summary()}.")
|
||||||
|
|
||||||
return sd_model
|
return sd_model
|
||||||
|
|
||||||
|
|
||||||
|
def apply_token_merging(sd_model, hr: bool):
|
||||||
|
"""
|
||||||
|
Applies speed and memory optimizations from tomesd.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hr (bool): True if called in the context of a high-res pass
|
||||||
|
"""
|
||||||
|
|
||||||
|
ratio = shared.opts.token_merging_ratio
|
||||||
|
if hr:
|
||||||
|
ratio = shared.opts.token_merging_ratio_hr
|
||||||
|
|
||||||
|
tomesd.apply_patch(
|
||||||
|
sd_model,
|
||||||
|
ratio=ratio,
|
||||||
|
use_rand=False, # can cause issues with some samplers
|
||||||
|
merge_attn=True,
|
||||||
|
merge_crossattn=False,
|
||||||
|
merge_mlp=False
|
||||||
|
)
|
||||||
|
@ -350,6 +350,8 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
|||||||
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
|
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
|
||||||
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
|
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
|
||||||
"randn_source": OptionInfo("GPU", "Random number generator source. Changes seeds drastically. Use CPU to produce the same picture across different vidocard vendors.", gr.Radio, {"choices": ["GPU", "CPU"]}),
|
"randn_source": OptionInfo("GPU", "Random number generator source. Changes seeds drastically. Use CPU to produce the same picture across different vidocard vendors.", gr.Radio, {"choices": ["GPU", "CPU"]}),
|
||||||
|
"token_merging_ratio_hr": OptionInfo(0, "Merging Ratio (high-res pass)", gr.Slider, {"minimum": 0, "maximum": 0.9, "step": 0.1}),
|
||||||
|
"token_merging_ratio": OptionInfo(0, "Merging Ratio", gr.Slider, {"minimum": 0, "maximum": 0.9, "step": 0.1})
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('compatibility', "Compatibility"), {
|
options_templates.update(options_section(('compatibility', "Compatibility"), {
|
||||||
@ -458,6 +460,7 @@ options_templates.update(options_section((None, "Hidden options"), {
|
|||||||
"sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"),
|
"sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
|
||||||
options_templates.update()
|
options_templates.update()
|
||||||
|
|
||||||
|
|
||||||
|
@ -26,3 +26,4 @@ torchsde==0.2.5
|
|||||||
safetensors==0.3.1
|
safetensors==0.3.1
|
||||||
httpcore<=0.15
|
httpcore<=0.15
|
||||||
fastapi==0.94.0
|
fastapi==0.94.0
|
||||||
|
tomesd>=0.1.2
|
Loading…
Reference in New Issue
Block a user