Spaces:
Running
on
Zero
Running
on
Zero
import contextlib | |
import unittest | |
import torch | |
# wildcard trick is taken from pythongossss's | |
class AnyType(str): | |
def __ne__(self, __value: object) -> bool: | |
return False | |
any_typ = AnyType("*") | |
def get_weight_dtype_inputs(): | |
return { | |
"weight_dtype": ( | |
[ | |
"default", | |
"float32", | |
"float64", | |
"bfloat16", | |
"float16", | |
"fp8_e4m3fn", | |
"fp8_e4m3fn_fast", | |
"fp8_e5m2", | |
], | |
), | |
} | |
def parse_weight_dtype(model_options, weight_dtype): | |
dtype = { | |
"float32": torch.float32, | |
"float64": torch.float64, | |
"bfloat16": torch.bfloat16, | |
"float16": torch.float16, | |
"fp8_e4m3fn": torch.float8_e4m3fn, | |
"fp8_e4m3fn_fast": torch.float8_e4m3fn, | |
"fp8_e5m2": torch.float8_e5m2, | |
}.get(weight_dtype, None) | |
if dtype is not None: | |
model_options["dtype"] = dtype | |
if weight_dtype == "fp8_e4m3fn_fast": | |
model_options["fp8_optimizations"] = True | |
return model_options | |
def disable_load_models_gpu(): | |
def foo(*args, **kwargs): | |
pass | |
from modules.Device import Device | |
with unittest.mock.patch.object(Device, "load_models_gpu", foo): | |
yield | |
def patch_optimized_module(): | |
try: | |
from torch._dynamo.eval_frame import OptimizedModule | |
except ImportError: | |
return | |
if getattr(OptimizedModule, "_patched", False): | |
return | |
def __getattribute__(self, name): | |
if name == "_orig_mod": | |
return object.__getattribute__(self, "_modules")[name] | |
if name in ( | |
"__class__", | |
"_modules", | |
"state_dict", | |
"load_state_dict", | |
"parameters", | |
"named_parameters", | |
"buffers", | |
"named_buffers", | |
"children", | |
"named_children", | |
"modules", | |
"named_modules", | |
): | |
return getattr(object.__getattribute__(self, "_orig_mod"), name) | |
return object.__getattribute__(self, name) | |
def __delattr__(self, name): | |
# unload_lora_weights() wants to del peft_config | |
return delattr(self._orig_mod, name) | |
def __instancecheck__(cls, instance): | |
return isinstance(instance, OptimizedModule) or issubclass( | |
object.__getattribute__(instance, "__class__"), cls | |
) | |
OptimizedModule.__getattribute__ = __getattribute__ | |
OptimizedModule.__delattr__ = __delattr__ | |
OptimizedModule.__instancecheck__ = __instancecheck__ | |
OptimizedModule._patched = True | |
def patch_same_meta(): | |
try: | |
from torch._inductor.fx_passes import post_grad | |
except ImportError: | |
return | |
same_meta = getattr(post_grad, "same_meta", None) | |
if same_meta is None: | |
return | |
if getattr(same_meta, "_patched", False): | |
return | |
def new_same_meta(a, b): | |
try: | |
return same_meta(a, b) | |
except Exception: | |
return False | |
post_grad.same_meta = new_same_meta | |
new_same_meta._patched = True | |