Spaces:
Sleeping
Sleeping
# Copyright 2017 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""The TensorBoard Scalars plugin. | |
See `http_api.md` in this directory for specifications of the routes for | |
this plugin. | |
""" | |
import csv | |
import io | |
import werkzeug.exceptions | |
from werkzeug import wrappers | |
from tensorboard import errors | |
from tensorboard import plugin_util | |
from tensorboard.backend import http_util | |
from tensorboard.data import provider | |
from tensorboard.plugins import base_plugin | |
from tensorboard.plugins.scalar import metadata | |
_DEFAULT_DOWNSAMPLING = 1000 # scalars per time series | |
class OutputFormat: | |
"""An enum used to list the valid output formats for API calls.""" | |
JSON = "json" | |
CSV = "csv" | |
class ScalarsPlugin(base_plugin.TBPlugin): | |
"""Scalars Plugin for TensorBoard.""" | |
plugin_name = metadata.PLUGIN_NAME | |
def __init__(self, context): | |
"""Instantiates ScalarsPlugin via TensorBoard core. | |
Args: | |
context: A base_plugin.TBContext instance. | |
""" | |
self._downsample_to = (context.sampling_hints or {}).get( | |
self.plugin_name, _DEFAULT_DOWNSAMPLING | |
) | |
self._data_provider = context.data_provider | |
self._version_checker = plugin_util._MetadataVersionChecker( | |
data_kind="scalar", | |
latest_known_version=0, | |
) | |
def get_plugin_apps(self): | |
return { | |
"/scalars": self.scalars_route, | |
"/scalars_multirun": self.scalars_multirun_route, | |
"/tags": self.tags_route, | |
} | |
def is_active(self): | |
return False # `list_plugins` as called by TB core suffices | |
def frontend_metadata(self): | |
return base_plugin.FrontendMetadata(element_name="tf-scalar-dashboard") | |
def index_impl(self, ctx, experiment=None): | |
"""Return {runName: {tagName: {displayName: ..., description: | |
...}}}.""" | |
mapping = self._data_provider.list_scalars( | |
ctx, | |
experiment_id=experiment, | |
plugin_name=metadata.PLUGIN_NAME, | |
) | |
result = {run: {} for run in mapping} | |
for run, tag_to_content in mapping.items(): | |
for tag, metadatum in tag_to_content.items(): | |
md = metadata.parse_plugin_metadata(metadatum.plugin_content) | |
if not self._version_checker.ok(md.version, run, tag): | |
continue | |
description = plugin_util.markdown_to_safe_html( | |
metadatum.description | |
) | |
result[run][tag] = { | |
"displayName": metadatum.display_name, | |
"description": description, | |
} | |
return result | |
def scalars_impl(self, ctx, tag, run, experiment, output_format): | |
"""Result of the form `(body, mime_type)`.""" | |
all_scalars = self._data_provider.read_scalars( | |
ctx, | |
experiment_id=experiment, | |
plugin_name=metadata.PLUGIN_NAME, | |
downsample=self._downsample_to, | |
run_tag_filter=provider.RunTagFilter(runs=[run], tags=[tag]), | |
) | |
scalars = all_scalars.get(run, {}).get(tag, None) | |
if scalars is None: | |
raise errors.NotFoundError( | |
"No scalar data for run=%r, tag=%r" % (run, tag) | |
) | |
values = [(x.wall_time, x.step, x.value) for x in scalars] | |
if output_format == OutputFormat.CSV: | |
string_io = io.StringIO() | |
writer = csv.writer(string_io) | |
writer.writerow(["Wall time", "Step", "Value"]) | |
writer.writerows(values) | |
return (string_io.getvalue(), "text/csv") | |
else: | |
return (values, "application/json") | |
def scalars_multirun_impl(self, ctx, tag, runs, experiment): | |
"""Result of the form `(body, mime_type)`.""" | |
all_scalars = self._data_provider.read_scalars( | |
ctx, | |
experiment_id=experiment, | |
plugin_name=metadata.PLUGIN_NAME, | |
downsample=self._downsample_to, | |
run_tag_filter=provider.RunTagFilter(runs=runs, tags=[tag]), | |
) | |
body = { | |
run: [(x.wall_time, x.step, x.value) for x in run_data[tag]] | |
for (run, run_data) in all_scalars.items() | |
} | |
return (body, "application/json") | |
def tags_route(self, request): | |
ctx = plugin_util.context(request.environ) | |
experiment = plugin_util.experiment_id(request.environ) | |
index = self.index_impl(ctx, experiment=experiment) | |
return http_util.Respond(request, index, "application/json") | |
def scalars_route(self, request): | |
"""Given a tag and single run, return array of ScalarEvents.""" | |
tag = request.args.get("tag") | |
run = request.args.get("run") | |
if tag is None or run is None: | |
raise errors.InvalidArgumentError( | |
"Both run and tag must be specified: tag=%r, run=%r" | |
% (tag, run) | |
) | |
ctx = plugin_util.context(request.environ) | |
experiment = plugin_util.experiment_id(request.environ) | |
output_format = request.args.get("format") | |
(body, mime_type) = self.scalars_impl( | |
ctx, tag, run, experiment, output_format | |
) | |
return http_util.Respond(request, body, mime_type) | |
def scalars_multirun_route(self, request): | |
"""Given a tag and list of runs, return dict of ScalarEvent arrays.""" | |
if request.method != "POST": | |
raise werkzeug.exceptions.MethodNotAllowed(["POST"]) | |
tags = request.form.getlist("tag") | |
runs = request.form.getlist("runs") | |
if len(tags) != 1: | |
raise errors.InvalidArgumentError( | |
"tag must be specified exactly once" | |
) | |
tag = tags[0] | |
ctx = plugin_util.context(request.environ) | |
experiment = plugin_util.experiment_id(request.environ) | |
(body, mime_type) = self.scalars_multirun_impl( | |
ctx, tag, runs, experiment | |
) | |
return http_util.Respond(request, body, mime_type) | |