File size: 11,217 Bytes
cf2a15a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TensorBoard 3D mesh visualizer plugin."""

import numpy as np
from werkzeug import wrappers

from tensorboard.backend import http_util
from tensorboard.data import provider
from tensorboard.plugins import base_plugin
from tensorboard.plugins.mesh import metadata
from tensorboard.plugins.mesh import plugin_data_pb2
from tensorboard import plugin_util

_DEFAULT_DOWNSAMPLING = 100  # meshes per time series


class MeshPlugin(base_plugin.TBPlugin):
    """A plugin that serves 3D visualization of meshes."""

    plugin_name = metadata.PLUGIN_NAME

    def __init__(self, context):
        """Instantiates a MeshPlugin via TensorBoard core.

        Args:
          context: A base_plugin.TBContext instance. A magic container that
            TensorBoard uses to make objects available to the plugin.
        """
        self._data_provider = context.data_provider
        self._downsample_to = (context.sampling_hints or {}).get(
            self.plugin_name, _DEFAULT_DOWNSAMPLING
        )
        self._version_checker = plugin_util._MetadataVersionChecker(
            data_kind="mesh",
            latest_known_version=0,
        )

    def _instance_tag_metadata(self, ctx, experiment, run, instance_tag):
        """Gets the `MeshPluginData` proto for an instance tag."""
        results = self._data_provider.list_tensors(
            ctx,
            experiment_id=experiment,
            plugin_name=metadata.PLUGIN_NAME,
            run_tag_filter=provider.RunTagFilter(
                runs=[run], tags=[instance_tag]
            ),
        )
        content = results[run][instance_tag].plugin_content
        return metadata.parse_plugin_metadata(content)

    def _tag(self, ctx, experiment, run, instance_tag):
        """Gets the user-facing tag name for an instance tag."""
        return self._instance_tag_metadata(
            ctx, experiment, run, instance_tag
        ).name

    def _instance_tags(self, ctx, experiment, run, tag):
        """Gets the instance tag names for a user-facing tag."""
        index = self._data_provider.list_tensors(
            ctx,
            experiment_id=experiment,
            plugin_name=metadata.PLUGIN_NAME,
            run_tag_filter=provider.RunTagFilter(runs=[run]),
        )
        return [
            instance_tag
            for (instance_tag, ts) in index.get(run, {}).items()
            if tag == metadata.parse_plugin_metadata(ts.plugin_content).name
        ]

    @wrappers.Request.application
    def _serve_tags(self, request):
        """A route (HTTP handler) that returns a response with tags.

        Args:
          request: The werkzeug.Request object.

        Returns:
          A response that contains a JSON object. The keys of the object
          are all the runs. Each run is mapped to a (potentially empty)
          list of all tags that are relevant to this plugin.
        """
        ctx = plugin_util.context(request.environ)
        experiment = plugin_util.experiment_id(request.environ)
        all_runs = self._data_provider.list_tensors(
            ctx,
            experiment_id=experiment,
            plugin_name=metadata.PLUGIN_NAME,
        )

        # tagToContent is itself a dictionary mapping tag name to string
        # SummaryMetadata.plugin_data.content. Retrieve the keys of that dictionary
        # to obtain a list of tags associated with each run. For each tag estimate
        # number of samples.
        response = dict()
        for run, tags in all_runs.items():
            response[run] = dict()
            for instance_tag, metadatum in tags.items():
                md = metadata.parse_plugin_metadata(metadatum.plugin_content)
                if not self._version_checker.ok(md.version, run, instance_tag):
                    continue
                # Make sure we only operate on user-defined tags here.
                tag = self._tag(ctx, experiment, run, instance_tag)
                meta = self._instance_tag_metadata(
                    ctx, experiment, run, instance_tag
                )
                # Batch size must be defined, otherwise we don't know how many
                # samples were there.
                response[run][tag] = {"samples": meta.shape[0]}
        return http_util.Respond(request, response, "application/json")

    def get_plugin_apps(self):
        """Gets all routes offered by the plugin.

        This method is called by TensorBoard when retrieving all the
        routes offered by the plugin.

        Returns:
          A dictionary mapping URL path to route that handles it.
        """
        # Note that the methods handling routes are decorated with
        # @wrappers.Request.application.
        return {
            "/tags": self._serve_tags,
            "/meshes": self._serve_mesh_metadata,
            "/data": self._serve_mesh_data,
        }

    def is_active(self):
        return False  # `list_plugins` as called by TB core suffices

    def frontend_metadata(self):
        return base_plugin.FrontendMetadata(element_name="mesh-dashboard")

    def _get_sample(self, tensor_datum, sample):
        """Returns a single sample from a batch of samples."""
        return tensor_datum.numpy[sample].tolist()

    def _get_tensor_metadata(
        self, event, content_type, components, data_shape, config
    ):
        """Converts a TensorDatum into a JSON-compatible response.

        Args:
          event: TensorDatum object containing data in proto format.
          content_type: enum plugin_data_pb2.MeshPluginData.ContentType value,
            representing content type in TensorDatum.
          components: Bitmask representing all parts (vertices, colors, etc.) that
            belong to the summary.
          data_shape: list of dimensions sizes of the tensor.
          config: rendering scene configuration as dictionary.

        Returns:
          Dictionary of transformed metadata.
        """
        return {
            "wall_time": event.wall_time,
            "step": event.step,
            "content_type": content_type,
            "components": components,
            "config": config,
            "data_shape": list(data_shape),
        }

    def _get_tensor_data(self, event, sample):
        """Convert a TensorDatum into a JSON-compatible response."""
        data = self._get_sample(event, sample)
        return data

    def _collect_tensor_events(self, request, step=None):
        """Collects list of tensor events based on request."""
        ctx = plugin_util.context(request.environ)
        experiment = plugin_util.experiment_id(request.environ)
        run = request.args.get("run")
        tag = request.args.get("tag")

        tensor_events = []  # List of tuples (meta, tensor) that contain tag.
        for instance_tag in self._instance_tags(ctx, experiment, run, tag):
            tensors = self._data_provider.read_tensors(
                ctx,
                experiment_id=experiment,
                plugin_name=metadata.PLUGIN_NAME,
                run_tag_filter=provider.RunTagFilter(
                    runs=[run], tags=[instance_tag]
                ),
                downsample=self._downsample_to,
            )[run][instance_tag]
            meta = self._instance_tag_metadata(
                ctx, experiment, run, instance_tag
            )
            tensor_events += [(meta, tensor) for tensor in tensors]

        if step is not None:
            tensor_events = [
                event for event in tensor_events if event[1].step == step
            ]
        else:
            # Make sure tensors sorted by step in ascending order.
            tensor_events = sorted(
                tensor_events, key=lambda tensor_data: tensor_data[1].step
            )

        return tensor_events

    @wrappers.Request.application
    def _serve_mesh_data(self, request):
        """A route that returns data for particular summary of specified type.

        Data can represent vertices coordinates, vertices indices in faces,
        vertices colors and so on. Each mesh may have different combination of
        abovementioned data and each type/part of mesh summary must be served as
        separate roundtrip to the server.

        Args:
          request: werkzeug.Request containing content_type as a name of enum
            plugin_data_pb2.MeshPluginData.ContentType.

        Returns:
          werkzeug.Response either float32 or int32 data in binary format.
        """
        step = float(request.args.get("step", 0.0))
        tensor_events = self._collect_tensor_events(request, step)
        content_type = request.args.get("content_type")
        try:
            content_type = plugin_data_pb2.MeshPluginData.ContentType.Value(
                content_type
            )
        except ValueError:
            return http_util.Respond(
                request, "Bad content_type", "text/plain", 400
            )
        sample = int(request.args.get("sample", 0))

        response = [
            self._get_tensor_data(tensor, sample)
            for meta, tensor in tensor_events
            if meta.content_type == content_type
        ]

        np_type = {
            plugin_data_pb2.MeshPluginData.VERTEX: np.float32,
            plugin_data_pb2.MeshPluginData.FACE: np.int32,
            plugin_data_pb2.MeshPluginData.COLOR: np.uint8,
        }[content_type]

        response = np.array(response, dtype=np_type)
        # Looks like reshape can take around 160ms, so why not store it reshaped.
        response = response.reshape(-1).tobytes()

        return http_util.Respond(request, response, "arraybuffer")

    @wrappers.Request.application
    def _serve_mesh_metadata(self, request):
        """A route that returns the mesh metadata associated with a tag.

        Metadata consists of wall time, type of elements in tensor, scene
        configuration and so on.

        Args:
          request: The werkzeug.Request object.

        Returns:
          A JSON list of mesh data associated with the run and tag
          combination.
        """
        tensor_events = self._collect_tensor_events(request)

        # We convert the tensor data to text.
        response = [
            self._get_tensor_metadata(
                tensor,
                meta.content_type,
                meta.components,
                meta.shape,
                meta.json_config,
            )
            for meta, tensor in tensor_events
        ]
        return http_util.Respond(request, response, "application/json")