Allow different merge ratios to be used for each pass. Make toggle cmd flag work again. Remove ratio flag. Remove warning about controlnet being incompatible
This commit is contained in:
parent
c707b7df95
commit
5c8e53d5e9
@ -103,5 +103,4 @@ parser.add_argument("--no-hashing", action='store_true', help="disable sha256 ha
|
||||
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
|
||||
|
||||
# token merging / tomesd
|
||||
parser.add_argument("--token-merging", action='store_true', help="Provides generation speedup by merging redundant tokens. (compatible with --xformers)", default=False)
|
||||
parser.add_argument("--token-merging-ratio", type=float, help="Adjusts ratio of merged to untouched tokens. Range: (0.0-1.0], Defaults to 0.5", default=0.5)
|
||||
parser.add_argument("--token-merging", action='store_true', help="Provides speed and memory improvements by merging redundant tokens. This has a more pronounced effect on higher resolutions.", default=False)
|
||||
|
@ -501,26 +501,16 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||
if k == 'sd_vae':
|
||||
sd_vae.reload_vae_weights()
|
||||
|
||||
if opts.token_merging and not opts.token_merging_hr_only:
|
||||
print("applying token merging to all passes")
|
||||
tomesd.apply_patch(
|
||||
p.sd_model,
|
||||
ratio=opts.token_merging_ratio,
|
||||
max_downsample=opts.token_merging_maximum_down_sampling,
|
||||
sx=opts.token_merging_stride_x,
|
||||
sy=opts.token_merging_stride_y,
|
||||
use_rand=opts.token_merging_random,
|
||||
merge_attn=opts.token_merging_merge_attention,
|
||||
merge_crossattn=opts.token_merging_merge_cross_attention,
|
||||
merge_mlp=opts.token_merging_merge_mlp
|
||||
)
|
||||
if (opts.token_merging or cmd_opts.token_merging) and not opts.token_merging_hr_only:
|
||||
print("\nApplying token merging\n")
|
||||
sd_models.apply_token_merging(sd_model=p.sd_model, hr=False)
|
||||
|
||||
res = process_images_inner(p)
|
||||
|
||||
finally:
|
||||
# undo model optimizations made by tomesd
|
||||
if opts.token_merging:
|
||||
print('removing token merging model optimizations')
|
||||
if opts.token_merging or cmd_opts.token_merging:
|
||||
print('\nRemoving token merging model optimizations\n')
|
||||
tomesd.remove_patch(p.sd_model)
|
||||
|
||||
# restore opts to original state
|
||||
@ -959,20 +949,16 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
devices.torch_gc()
|
||||
|
||||
# apply token merging optimizations from tomesd for high-res pass
|
||||
# check if hr_only so we don't redundantly apply patch
|
||||
if opts.token_merging and opts.token_merging_hr_only:
|
||||
print("applying token merging for high-res pass")
|
||||
tomesd.apply_patch(
|
||||
self.sd_model,
|
||||
ratio=opts.token_merging_ratio,
|
||||
max_downsample=opts.token_merging_maximum_down_sampling,
|
||||
sx=opts.token_merging_stride_x,
|
||||
sy=opts.token_merging_stride_y,
|
||||
use_rand=opts.token_merging_random,
|
||||
merge_attn=opts.token_merging_merge_attention,
|
||||
merge_crossattn=opts.token_merging_merge_cross_attention,
|
||||
merge_mlp=opts.token_merging_merge_mlp
|
||||
)
|
||||
# check if hr_only so we are not redundantly patching
|
||||
if (cmd_opts.token_merging or opts.token_merging) and (opts.token_merging_hr_only or opts.token_merging_ratio_hr != opts.token_merging_ratio):
|
||||
# case where user wants to use separate merge ratios
|
||||
if not opts.token_merging_hr_only:
|
||||
# clean patch done by first pass. (clobbering the first patch might be fine? this might be excessive)
|
||||
print('Temporarily reverting token merging optimizations in preparation for next pass')
|
||||
tomesd.remove_patch(self.sd_model)
|
||||
|
||||
print("\nApplying token merging for high-res pass\n")
|
||||
sd_models.apply_token_merging(sd_model=self.sd_model, hr=True)
|
||||
|
||||
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
|
||||
|
||||
|
@ -16,6 +16,7 @@ from modules import paths, shared, modelloader, devices, script_callbacks, sd_va
|
||||
from modules.paths import models_path
|
||||
from modules.sd_hijack_inpainting import do_inpainting_hijack
|
||||
from modules.timer import Timer
|
||||
import tomesd
|
||||
|
||||
model_dir = "Stable-diffusion"
|
||||
model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
|
||||
@ -546,3 +547,29 @@ def unload_model_weights(sd_model=None, info=None):
|
||||
print(f"Unloaded weights {timer.summary()}.")
|
||||
|
||||
return sd_model
|
||||
|
||||
|
||||
def apply_token_merging(sd_model, hr: bool):
|
||||
"""
|
||||
Applies speed and memory optimizations from tomesd.
|
||||
|
||||
Args:
|
||||
hr (bool): True if called in the context of a high-res pass
|
||||
"""
|
||||
|
||||
ratio = shared.opts.token_merging_ratio
|
||||
if hr:
|
||||
ratio = shared.opts.token_merging_ratio_hr
|
||||
print("effective hr pass merge ratio is "+str(ratio))
|
||||
|
||||
tomesd.apply_patch(
|
||||
sd_model,
|
||||
ratio=ratio,
|
||||
max_downsample=shared.opts.token_merging_maximum_down_sampling,
|
||||
sx=shared.opts.token_merging_stride_x,
|
||||
sy=shared.opts.token_merging_stride_y,
|
||||
use_rand=shared.opts.token_merging_random,
|
||||
merge_attn=shared.opts.token_merging_merge_attention,
|
||||
merge_crossattn=shared.opts.token_merging_merge_cross_attention,
|
||||
merge_mlp=shared.opts.token_merging_merge_mlp
|
||||
)
|
||||
|
@ -429,7 +429,7 @@ options_templates.update(options_section((None, "Hidden options"), {
|
||||
|
||||
options_templates.update(options_section(('token_merging', 'Token Merging'), {
|
||||
"token_merging": OptionInfo(
|
||||
False, "Enable redundant token merging via tomesd. (currently incompatible with controlnet extension)",
|
||||
0.5, "Enable redundant token merging via tomesd. This can provide significant speed and memory improvements.",
|
||||
gr.Checkbox
|
||||
),
|
||||
"token_merging_ratio": OptionInfo(
|
||||
@ -440,6 +440,10 @@ options_templates.update(options_section(('token_merging', 'Token Merging'), {
|
||||
True, "Apply only to high-res fix pass. Disabling can yield a ~20-35% speedup on contemporary resolutions.",
|
||||
gr.Checkbox
|
||||
),
|
||||
"token_merging_ratio_hr": OptionInfo(
|
||||
0.5, "Merging Ratio (high-res pass) - If 'Apply only to high-res' is enabled, this will always be the ratio used.",
|
||||
gr.Slider, {"minimum": 0, "maximum": 0.9, "step": 0.1}
|
||||
),
|
||||
# More advanced/niche settings:
|
||||
"token_merging_random": OptionInfo(
|
||||
True, "Use random perturbations - Disabling might help with certain samplers",
|
||||
|
Loading…
x
Reference in New Issue
Block a user