Aatricks's picture
Upload folder using huggingface_hub
d9a2e19 verified
import contextlib
import unittest
import torch
# wildcard trick is taken from pythongossss's
class AnyType(str):
def __ne__(self, __value: object) -> bool:
return False
any_typ = AnyType("*")
def get_weight_dtype_inputs():
return {
"weight_dtype": (
[
"default",
"float32",
"float64",
"bfloat16",
"float16",
"fp8_e4m3fn",
"fp8_e4m3fn_fast",
"fp8_e5m2",
],
),
}
def parse_weight_dtype(model_options, weight_dtype):
dtype = {
"float32": torch.float32,
"float64": torch.float64,
"bfloat16": torch.bfloat16,
"float16": torch.float16,
"fp8_e4m3fn": torch.float8_e4m3fn,
"fp8_e4m3fn_fast": torch.float8_e4m3fn,
"fp8_e5m2": torch.float8_e5m2,
}.get(weight_dtype, None)
if dtype is not None:
model_options["dtype"] = dtype
if weight_dtype == "fp8_e4m3fn_fast":
model_options["fp8_optimizations"] = True
return model_options
@contextlib.contextmanager
def disable_load_models_gpu():
def foo(*args, **kwargs):
pass
from modules.Device import Device
with unittest.mock.patch.object(Device, "load_models_gpu", foo):
yield
def patch_optimized_module():
try:
from torch._dynamo.eval_frame import OptimizedModule
except ImportError:
return
if getattr(OptimizedModule, "_patched", False):
return
def __getattribute__(self, name):
if name == "_orig_mod":
return object.__getattribute__(self, "_modules")[name]
if name in (
"__class__",
"_modules",
"state_dict",
"load_state_dict",
"parameters",
"named_parameters",
"buffers",
"named_buffers",
"children",
"named_children",
"modules",
"named_modules",
):
return getattr(object.__getattribute__(self, "_orig_mod"), name)
return object.__getattribute__(self, name)
def __delattr__(self, name):
# unload_lora_weights() wants to del peft_config
return delattr(self._orig_mod, name)
@classmethod
def __instancecheck__(cls, instance):
return isinstance(instance, OptimizedModule) or issubclass(
object.__getattribute__(instance, "__class__"), cls
)
OptimizedModule.__getattribute__ = __getattribute__
OptimizedModule.__delattr__ = __delattr__
OptimizedModule.__instancecheck__ = __instancecheck__
OptimizedModule._patched = True
def patch_same_meta():
try:
from torch._inductor.fx_passes import post_grad
except ImportError:
return
same_meta = getattr(post_grad, "same_meta", None)
if same_meta is None:
return
if getattr(same_meta, "_patched", False):
return
def new_same_meta(a, b):
try:
return same_meta(a, b)
except Exception:
return False
post_grad.same_meta = new_same_meta
new_same_meta._patched = True