Spaces:
Sleeping
Sleeping
# Copyright 2017 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""The TensorBoard Histograms plugin. | |
See `http_api.md` in this directory for specifications of the routes for | |
this plugin. | |
""" | |
from werkzeug import wrappers | |
from tensorboard import errors | |
from tensorboard import plugin_util | |
from tensorboard.backend import http_util | |
from tensorboard.data import provider | |
from tensorboard.plugins import base_plugin | |
from tensorboard.plugins.histogram import metadata | |
_DEFAULT_DOWNSAMPLING = 500 # histograms per time series | |
class HistogramsPlugin(base_plugin.TBPlugin): | |
"""Histograms Plugin for TensorBoard. | |
This supports both old-style summaries (created with TensorFlow ops | |
that output directly to the `histo` field of the proto) and new- | |
style summaries (as created by the | |
`tensorboard.plugins.histogram.summary` module). | |
""" | |
plugin_name = metadata.PLUGIN_NAME | |
# Use a round number + 1 since sampling includes both start and end steps, | |
# so N+1 samples corresponds to dividing the step sequence into N intervals. | |
SAMPLE_SIZE = 51 | |
def __init__(self, context): | |
"""Instantiates HistogramsPlugin via TensorBoard core. | |
Args: | |
context: A base_plugin.TBContext instance. | |
""" | |
self._downsample_to = (context.sampling_hints or {}).get( | |
self.plugin_name, _DEFAULT_DOWNSAMPLING | |
) | |
self._data_provider = context.data_provider | |
self._version_checker = plugin_util._MetadataVersionChecker( | |
data_kind="histogram", | |
latest_known_version=0, | |
) | |
def get_plugin_apps(self): | |
return { | |
"/histograms": self.histograms_route, | |
"/tags": self.tags_route, | |
} | |
def is_active(self): | |
return False # `list_plugins` as called by TB core suffices | |
def index_impl(self, ctx, experiment): | |
"""Return {runName: {tagName: {displayName: ..., description: | |
...}}}.""" | |
mapping = self._data_provider.list_tensors( | |
ctx, | |
experiment_id=experiment, | |
plugin_name=metadata.PLUGIN_NAME, | |
) | |
result = {run: {} for run in mapping} | |
for run, tag_to_content in mapping.items(): | |
for tag, metadatum in tag_to_content.items(): | |
description = plugin_util.markdown_to_safe_html( | |
metadatum.description | |
) | |
md = metadata.parse_plugin_metadata(metadatum.plugin_content) | |
if not self._version_checker.ok(md.version, run, tag): | |
continue | |
result[run][tag] = { | |
"displayName": metadatum.display_name, | |
"description": description, | |
} | |
return result | |
def frontend_metadata(self): | |
return base_plugin.FrontendMetadata( | |
element_name="tf-histogram-dashboard" | |
) | |
def histograms_impl(self, ctx, tag, run, experiment, downsample_to=None): | |
"""Result of the form `(body, mime_type)`. | |
At most `downsample_to` events will be returned. If this value is | |
`None`, then default downsampling will be performed. | |
Raises: | |
tensorboard.errors.PublicError: On invalid request. | |
""" | |
sample_count = ( | |
downsample_to if downsample_to is not None else self._downsample_to | |
) | |
all_histograms = self._data_provider.read_tensors( | |
ctx, | |
experiment_id=experiment, | |
plugin_name=metadata.PLUGIN_NAME, | |
downsample=sample_count, | |
run_tag_filter=provider.RunTagFilter(runs=[run], tags=[tag]), | |
) | |
histograms = all_histograms.get(run, {}).get(tag, None) | |
if histograms is None: | |
raise errors.NotFoundError( | |
"No histogram tag %r for run %r" % (tag, run) | |
) | |
events = [(e.wall_time, e.step, e.numpy.tolist()) for e in histograms] | |
return (events, "application/json") | |
def tags_route(self, request): | |
ctx = plugin_util.context(request.environ) | |
experiment = plugin_util.experiment_id(request.environ) | |
index = self.index_impl(ctx, experiment=experiment) | |
return http_util.Respond(request, index, "application/json") | |
def histograms_route(self, request): | |
"""Given a tag and single run, return array of histogram values.""" | |
ctx = plugin_util.context(request.environ) | |
experiment = plugin_util.experiment_id(request.environ) | |
tag = request.args.get("tag") | |
run = request.args.get("run") | |
(body, mime_type) = self.histograms_impl( | |
ctx, tag, run, experiment=experiment, downsample_to=self.SAMPLE_SIZE | |
) | |
return http_util.Respond(request, body, mime_type) | |