Spaces:
Running
on
Zero
Running
on
Zero
import importlib | |
import json | |
from . import utils | |
class EnhancedCompileModel: | |
def patch( | |
self, | |
model, | |
is_patcher, | |
object_to_patch, | |
compiler, | |
fullgraph, | |
dynamic, | |
mode, | |
options, | |
disable, | |
backend, | |
): | |
utils.patch_optimized_module() | |
utils.patch_same_meta() | |
import_path, function_name = compiler.rsplit(".", 1) | |
module = importlib.import_module(import_path) | |
compile_function = getattr(module, function_name) | |
mode = mode if mode else None | |
options = json.loads(options) if options else None | |
if compiler == "torch.compile" and backend == "inductor" and dynamic: | |
# TODO: Fix this | |
# File "pytorch/torch/_inductor/fx_passes/post_grad.py", line 643, in same_meta | |
# and statically_known_true(sym_eq(val1.size(), val2.size())) | |
# AttributeError: 'SymInt' object has no attribute 'size' | |
pass | |
if is_patcher: | |
patcher = model[0].clone() | |
else: | |
patcher = model.patcher | |
patcher = patcher.clone() | |
patcher.add_object_patch( | |
object_to_patch, | |
compile_function( | |
patcher.get_model_object(object_to_patch), | |
fullgraph=fullgraph, | |
dynamic=dynamic, | |
mode=mode, | |
options=options, | |
disable=disable, | |
backend=backend, | |
), | |
) | |
if is_patcher: | |
return (patcher,) | |
else: | |
model.patcher = patcher | |
return (model,) | |