Spaces:
Sleeping
Sleeping
# Copyright 2017 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""The TensorBoard Graphs plugin.""" | |
import json | |
from werkzeug import wrappers | |
from tensorboard import errors | |
from tensorboard import plugin_util | |
from tensorboard.backend import http_util | |
from tensorboard.backend import process_graph | |
from tensorboard.compat.proto import config_pb2 | |
from tensorboard.compat.proto import graph_pb2 | |
from tensorboard.data import provider | |
from tensorboard.plugins import base_plugin | |
from tensorboard.plugins.graph import graph_util | |
from tensorboard.plugins.graph import keras_util | |
from tensorboard.plugins.graph import metadata | |
from tensorboard.util import tb_logging | |
logger = tb_logging.get_logger() | |
class GraphsPlugin(base_plugin.TBPlugin): | |
"""Graphs Plugin for TensorBoard.""" | |
plugin_name = metadata.PLUGIN_NAME | |
def __init__(self, context): | |
"""Instantiates GraphsPlugin via TensorBoard core. | |
Args: | |
context: A base_plugin.TBContext instance. | |
""" | |
self._data_provider = context.data_provider | |
def get_plugin_apps(self): | |
return { | |
"/graph": self.graph_route, | |
"/info": self.info_route, | |
"/run_metadata": self.run_metadata_route, | |
} | |
def is_active(self): | |
"""The graphs plugin is active iff any run has a graph or metadata.""" | |
return False # `list_plugins` as called by TB core suffices | |
def data_plugin_names(self): | |
return ( | |
metadata.PLUGIN_NAME, | |
metadata.PLUGIN_NAME_RUN_METADATA, | |
metadata.PLUGIN_NAME_RUN_METADATA_WITH_GRAPH, | |
metadata.PLUGIN_NAME_KERAS_MODEL, | |
metadata.PLUGIN_NAME_TAGGED_RUN_METADATA, | |
) | |
def frontend_metadata(self): | |
return base_plugin.FrontendMetadata( | |
element_name="tf-graph-dashboard", | |
# TODO(@chihuahua): Reconcile this setting with Health Pills. | |
disable_reload=True, | |
) | |
def info_impl(self, ctx, experiment=None): | |
"""Returns a dict of all runs and their data availabilities.""" | |
result = {} | |
def add_row_item(run, tag=None): | |
run_item = result.setdefault( | |
run, | |
{ | |
"run": run, | |
"tags": {}, | |
# A run-wide GraphDef of ops. | |
"run_graph": False, | |
}, | |
) | |
tag_item = None | |
if tag: | |
tag_item = run_item.get("tags").setdefault( | |
tag, | |
{ | |
"tag": tag, | |
"conceptual_graph": False, | |
# A tagged GraphDef of ops. | |
"op_graph": False, | |
"profile": False, | |
}, | |
) | |
return (run_item, tag_item) | |
mapping = self._data_provider.list_blob_sequences( | |
ctx, | |
experiment_id=experiment, | |
plugin_name=metadata.PLUGIN_NAME_RUN_METADATA_WITH_GRAPH, | |
) | |
for run_name, tags in mapping.items(): | |
for tag, tag_data in tags.items(): | |
# The Summary op is defined in TensorFlow and does not use a stringified proto | |
# as a content of plugin data. It contains single string that denotes a version. | |
# https://github.com/tensorflow/tensorflow/blob/11f4ecb54708865ec757ca64e4805957b05d7570/tensorflow/python/ops/summary_ops_v2.py#L789-L790 | |
if tag_data.plugin_content != b"1": | |
logger.warning( | |
"Ignoring unrecognizable version of RunMetadata." | |
) | |
continue | |
(_, tag_item) = add_row_item(run_name, tag) | |
tag_item["op_graph"] = True | |
# Tensors associated with plugin name metadata.PLUGIN_NAME_RUN_METADATA | |
# contain both op graph and profile information. | |
mapping = self._data_provider.list_blob_sequences( | |
ctx, | |
experiment_id=experiment, | |
plugin_name=metadata.PLUGIN_NAME_RUN_METADATA, | |
) | |
for run_name, tags in mapping.items(): | |
for tag, tag_data in tags.items(): | |
if tag_data.plugin_content != b"1": | |
logger.warning( | |
"Ignoring unrecognizable version of RunMetadata." | |
) | |
continue | |
(_, tag_item) = add_row_item(run_name, tag) | |
tag_item["profile"] = True | |
tag_item["op_graph"] = True | |
# Tensors associated with plugin name metadata.PLUGIN_NAME_KERAS_MODEL | |
# contain serialized Keras model in JSON format. | |
mapping = self._data_provider.list_blob_sequences( | |
ctx, | |
experiment_id=experiment, | |
plugin_name=metadata.PLUGIN_NAME_KERAS_MODEL, | |
) | |
for run_name, tags in mapping.items(): | |
for tag, tag_data in tags.items(): | |
if tag_data.plugin_content != b"1": | |
logger.warning( | |
"Ignoring unrecognizable version of RunMetadata." | |
) | |
continue | |
(_, tag_item) = add_row_item(run_name, tag) | |
tag_item["conceptual_graph"] = True | |
mapping = self._data_provider.list_blob_sequences( | |
ctx, | |
experiment_id=experiment, | |
plugin_name=metadata.PLUGIN_NAME, | |
) | |
for run_name, tags in mapping.items(): | |
if metadata.RUN_GRAPH_NAME in tags: | |
(run_item, _) = add_row_item(run_name, None) | |
run_item["run_graph"] = True | |
# Top level `Event.tagged_run_metadata` represents profile data only. | |
mapping = self._data_provider.list_blob_sequences( | |
ctx, | |
experiment_id=experiment, | |
plugin_name=metadata.PLUGIN_NAME_TAGGED_RUN_METADATA, | |
) | |
for run_name, tags in mapping.items(): | |
for tag in tags: | |
(_, tag_item) = add_row_item(run_name, tag) | |
tag_item["profile"] = True | |
return result | |
def _read_blob(self, ctx, experiment, plugin_names, run, tag): | |
for plugin_name in plugin_names: | |
blob_sequences = self._data_provider.read_blob_sequences( | |
ctx, | |
experiment_id=experiment, | |
plugin_name=plugin_name, | |
run_tag_filter=provider.RunTagFilter(runs=[run], tags=[tag]), | |
downsample=1, | |
) | |
blob_sequence_data = blob_sequences.get(run, {}).get(tag, ()) | |
try: | |
blob_ref = blob_sequence_data[0].values[0] | |
except IndexError: | |
continue | |
return self._data_provider.read_blob( | |
ctx, blob_key=blob_ref.blob_key | |
) | |
raise errors.NotFoundError() | |
def graph_impl( | |
self, | |
ctx, | |
run, | |
tag, | |
is_conceptual, | |
experiment=None, | |
limit_attr_size=None, | |
large_attrs_key=None, | |
): | |
"""Result of the form `(body, mime_type)`; may raise `NotFound`.""" | |
if is_conceptual: | |
keras_model_config = json.loads( | |
self._read_blob( | |
ctx, | |
experiment, | |
[metadata.PLUGIN_NAME_KERAS_MODEL], | |
run, | |
tag, | |
) | |
) | |
graph = keras_util.keras_model_to_graph_def(keras_model_config) | |
elif tag is None: | |
graph_raw = self._read_blob( | |
ctx, | |
experiment, | |
[metadata.PLUGIN_NAME], | |
run, | |
metadata.RUN_GRAPH_NAME, | |
) | |
graph = graph_pb2.GraphDef.FromString(graph_raw) | |
else: | |
# Op graph: could be either of two plugins. (Cf. `info_impl`.) | |
plugins = [ | |
metadata.PLUGIN_NAME_RUN_METADATA, | |
metadata.PLUGIN_NAME_RUN_METADATA_WITH_GRAPH, | |
] | |
raw_run_metadata = self._read_blob( | |
ctx, experiment, plugins, run, tag | |
) | |
run_metadata = config_pb2.RunMetadata.FromString(raw_run_metadata) | |
graph = graph_util.merge_graph_defs( | |
[ | |
func_graph.pre_optimization_graph | |
for func_graph in run_metadata.function_graphs | |
] | |
) | |
# This next line might raise a ValueError if the limit parameters | |
# are invalid (size is negative, size present but key absent, etc.). | |
process_graph.prepare_graph_for_ui( | |
graph, limit_attr_size, large_attrs_key | |
) | |
return (str(graph), "text/x-protobuf") # pbtxt | |
def run_metadata_impl(self, ctx, experiment, run, tag): | |
"""Result of the form `(body, mime_type)`; may raise `NotFound`.""" | |
# Profile graph: could be either of two plugins. (Cf. `info_impl`.) | |
plugins = [ | |
metadata.PLUGIN_NAME_TAGGED_RUN_METADATA, | |
metadata.PLUGIN_NAME_RUN_METADATA, | |
] | |
raw_run_metadata = self._read_blob(ctx, experiment, plugins, run, tag) | |
run_metadata = config_pb2.RunMetadata.FromString(raw_run_metadata) | |
return (str(run_metadata), "text/x-protobuf") # pbtxt | |
def info_route(self, request): | |
ctx = plugin_util.context(request.environ) | |
experiment = plugin_util.experiment_id(request.environ) | |
info = self.info_impl(ctx, experiment) | |
return http_util.Respond(request, info, "application/json") | |
def graph_route(self, request): | |
"""Given a single run, return the graph definition in protobuf | |
format.""" | |
ctx = plugin_util.context(request.environ) | |
experiment = plugin_util.experiment_id(request.environ) | |
run = request.args.get("run") | |
tag = request.args.get("tag") | |
conceptual_arg = request.args.get("conceptual", False) | |
is_conceptual = True if conceptual_arg == "true" else False | |
if run is None: | |
return http_util.Respond( | |
request, 'query parameter "run" is required', "text/plain", 400 | |
) | |
limit_attr_size = request.args.get("limit_attr_size", None) | |
if limit_attr_size is not None: | |
try: | |
limit_attr_size = int(limit_attr_size) | |
except ValueError: | |
return http_util.Respond( | |
request, | |
"query parameter `limit_attr_size` must be an integer", | |
"text/plain", | |
400, | |
) | |
large_attrs_key = request.args.get("large_attrs_key", None) | |
try: | |
result = self.graph_impl( | |
ctx, | |
run, | |
tag, | |
is_conceptual, | |
experiment, | |
limit_attr_size, | |
large_attrs_key, | |
) | |
except ValueError as e: | |
return http_util.Respond(request, e.message, "text/plain", code=400) | |
(body, mime_type) = result | |
return http_util.Respond(request, body, mime_type) | |
def run_metadata_route(self, request): | |
"""Given a tag and a run, return the session.run() metadata.""" | |
ctx = plugin_util.context(request.environ) | |
experiment = plugin_util.experiment_id(request.environ) | |
tag = request.args.get("tag") | |
run = request.args.get("run") | |
if tag is None: | |
return http_util.Respond( | |
request, 'query parameter "tag" is required', "text/plain", 400 | |
) | |
if run is None: | |
return http_util.Respond( | |
request, 'query parameter "run" is required', "text/plain", 400 | |
) | |
(body, mime_type) = self.run_metadata_impl(ctx, experiment, run, tag) | |
return http_util.Respond(request, body, mime_type) | |