Add support for PyTorch nightly and local builds
This commit is contained in:
parent
3bd737767b
commit
8111b5569d
@ -133,8 +133,26 @@ def numpy_fix(self, *args, **kwargs):
|
|||||||
return orig_tensor_numpy(self, *args, **kwargs)
|
return orig_tensor_numpy(self, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
|
# MPS workaround for https://github.com/pytorch/pytorch/issues/89784
|
||||||
if has_mps() and version.parse(torch.__version__) < version.parse("1.13"):
|
orig_cumsum = torch.cumsum
|
||||||
torch.Tensor.to = tensor_to_fix
|
orig_Tensor_cumsum = torch.Tensor.cumsum
|
||||||
torch.nn.functional.layer_norm = layer_norm_fix
|
def cumsum_fix(input, cumsum_func, *args, **kwargs):
|
||||||
torch.Tensor.numpy = numpy_fix
|
if input.device.type == 'mps':
|
||||||
|
output_dtype = kwargs.get('dtype', input.dtype)
|
||||||
|
if any(output_dtype == broken_dtype for broken_dtype in [torch.bool, torch.int8, torch.int16, torch.int64]):
|
||||||
|
return cumsum_func(input.cpu(), *args, **kwargs).to(input.device)
|
||||||
|
return cumsum_func(input, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
if has_mps():
|
||||||
|
if version.parse(torch.__version__) < version.parse("1.13"):
|
||||||
|
# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
|
||||||
|
torch.Tensor.to = tensor_to_fix
|
||||||
|
torch.nn.functional.layer_norm = layer_norm_fix
|
||||||
|
torch.Tensor.numpy = numpy_fix
|
||||||
|
elif version.parse(torch.__version__) > version.parse("1.13.1"):
|
||||||
|
if not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.Tensor([1,1]).to(torch.device("mps")).cumsum(0, dtype=torch.int16)):
|
||||||
|
torch.cumsum = lambda input, *args, **kwargs: ( cumsum_fix(input, orig_cumsum, *args, **kwargs) )
|
||||||
|
torch.Tensor.cumsum = lambda self, *args, **kwargs: ( cumsum_fix(self, orig_Tensor_cumsum, *args, **kwargs) )
|
||||||
|
orig_narrow = torch.narrow
|
||||||
|
torch.narrow = lambda *args, **kwargs: ( orig_narrow(*args, **kwargs).clone() )
|
||||||
|
7
webui.py
7
webui.py
@ -4,7 +4,7 @@ import threading
|
|||||||
import time
|
import time
|
||||||
import importlib
|
import importlib
|
||||||
import signal
|
import signal
|
||||||
import threading
|
import re
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.middleware.gzip import GZipMiddleware
|
from fastapi.middleware.gzip import GZipMiddleware
|
||||||
@ -13,6 +13,11 @@ from modules import import_hook, errors
|
|||||||
from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
|
from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
|
||||||
from modules.paths import script_path
|
from modules.paths import script_path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
|
||||||
|
if ".dev" in torch.__version__ or "+git" in torch.__version__:
|
||||||
|
torch.__version__ = re.search(r'[\d.]+', torch.__version__).group(0)
|
||||||
|
|
||||||
from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir
|
from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir
|
||||||
import modules.codeformer_model as codeformer
|
import modules.codeformer_model as codeformer
|
||||||
import modules.extras
|
import modules.extras
|
||||||
|
Loading…
Reference in New Issue
Block a user