Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	| import contextlib | |
| import unittest | |
| import torch | |
| # wildcard trick is taken from pythongossss's | |
| class AnyType(str): | |
| def __ne__(self, __value: object) -> bool: | |
| return False | |
| any_typ = AnyType("*") | |
| def get_weight_dtype_inputs(): | |
| return { | |
| "weight_dtype": ( | |
| [ | |
| "default", | |
| "float32", | |
| "float64", | |
| "bfloat16", | |
| "float16", | |
| "fp8_e4m3fn", | |
| "fp8_e4m3fn_fast", | |
| "fp8_e5m2", | |
| ], | |
| ), | |
| } | |
| def parse_weight_dtype(model_options, weight_dtype): | |
| dtype = { | |
| "float32": torch.float32, | |
| "float64": torch.float64, | |
| "bfloat16": torch.bfloat16, | |
| "float16": torch.float16, | |
| "fp8_e4m3fn": torch.float8_e4m3fn, | |
| "fp8_e4m3fn_fast": torch.float8_e4m3fn, | |
| "fp8_e5m2": torch.float8_e5m2, | |
| }.get(weight_dtype, None) | |
| if dtype is not None: | |
| model_options["dtype"] = dtype | |
| if weight_dtype == "fp8_e4m3fn_fast": | |
| model_options["fp8_optimizations"] = True | |
| return model_options | |
| def disable_load_models_gpu(): | |
| def foo(*args, **kwargs): | |
| pass | |
| from modules.Device import Device | |
| with unittest.mock.patch.object(Device, "load_models_gpu", foo): | |
| yield | |
| def patch_optimized_module(): | |
| try: | |
| from torch._dynamo.eval_frame import OptimizedModule | |
| except ImportError: | |
| return | |
| if getattr(OptimizedModule, "_patched", False): | |
| return | |
| def __getattribute__(self, name): | |
| if name == "_orig_mod": | |
| return object.__getattribute__(self, "_modules")[name] | |
| if name in ( | |
| "__class__", | |
| "_modules", | |
| "state_dict", | |
| "load_state_dict", | |
| "parameters", | |
| "named_parameters", | |
| "buffers", | |
| "named_buffers", | |
| "children", | |
| "named_children", | |
| "modules", | |
| "named_modules", | |
| ): | |
| return getattr(object.__getattribute__(self, "_orig_mod"), name) | |
| return object.__getattribute__(self, name) | |
| def __delattr__(self, name): | |
| # unload_lora_weights() wants to del peft_config | |
| return delattr(self._orig_mod, name) | |
| def __instancecheck__(cls, instance): | |
| return isinstance(instance, OptimizedModule) or issubclass( | |
| object.__getattribute__(instance, "__class__"), cls | |
| ) | |
| OptimizedModule.__getattribute__ = __getattribute__ | |
| OptimizedModule.__delattr__ = __delattr__ | |
| OptimizedModule.__instancecheck__ = __instancecheck__ | |
| OptimizedModule._patched = True | |
| def patch_same_meta(): | |
| try: | |
| from torch._inductor.fx_passes import post_grad | |
| except ImportError: | |
| return | |
| same_meta = getattr(post_grad, "same_meta", None) | |
| if same_meta is None: | |
| return | |
| if getattr(same_meta, "_patched", False): | |
| return | |
| def new_same_meta(a, b): | |
| try: | |
| return same_meta(a, b) | |
| except Exception: | |
| return False | |
| post_grad.same_meta = new_same_meta | |
| new_same_meta._patched = True | |
