Spaces:
Runtime error
Runtime error
import sys, os, shlex | |
import contextlib | |
import torch | |
from modules import errors | |
from packaging import version | |
# has_mps is only available in nightly pytorch (for now) and macOS 12.3+. | |
# check `getattr` and try it for compatibility | |
def has_mps() -> bool: | |
if not getattr(torch, 'has_mps', False): | |
return False | |
try: | |
torch.zeros(1).to(torch.device("mps")) | |
return True | |
except Exception: | |
return False | |
def extract_device_id(args, name): | |
for x in range(len(args)): | |
if name in args[x]: | |
return args[x + 1] | |
return None | |
def get_cuda_device_string(): | |
from modules import shared | |
if shared.cmd_opts.device_id is not None: | |
return f"cuda:{shared.cmd_opts.device_id}" | |
return "cuda" | |
def get_optimal_device(): | |
if torch.cuda.is_available(): | |
return torch.device(get_cuda_device_string()) | |
if has_mps(): | |
return torch.device("mps") | |
return cpu | |
def torch_gc(): | |
if torch.cuda.is_available(): | |
with torch.cuda.device(get_cuda_device_string()): | |
torch.cuda.empty_cache() | |
torch.cuda.ipc_collect() | |
def enable_tf32(): | |
if torch.cuda.is_available(): | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.allow_tf32 = True | |
errors.run(enable_tf32, "Enabling TF32") | |
cpu = torch.device("cpu") | |
device = device_interrogate = device_gfpgan = device_swinir = device_esrgan = device_scunet = device_codeformer = None | |
dtype = torch.float16 | |
dtype_vae = torch.float16 | |
def randn(seed, shape): | |
# Pytorch currently doesn't handle setting randomness correctly when the metal backend is used. | |
if device.type == 'mps': | |
generator = torch.Generator(device=cpu) | |
generator.manual_seed(seed) | |
noise = torch.randn(shape, generator=generator, device=cpu).to(device) | |
return noise | |
torch.manual_seed(seed) | |
return torch.randn(shape, device=device) | |
def randn_without_seed(shape): | |
# Pytorch currently doesn't handle setting randomness correctly when the metal backend is used. | |
if device.type == 'mps': | |
generator = torch.Generator(device=cpu) | |
noise = torch.randn(shape, generator=generator, device=cpu).to(device) | |
return noise | |
return torch.randn(shape, device=device) | |
def autocast(disable=False): | |
from modules import shared | |
if disable: | |
return contextlib.nullcontext() | |
if dtype == torch.float32 or shared.cmd_opts.precision == "full": | |
return contextlib.nullcontext() | |
return torch.autocast("cuda") | |
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383 | |
orig_tensor_to = torch.Tensor.to | |
def tensor_to_fix(self, *args, **kwargs): | |
if self.device.type != 'mps' and \ | |
((len(args) > 0 and isinstance(args[0], torch.device) and args[0].type == 'mps') or \ | |
(isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps')): | |
self = self.contiguous() | |
return orig_tensor_to(self, *args, **kwargs) | |
# MPS workaround for https://github.com/pytorch/pytorch/issues/80800 | |
orig_layer_norm = torch.nn.functional.layer_norm | |
def layer_norm_fix(*args, **kwargs): | |
if len(args) > 0 and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps': | |
args = list(args) | |
args[0] = args[0].contiguous() | |
return orig_layer_norm(*args, **kwargs) | |
# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working | |
if has_mps() and version.parse(torch.__version__) < version.parse("1.13"): | |
torch.Tensor.to = tensor_to_fix | |
torch.nn.functional.layer_norm = layer_norm_fix | |