Spaces:
Runtime error
Runtime error
| import sys, os, shlex | |
| import contextlib | |
| import torch | |
| from modules import errors | |
| from packaging import version | |
| # has_mps is only available in nightly pytorch (for now) and macOS 12.3+. | |
| # check `getattr` and try it for compatibility | |
| def has_mps() -> bool: | |
| if not getattr(torch, 'has_mps', False): | |
| return False | |
| try: | |
| torch.zeros(1).to(torch.device("mps")) | |
| return True | |
| except Exception: | |
| return False | |
| def extract_device_id(args, name): | |
| for x in range(len(args)): | |
| if name in args[x]: | |
| return args[x + 1] | |
| return None | |
| def get_cuda_device_string(): | |
| from modules import shared | |
| if shared.cmd_opts.device_id is not None: | |
| return f"cuda:{shared.cmd_opts.device_id}" | |
| return "cuda" | |
| def get_optimal_device(): | |
| if torch.cuda.is_available(): | |
| return torch.device(get_cuda_device_string()) | |
| if has_mps(): | |
| return torch.device("mps") | |
| return cpu | |
| def torch_gc(): | |
| if torch.cuda.is_available(): | |
| with torch.cuda.device(get_cuda_device_string()): | |
| torch.cuda.empty_cache() | |
| torch.cuda.ipc_collect() | |
| def enable_tf32(): | |
| if torch.cuda.is_available(): | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| errors.run(enable_tf32, "Enabling TF32") | |
| cpu = torch.device("cpu") | |
| device = device_interrogate = device_gfpgan = device_swinir = device_esrgan = device_scunet = device_codeformer = None | |
| dtype = torch.float16 | |
| dtype_vae = torch.float16 | |
| def randn(seed, shape): | |
| # Pytorch currently doesn't handle setting randomness correctly when the metal backend is used. | |
| if device.type == 'mps': | |
| generator = torch.Generator(device=cpu) | |
| generator.manual_seed(seed) | |
| noise = torch.randn(shape, generator=generator, device=cpu).to(device) | |
| return noise | |
| torch.manual_seed(seed) | |
| return torch.randn(shape, device=device) | |
| def randn_without_seed(shape): | |
| # Pytorch currently doesn't handle setting randomness correctly when the metal backend is used. | |
| if device.type == 'mps': | |
| generator = torch.Generator(device=cpu) | |
| noise = torch.randn(shape, generator=generator, device=cpu).to(device) | |
| return noise | |
| return torch.randn(shape, device=device) | |
| def autocast(disable=False): | |
| from modules import shared | |
| if disable: | |
| return contextlib.nullcontext() | |
| if dtype == torch.float32 or shared.cmd_opts.precision == "full": | |
| return contextlib.nullcontext() | |
| return torch.autocast("cuda") | |
| # MPS workaround for https://github.com/pytorch/pytorch/issues/79383 | |
| orig_tensor_to = torch.Tensor.to | |
| def tensor_to_fix(self, *args, **kwargs): | |
| if self.device.type != 'mps' and \ | |
| ((len(args) > 0 and isinstance(args[0], torch.device) and args[0].type == 'mps') or \ | |
| (isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps')): | |
| self = self.contiguous() | |
| return orig_tensor_to(self, *args, **kwargs) | |
| # MPS workaround for https://github.com/pytorch/pytorch/issues/80800 | |
| orig_layer_norm = torch.nn.functional.layer_norm | |
| def layer_norm_fix(*args, **kwargs): | |
| if len(args) > 0 and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps': | |
| args = list(args) | |
| args[0] = args[0].contiguous() | |
| return orig_layer_norm(*args, **kwargs) | |
| # PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working | |
| if has_mps() and version.parse(torch.__version__) < version.parse("1.13"): | |
| torch.Tensor.to = tensor_to_fix | |
| torch.nn.functional.layer_norm = layer_norm_fix | |