File size: 5,204 Bytes
cf2a15a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Internal information about the mesh plugin."""

import dataclasses

from typing import Any

from tensorboard.compat.proto import summary_pb2
from tensorboard.plugins.mesh import plugin_data_pb2

PLUGIN_NAME = "mesh"

# The most recent value for the `version` field of the
# `MeshPluginData` proto.
_PROTO_VERSION = 0


@dataclasses.dataclass(frozen=True)
class MeshTensor:
    """A mesh tensor.

    Attributes:
      data: Tensor of shape `[dim_1, ..., dim_n, 3]` representing the mesh data
        of one of the following:
          - 3D coordinates of vertices
          - Indices of vertices within each triangle
          - Colors for each vertex
      content_type: Type of the mesh plugin data content.
      data_type: Data type of the elements in the tensor.
    """

    data: Any  # Expects type `tf.Tensor`, not specified here to avoid heavy TF dep.
    content_type: plugin_data_pb2.MeshPluginData.ContentType
    data_type: Any  # Expects type `tf.DType`, not specified here to avoid heavy TF dep.


def get_components_bitmask(content_types):
    """Creates bitmask for all existing components of the summary.

    Args:
      content_type: list of plugin_data_pb2.MeshPluginData.ContentType,
        representing all components related to the summary.
    Returns: bitmask based on passed tensors.
    """
    components = 0
    for content_type in content_types:
        if content_type == plugin_data_pb2.MeshPluginData.UNDEFINED:
            raise ValueError("Cannot include UNDEFINED content type in mask.")
        components = components | (1 << content_type)
    return components


def get_current_version():
    """Returns current verions of the proto."""
    return _PROTO_VERSION


def get_instance_name(name, content_type):
    """Returns a unique instance name for a given summary related to the
    mesh."""
    return "%s_%s" % (
        name,
        plugin_data_pb2.MeshPluginData.ContentType.Name(content_type),
    )


def create_summary_metadata(
    name,
    display_name,
    content_type,
    components,
    shape,
    description=None,
    json_config=None,
):
    """Creates summary metadata which defined at MeshPluginData proto.

    Arguments:
      name: Original merged (summaries of different types) summary name.
      display_name: The display name used in TensorBoard.
      content_type: Value from MeshPluginData.ContentType enum describing data.
      components: Bitmask representing present parts (vertices, colors, etc.) that
        belong to the summary.
      shape: list of dimensions sizes of the tensor.
      description: The description to show in TensorBoard.
      json_config: A string, JSON-serialized dictionary of ThreeJS classes
        configuration.

    Returns:
      A `summary_pb2.SummaryMetadata` protobuf object.
    """
    # Shape should be at least BxNx3 where B represents the batch dimensions
    # and N - the number of points, each with x,y,z coordinates.
    if len(shape) != 3:
        raise ValueError(
            "Tensor shape should be of shape BxNx3, but got %s." % str(shape)
        )
    mesh_plugin_data = plugin_data_pb2.MeshPluginData(
        version=get_current_version(),
        name=name,
        content_type=content_type,
        components=components,
        shape=shape,
        json_config=json_config,
    )
    content = mesh_plugin_data.SerializeToString()
    return summary_pb2.SummaryMetadata(
        display_name=display_name,  # Will not be used in TensorBoard UI.
        summary_description=description,
        plugin_data=summary_pb2.SummaryMetadata.PluginData(
            plugin_name=PLUGIN_NAME, content=content
        ),
    )


def parse_plugin_metadata(content):
    """Parse summary metadata to a Python object.

    Arguments:
      content: The `content` field of a `SummaryMetadata` proto
        corresponding to the mesh plugin.

    Returns:
      A `MeshPluginData` protobuf object.
    Raises: Error if the version of the plugin is not supported.
    """
    if not isinstance(content, bytes):
        raise TypeError("Content type must be bytes.")
    result = plugin_data_pb2.MeshPluginData.FromString(content)
    # Add components field to older version of the proto.
    if result.components == 0:
        result.components = get_components_bitmask(
            [
                plugin_data_pb2.MeshPluginData.VERTEX,
                plugin_data_pb2.MeshPluginData.FACE,
                plugin_data_pb2.MeshPluginData.COLOR,
            ]
        )
    return result