File size: 1,723 Bytes
d9a2e19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import importlib
import json

from . import utils


class EnhancedCompileModel:

    def patch(

        self,

        model,

        is_patcher,

        object_to_patch,

        compiler,

        fullgraph,

        dynamic,

        mode,

        options,

        disable,

        backend,

    ):
        utils.patch_optimized_module()
        utils.patch_same_meta()

        import_path, function_name = compiler.rsplit(".", 1)
        module = importlib.import_module(import_path)
        compile_function = getattr(module, function_name)

        mode = mode if mode else None
        options = json.loads(options) if options else None

        if compiler == "torch.compile" and backend == "inductor" and dynamic:
            # TODO: Fix this
            # File "pytorch/torch/_inductor/fx_passes/post_grad.py", line 643, in same_meta
            #   and statically_known_true(sym_eq(val1.size(), val2.size()))
            #   AttributeError: 'SymInt' object has no attribute 'size'
            pass

        if is_patcher:
            patcher = model[0].clone()
        else:
            patcher = model.patcher
            patcher = patcher.clone()

        patcher.add_object_patch(
            object_to_patch,
            compile_function(
                patcher.get_model_object(object_to_patch),
                fullgraph=fullgraph,
                dynamic=dynamic,
                mode=mode,
                options=options,
                disable=disable,
                backend=backend,
            ),
        )

        if is_patcher:
            return (patcher,)
        else:
            model.patcher = patcher
            return (model,)