Refactor Mac specific code to a separate file
Move most Mac related code to a separate file, don't even load it unless web UI is run under macOS.
This commit is contained in:
parent
226d840e84
commit
1b8af15f13
@ -1,22 +1,17 @@
|
|||||||
import sys, os, shlex
|
import sys
|
||||||
import contextlib
|
import contextlib
|
||||||
import torch
|
import torch
|
||||||
from modules import errors
|
from modules import errors
|
||||||
from modules.sd_hijack_utils import CondFunc
|
|
||||||
from packaging import version
|
if sys.platform == "darwin":
|
||||||
|
from modules import mac_specific
|
||||||
|
|
||||||
|
|
||||||
# has_mps is only available in nightly pytorch (for now) and macOS 12.3+.
|
|
||||||
# check `getattr` and try it for compatibility
|
|
||||||
def has_mps() -> bool:
|
def has_mps() -> bool:
|
||||||
if not getattr(torch, 'has_mps', False):
|
if sys.platform != "darwin":
|
||||||
return False
|
return False
|
||||||
try:
|
else:
|
||||||
torch.zeros(1).to(torch.device("mps"))
|
return mac_specific.has_mps
|
||||||
return True
|
|
||||||
except Exception:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def extract_device_id(args, name):
|
def extract_device_id(args, name):
|
||||||
for x in range(len(args)):
|
for x in range(len(args)):
|
||||||
@ -155,36 +150,3 @@ def test_for_nans(x, where):
|
|||||||
message += " Use --disable-nan-check commandline argument to disable this check."
|
message += " Use --disable-nan-check commandline argument to disable this check."
|
||||||
|
|
||||||
raise NansException(message)
|
raise NansException(message)
|
||||||
|
|
||||||
|
|
||||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/89784
|
|
||||||
def cumsum_fix(input, cumsum_func, *args, **kwargs):
|
|
||||||
if input.device.type == 'mps':
|
|
||||||
output_dtype = kwargs.get('dtype', input.dtype)
|
|
||||||
if output_dtype == torch.int64:
|
|
||||||
return cumsum_func(input.cpu(), *args, **kwargs).to(input.device)
|
|
||||||
elif cumsum_needs_bool_fix and output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16):
|
|
||||||
return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64)
|
|
||||||
return cumsum_func(input, *args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
if has_mps():
|
|
||||||
if version.parse(torch.__version__) < version.parse("1.13"):
|
|
||||||
# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
|
|
||||||
|
|
||||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
|
|
||||||
CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs),
|
|
||||||
lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps'))
|
|
||||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/80800
|
|
||||||
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs),
|
|
||||||
lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps')
|
|
||||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/90532
|
|
||||||
CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad)
|
|
||||||
elif version.parse(torch.__version__) > version.parse("1.13.1"):
|
|
||||||
cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0))
|
|
||||||
cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0))
|
|
||||||
cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs)
|
|
||||||
CondFunc('torch.cumsum', cumsum_fix_func, None)
|
|
||||||
CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)
|
|
||||||
CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None)
|
|
||||||
|
|
||||||
|
56
modules/mac_specific.py
Normal file
56
modules/mac_specific.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
import torch
|
||||||
|
from modules import paths
|
||||||
|
from modules.sd_hijack_utils import CondFunc
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
|
|
||||||
|
device = None
|
||||||
|
|
||||||
|
|
||||||
|
# has_mps is only available in nightly pytorch (for now) and macOS 12.3+.
|
||||||
|
# check `getattr` and try it for compatibility
|
||||||
|
def check_for_mps() -> bool:
|
||||||
|
if not getattr(torch, 'has_mps', False):
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
torch.zeros(1).to(torch.device("mps"))
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
has_mps = check_for_mps()
|
||||||
|
|
||||||
|
|
||||||
|
# MPS workaround for https://github.com/pytorch/pytorch/issues/89784
|
||||||
|
def cumsum_fix(input, cumsum_func, *args, **kwargs):
|
||||||
|
if input.device.type == 'mps':
|
||||||
|
output_dtype = kwargs.get('dtype', input.dtype)
|
||||||
|
if output_dtype == torch.int64:
|
||||||
|
return cumsum_func(input.cpu(), *args, **kwargs).to(input.device)
|
||||||
|
elif cumsum_needs_bool_fix and output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16):
|
||||||
|
return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64)
|
||||||
|
return cumsum_func(input, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
if has_mps:
|
||||||
|
# MPS fix for randn in torchsde
|
||||||
|
CondFunc('torchsde._brownian.brownian_interval._randn', lambda _, size, dtype, device, seed: torch.randn(size, dtype=dtype, device=torch.device("cpu"), generator=torch.Generator(torch.device("cpu")).manual_seed(int(seed))).to(device), lambda _, size, dtype, device, seed: device.type == 'mps')
|
||||||
|
|
||||||
|
if version.parse(torch.__version__) < version.parse("1.13"):
|
||||||
|
# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
|
||||||
|
|
||||||
|
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
|
||||||
|
CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs),
|
||||||
|
lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps'))
|
||||||
|
# MPS workaround for https://github.com/pytorch/pytorch/issues/80800
|
||||||
|
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs),
|
||||||
|
lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps')
|
||||||
|
# MPS workaround for https://github.com/pytorch/pytorch/issues/90532
|
||||||
|
CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad)
|
||||||
|
elif version.parse(torch.__version__) > version.parse("1.13.1"):
|
||||||
|
cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0))
|
||||||
|
cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0))
|
||||||
|
cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs)
|
||||||
|
CondFunc('torch.cumsum', cumsum_fix_func, None)
|
||||||
|
CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)
|
||||||
|
CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None)
|
||||||
|
|
@ -2,7 +2,6 @@ from collections import namedtuple
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import torchsde._brownian.brownian_interval
|
|
||||||
from modules import devices, processing, images, sd_vae_approx
|
from modules import devices, processing, images, sd_vae_approx
|
||||||
|
|
||||||
from modules.shared import opts, state
|
from modules.shared import opts, state
|
||||||
@ -61,18 +60,3 @@ def store_latent(decoded):
|
|||||||
|
|
||||||
class InterruptedException(BaseException):
|
class InterruptedException(BaseException):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
# MPS fix for randn in torchsde
|
|
||||||
# XXX move this to separate file for MPS
|
|
||||||
def torchsde_randn(size, dtype, device, seed):
|
|
||||||
if device.type == 'mps':
|
|
||||||
generator = torch.Generator(devices.cpu).manual_seed(int(seed))
|
|
||||||
return torch.randn(size, dtype=dtype, device=devices.cpu, generator=generator).to(device)
|
|
||||||
else:
|
|
||||||
generator = torch.Generator(device).manual_seed(int(seed))
|
|
||||||
return torch.randn(size, dtype=dtype, device=device, generator=generator)
|
|
||||||
|
|
||||||
|
|
||||||
torchsde._brownian.brownian_interval._randn = torchsde_randn
|
|
||||||
|
|
||||||
|
@ -145,6 +145,9 @@ devices.device, devices.device_interrogate, devices.device_gfpgan, devices.devic
|
|||||||
(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'esrgan', 'codeformer'])
|
(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'esrgan', 'codeformer'])
|
||||||
|
|
||||||
device = devices.device
|
device = devices.device
|
||||||
|
if sys.platform == "darwin":
|
||||||
|
from modules import mac_specific
|
||||||
|
mac_specific.device = device
|
||||||
weight_load_location = None if cmd_opts.lowram else "cpu"
|
weight_load_location = None if cmd_opts.lowram else "cpu"
|
||||||
|
|
||||||
batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram)
|
batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user