Spaces:
Sleeping
Sleeping
File size: 5,769 Bytes
cf2a15a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Classes and functions for handling the DownloadData API call."""
import csv
import io
import math
from tensorboard.plugins.hparams import error
class OutputFormat:
"""An enum used to list the valid output formats for API calls."""
JSON = "json"
CSV = "csv"
LATEX = "latex"
class Handler:
"""Handles a DownloadData request."""
def __init__(
self,
context,
experiment,
session_groups,
response_format,
columns_visibility,
):
"""Constructor.
Args:
context: A backend_context.Context instance.
experiment: Experiment proto.
session_groups: ListSessionGroupsResponse proto.
response_format: A string in the OutputFormat enum.
columns_visibility: A list of boolean values to filter columns.
"""
self._context = context
self._experiment = experiment
self._session_groups = session_groups
self._response_format = response_format
self._columns_visibility = columns_visibility
def run(self):
"""Handles the request specified on construction.
Returns:
A response body.
A mime type (string) for the response.
"""
experiment = self._experiment
session_groups = self._session_groups
response_format = self._response_format
visibility = self._columns_visibility
header = []
for hparam_info in experiment.hparam_infos:
header.append(hparam_info.display_name or hparam_info.name)
for metric_info in experiment.metric_infos:
header.append(metric_info.display_name or metric_info.name.tag)
def _filter_columns(row):
return [value for value, visible in zip(row, visibility) if visible]
header = _filter_columns(header)
rows = []
def _get_value(value):
if value.HasField("number_value"):
return value.number_value
if value.HasField("string_value"):
return value.string_value
if value.HasField("bool_value"):
return value.bool_value
# hyperparameter values can be optional in a session group
return ""
def _get_metric_id(metric):
return metric.group + "." + metric.tag
for group in session_groups.session_groups:
row = []
for hparam_info in experiment.hparam_infos:
row.append(_get_value(group.hparams[hparam_info.name]))
metric_values = {}
for metric_value in group.metric_values:
metric_id = _get_metric_id(metric_value.name)
metric_values[metric_id] = metric_value.value
for metric_info in experiment.metric_infos:
metric_id = _get_metric_id(metric_info.name)
row.append(metric_values.get(metric_id))
rows.append(_filter_columns(row))
if response_format == OutputFormat.JSON:
mime_type = "application/json"
body = dict(header=header, rows=rows)
elif response_format == OutputFormat.LATEX:
def latex_format(value):
if value is None:
return "-"
elif isinstance(value, int):
return "$%d$" % value
elif isinstance(value, float):
if math.isnan(value):
return r"$\mathrm{NaN}$"
if value in (float("inf"), float("-inf")):
return r"$%s\infty$" % ("-" if value < 0 else "+")
scientific = "%.3g" % value
if "e" in scientific:
coefficient, exponent = scientific.split("e")
return "$%s\\cdot 10^{%d}$" % (
coefficient,
int(exponent),
)
return "$%s$" % scientific
return value.replace("_", "\\_").replace("%", "\\%")
mime_type = "application/x-latex"
top_part = "\\begin{table}[tbp]\n\\begin{tabular}{%s}\n" % (
"l" * len(header)
)
header_part = (
" & ".join(map(latex_format, header)) + " \\\\ \\hline\n"
)
middle_part = "".join(
" & ".join(map(latex_format, row)) + " \\\\\n" for row in rows
)
bottom_part = "\\hline\n\\end{tabular}\n\\end{table}\n"
body = top_part + header_part + middle_part + bottom_part
elif response_format == OutputFormat.CSV:
string_io = io.StringIO()
writer = csv.writer(string_io)
writer.writerow(header)
writer.writerows(rows)
body = string_io.getvalue()
mime_type = "text/csv"
else:
raise error.HParamsError(
"Invalid reponses format: %s" % response_format
)
return body, mime_type
|