From b85fc7187d953828340d4e3af34af46d9fc70b9e Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 10 Jul 2023 21:18:34 +0300 Subject: [PATCH 1/4] Fix MPS cache cleanup Importing torch does not import torch.mps so the call failed. --- modules/devices.py | 5 +++-- modules/mac_specific.py | 14 ++++++++++++++ 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index c5ad950f..57e51da3 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -54,8 +54,9 @@ def torch_gc(): with torch.cuda.device(get_cuda_device_string()): torch.cuda.empty_cache() torch.cuda.ipc_collect() - elif has_mps() and hasattr(torch.mps, 'empty_cache'): - torch.mps.empty_cache() + + if has_mps(): + mac_specific.torch_mps_gc() def enable_tf32(): diff --git a/modules/mac_specific.py b/modules/mac_specific.py index 735847f5..2c2f15ca 100644 --- a/modules/mac_specific.py +++ b/modules/mac_specific.py @@ -1,8 +1,12 @@ +import logging + import torch import platform from modules.sd_hijack_utils import CondFunc from packaging import version +log = logging.getLogger() + # before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+, # use check `getattr` and try it for compatibility. @@ -19,9 +23,19 @@ def check_for_mps() -> bool: return False else: return torch.backends.mps.is_available() and torch.backends.mps.is_built() + + has_mps = check_for_mps() +def torch_mps_gc() -> None: + try: + from torch.mps import empty_cache + empty_cache() + except Exception: + log.warning("MPS garbage collection failed", exc_info=True) + + # MPS workaround for https://github.com/pytorch/pytorch/issues/89784 def cumsum_fix(input, cumsum_func, *args, **kwargs): if input.device.type == 'mps': From 3636c2c6eda6ea25db95a5e3e77fe1ac347f0081 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Tue, 11 Jul 2023 15:05:20 +0300 Subject: [PATCH 2/4] Allow using alt in the prompt fields again --- javascript/edit-order.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/javascript/edit-order.js b/javascript/edit-order.js index e6e73937..ed4ef9ac 100644 --- a/javascript/edit-order.js +++ b/javascript/edit-order.js @@ -6,11 +6,11 @@ function keyupEditOrder(event) { let target = event.originalTarget || event.composedPath()[0]; if (!target.matches("*:is([id*='_toprow'] [id*='_prompt'], .prompt) textarea")) return; if (!event.altKey) return; - event.preventDefault(); let isLeft = event.key == "ArrowLeft"; let isRight = event.key == "ArrowRight"; if (!isLeft && !isRight) return; + event.preventDefault(); let selectionStart = target.selectionStart; let selectionEnd = target.selectionEnd; From 8f6b24ce5922174d96eb9776126488cb28694ff8 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Wed, 12 Jul 2023 15:16:42 +0300 Subject: [PATCH 3/4] Add correct logger name --- modules/mac_specific.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/mac_specific.py b/modules/mac_specific.py index 2c2f15ca..328b5973 100644 --- a/modules/mac_specific.py +++ b/modules/mac_specific.py @@ -5,7 +5,7 @@ import platform from modules.sd_hijack_utils import CondFunc from packaging import version -log = logging.getLogger() +log = logging.getLogger(__name__) # before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+, From 3d524fd3f1bdb17946bf6fa8a3cdf7b10859c495 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Wed, 12 Jul 2023 15:17:13 +0300 Subject: [PATCH 4/4] Don't do MPS GC when there's a latent that could still be sampled --- modules/mac_specific.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/modules/mac_specific.py b/modules/mac_specific.py index 328b5973..9ceb43ba 100644 --- a/modules/mac_specific.py +++ b/modules/mac_specific.py @@ -30,6 +30,10 @@ has_mps = check_for_mps() def torch_mps_gc() -> None: try: + from modules.shared import state + if state.current_latent is not None: + log.debug("`current_latent` is set, skipping MPS garbage collection") + return from torch.mps import empty_cache empty_cache() except Exception: