File size: 4,569 Bytes
430de99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# Copyright (c) Facebook, Inc. and its affiliates.

import os
import torch

from detectron2.utils.file_io import PathManager

from .torchscript_patch import patch_instances, patch_nonscriptable_classes


def export_torchscript_with_instances(model, fields):
    """
    Run :func:`torch.jit.script` on a model that uses the :class:`Instances` class. Since
    attributes of :class:`Instances` are "dynamically" added in eager mode,it is difficult
    for torchscript to support it out of the box. This function is made to support scripting
    a model that uses :class:`Instances`. It does the following:

    1. Create a scriptable ``new_Instances`` class which behaves similarly to ``Instances``,
       but with all attributes been "static".
       The attributes need to be statically declared in the ``fields`` argument.
    2. Register ``new_Instances`` to torchscript, and force torchscript to
       use it when trying to compile ``Instances``.

    After this function, the process will be reverted. User should be able to script another model
    using different fields.

    Example:
        Assume that ``Instances`` in the model consist of two attributes named
        ``proposal_boxes`` and ``objectness_logits`` with type :class:`Boxes` and
        :class:`Tensor` respectively during inference. You can call this function like:

        ::
            fields = {"proposal_boxes": Boxes, "objectness_logits": torch.Tensor}
            torchscipt_model =  export_torchscript_with_instances(model, fields)

    Note:
        Currently we only support models in evaluation mode.

    Args:
        model (nn.Module): The input model to be exported to torchscript.
        fields (Dict[str, type]): Attribute names and corresponding type that
            ``Instances`` will use in the model. Note that all attributes used in ``Instances``
            need to be added, regarldess of whether they are inputs/outputs of the model.
            Data type not defined in detectron2 is not supported for now.

    Returns:
        torch.jit.ScriptModule: the input model in torchscript format
    """
    patch_nonscriptable_classes()

    assert (
        not model.training
    ), "Currently we only support exporting models in evaluation mode to torchscript"

    with patch_instances(fields):
        scripted_model = torch.jit.script(model)
        return scripted_model


def dump_torchscript_IR(model, dir):
    """
    Dump IR of a TracedModule/ScriptModule at various levels.
    Useful for debugging.

    Args:
        model (TracedModule or ScriptModule): traced or scripted module
        dir (str): output directory to dump files.
    """
    PathManager.mkdirs(dir)

    def _get_script_mod(mod):
        if isinstance(mod, torch.jit.TracedModule):
            return mod._actual_script_module
        return mod

    # Dump pretty-printed code: https://pytorch.org/docs/stable/jit.html#inspecting-code
    with PathManager.open(os.path.join(dir, "model_ts_code.txt"), "w") as f:

        def get_code(mod):
            # Try a few ways to get code using private attributes.
            try:
                # This contains more information than just `mod.code`
                return _get_script_mod(mod)._c.code
            except AttributeError:
                pass
            try:
                return mod.code
            except AttributeError:
                return None

        def dump_code(prefix, mod):
            code = get_code(mod)
            name = prefix or "root model"
            if code is None:
                f.write(f"Could not found code for {name} (type={mod.original_name})\n")
                f.write("\n")
            else:
                f.write(f"\nCode for {name}, type={mod.original_name}:\n")
                f.write(code)
                f.write("\n")
                f.write("-" * 80)

            for name, m in mod.named_children():
                dump_code(prefix + "." + name, m)

        dump_code("", model)

    # Recursively dump IR of all modules
    with PathManager.open(os.path.join(dir, "model_ts_IR.txt"), "w") as f:
        try:
            f.write(_get_script_mod(model)._c.dump_to_str(True, False, False))
        except AttributeError:
            pass

    # Dump IR of the entire graph (all submodules inlined)
    with PathManager.open(os.path.join(dir, "model_ts_IR_inlined.txt"), "w") as f:
        f.write(str(model.inlined_graph))

    # Dump the model structure in pytorch style
    with PathManager.open(os.path.join(dir, "model.txt"), "w") as f:
        f.write(str(model))