File size: 22,316 Bytes
cf2a15a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Precision--recall curves and TensorFlow operations to create them.

NOTE: This module is in beta, and its API is subject to change, but the
data that it stores to disk will be supported forever.
"""


import numpy as np

from tensorboard.plugins.pr_curve import metadata


# A value that we use as the minimum value during division of counts to prevent
# division by 0. 1.0 does not work: Certain weights could cause counts below 1.
_MINIMUM_COUNT = 1e-7

# The default number of thresholds.
_DEFAULT_NUM_THRESHOLDS = 201


def op(
    name,
    labels,
    predictions,
    num_thresholds=None,
    weights=None,
    display_name=None,
    description=None,
    collections=None,
):
    """Create a PR curve summary op for a single binary classifier.

    Computes true/false positive/negative values for the given `predictions`
    against the ground truth `labels`, against a list of evenly distributed
    threshold values in `[0, 1]` of length `num_thresholds`.

    Each number in `predictions`, a float in `[0, 1]`, is compared with its
    corresponding boolean label in `labels`, and counts as a single tp/fp/tn/fn
    value at each threshold. This is then multiplied with `weights` which can be
    used to reweight certain values, or more commonly used for masking values.

    Args:
      name: A tag attached to the summary. Used by TensorBoard for organization.
      labels: The ground truth values. A Tensor of `bool` values with arbitrary
          shape.
      predictions: A float32 `Tensor` whose values are in the range `[0, 1]`.
          Dimensions must match those of `labels`.
      num_thresholds: Number of thresholds, evenly distributed in `[0, 1]`, to
          compute PR metrics for. Should be `>= 2`. This value should be a
          constant integer value, not a Tensor that stores an integer.
      weights: Optional float32 `Tensor`. Individual counts are multiplied by this
          value. This tensor must be either the same shape as or broadcastable to
          the `labels` tensor.
      display_name: Optional name for this summary in TensorBoard, as a
          constant `str`. Defaults to `name`.
      description: Optional long-form description for this summary, as a
          constant `str`. Markdown is supported. Defaults to empty.
      collections: Optional list of graph collections keys. The new
          summary op is added to these collections. Defaults to
          `[Graph Keys.SUMMARIES]`.

    Returns:
      A summary operation for use in a TensorFlow graph. The float32 tensor
      produced by the summary operation is of dimension (6, num_thresholds). The
      first dimension (of length 6) is of the order: true positives,
      false positives, true negatives, false negatives, precision, recall.
    """
    # TODO(nickfelt): remove on-demand imports once dep situation is fixed.
    import tensorflow.compat.v1 as tf

    if num_thresholds is None:
        num_thresholds = _DEFAULT_NUM_THRESHOLDS

    if weights is None:
        weights = 1.0

    dtype = predictions.dtype

    with tf.name_scope(name, values=[labels, predictions, weights]):
        tf.assert_type(labels, tf.bool)
        # We cast to float to ensure we have 0.0 or 1.0.
        f_labels = tf.cast(labels, dtype)
        # Ensure predictions are all in range [0.0, 1.0].
        predictions = tf.minimum(1.0, tf.maximum(0.0, predictions))
        # Get weighted true/false labels.
        true_labels = f_labels * weights
        false_labels = (1.0 - f_labels) * weights

        # Before we begin, flatten predictions.
        predictions = tf.reshape(predictions, [-1])

        # Shape the labels so they are broadcast-able for later multiplication.
        true_labels = tf.reshape(true_labels, [-1, 1])
        false_labels = tf.reshape(false_labels, [-1, 1])

        # To compute TP/FP/TN/FN, we are measuring a binary classifier
        #   C(t) = (predictions >= t)
        # at each threshold 't'. So we have
        #   TP(t) = sum( C(t) * true_labels )
        #   FP(t) = sum( C(t) * false_labels )
        #
        # But, computing C(t) requires computation for each t. To make it fast,
        # observe that C(t) is a cumulative integral, and so if we have
        #   thresholds = [t_0, ..., t_{n-1}];  t_0 < ... < t_{n-1}
        # where n = num_thresholds, and if we can compute the bucket function
        #   B(i) = Sum( (predictions == t), t_i <= t < t{i+1} )
        # then we get
        #   C(t_i) = sum( B(j), j >= i )
        # which is the reversed cumulative sum in tf.cumsum().
        #
        # We can compute B(i) efficiently by taking advantage of the fact that
        # our thresholds are evenly distributed, in that
        #   width = 1.0 / (num_thresholds - 1)
        #   thresholds = [0.0, 1*width, 2*width, 3*width, ..., 1.0]
        # Given a prediction value p, we can map it to its bucket by
        #   bucket_index(p) = floor( p * (num_thresholds - 1) )
        # so we can use tf.scatter_add() to update the buckets in one pass.

        # Compute the bucket indices for each prediction value.
        bucket_indices = tf.cast(
            tf.floor(predictions * (num_thresholds - 1)), tf.int32
        )

        # Bucket predictions.
        tp_buckets = tf.reduce_sum(
            input_tensor=tf.one_hot(bucket_indices, depth=num_thresholds)
            * true_labels,
            axis=0,
        )
        fp_buckets = tf.reduce_sum(
            input_tensor=tf.one_hot(bucket_indices, depth=num_thresholds)
            * false_labels,
            axis=0,
        )

        # Set up the cumulative sums to compute the actual metrics.
        tp = tf.cumsum(tp_buckets, reverse=True, name="tp")
        fp = tf.cumsum(fp_buckets, reverse=True, name="fp")
        # fn = sum(true_labels) - tp
        #    = sum(tp_buckets) - tp
        #    = tp[0] - tp
        # Similarly,
        # tn = fp[0] - fp
        tn = fp[0] - fp
        fn = tp[0] - tp

        precision = tp / tf.maximum(_MINIMUM_COUNT, tp + fp)
        recall = tp / tf.maximum(_MINIMUM_COUNT, tp + fn)

        return _create_tensor_summary(
            name,
            tp,
            fp,
            tn,
            fn,
            precision,
            recall,
            num_thresholds,
            display_name,
            description,
            collections,
        )


def pb(
    name,
    labels,
    predictions,
    num_thresholds=None,
    weights=None,
    display_name=None,
    description=None,
):
    """Create a PR curves summary protobuf.

    Arguments:
      name: A name for the generated node. Will also serve as a series name in
          TensorBoard.
      labels: The ground truth values. A bool numpy array.
      predictions: A float32 numpy array whose values are in the range `[0, 1]`.
          Dimensions must match those of `labels`.
      num_thresholds: Optional number of thresholds, evenly distributed in
          `[0, 1]`, to compute PR metrics for. When provided, should be an int of
          value at least 2. Defaults to 201.
      weights: Optional float or float32 numpy array. Individual counts are
          multiplied by this value. This tensor must be either the same shape as
          or broadcastable to the `labels` numpy array.
      display_name: Optional name for this summary in TensorBoard, as a `str`.
          Defaults to `name`.
      description: Optional long-form description for this summary, as a `str`.
          Markdown is supported. Defaults to empty.
    """
    # TODO(nickfelt): remove on-demand imports once dep situation is fixed.
    import tensorflow.compat.v1 as tf  # noqa: F401

    if num_thresholds is None:
        num_thresholds = _DEFAULT_NUM_THRESHOLDS

    if weights is None:
        weights = 1.0

    # Compute bins of true positives and false positives.
    bucket_indices = np.int32(np.floor(predictions * (num_thresholds - 1)))
    float_labels = labels.astype(float)
    histogram_range = (0, num_thresholds - 1)
    tp_buckets, _ = np.histogram(
        bucket_indices,
        bins=num_thresholds,
        range=histogram_range,
        weights=float_labels * weights,
    )
    fp_buckets, _ = np.histogram(
        bucket_indices,
        bins=num_thresholds,
        range=histogram_range,
        weights=(1.0 - float_labels) * weights,
    )

    # Obtain the reverse cumulative sum.
    tp = np.cumsum(tp_buckets[::-1])[::-1]
    fp = np.cumsum(fp_buckets[::-1])[::-1]
    tn = fp[0] - fp
    fn = tp[0] - tp
    precision = tp / np.maximum(_MINIMUM_COUNT, tp + fp)
    recall = tp / np.maximum(_MINIMUM_COUNT, tp + fn)

    return raw_data_pb(
        name,
        true_positive_counts=tp,
        false_positive_counts=fp,
        true_negative_counts=tn,
        false_negative_counts=fn,
        precision=precision,
        recall=recall,
        num_thresholds=num_thresholds,
        display_name=display_name,
        description=description,
    )


def streaming_op(
    name,
    labels,
    predictions,
    num_thresholds=None,
    weights=None,
    metrics_collections=None,
    updates_collections=None,
    display_name=None,
    description=None,
):
    """Computes a precision-recall curve summary across batches of data.

    This function is similar to op() above, but can be used to compute the PR
    curve across multiple batches of labels and predictions, in the same style
    as the metrics found in tf.metrics.

    This function creates multiple local variables for storing true positives,
    true negative, etc. accumulated over each batch of data, and uses these local
    variables for computing the final PR curve summary. These variables can be
    updated with the returned update_op.

    Args:
      name: A tag attached to the summary. Used by TensorBoard for organization.
      labels: The ground truth values, a `Tensor` whose dimensions must match
        `predictions`. Will be cast to `bool`.
      predictions: A floating point `Tensor` of arbitrary shape and whose values
        are in the range `[0, 1]`.
      num_thresholds: The number of evenly spaced thresholds to generate for
        computing the PR curve. Defaults to 201.
      weights: Optional `Tensor` whose rank is either 0, or the same rank as
        `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
        be either `1`, or the same as the corresponding `labels` dimension).
      metrics_collections: An optional list of collections that `auc` should be
        added to.
      updates_collections: An optional list of collections that `update_op` should
        be added to.
      display_name: Optional name for this summary in TensorBoard, as a
          constant `str`. Defaults to `name`.
      description: Optional long-form description for this summary, as a
          constant `str`. Markdown is supported. Defaults to empty.

    Returns:
      pr_curve: A string `Tensor` containing a single value: the
        serialized PR curve Tensor summary. The summary contains a
        float32 `Tensor` of dimension (6, num_thresholds). The first
        dimension (of length 6) is of the order: true positives, false
        positives, true negatives, false negatives, precision, recall.
      update_op: An operation that updates the summary with the latest data.
    """
    # TODO(nickfelt): remove on-demand imports once dep situation is fixed.
    import tensorflow.compat.v1 as tf

    if num_thresholds is None:
        num_thresholds = _DEFAULT_NUM_THRESHOLDS

    thresholds = [i / float(num_thresholds - 1) for i in range(num_thresholds)]

    with tf.name_scope(name, values=[labels, predictions, weights]):
        tp, update_tp = tf.metrics.true_positives_at_thresholds(
            labels=labels,
            predictions=predictions,
            thresholds=thresholds,
            weights=weights,
        )
        fp, update_fp = tf.metrics.false_positives_at_thresholds(
            labels=labels,
            predictions=predictions,
            thresholds=thresholds,
            weights=weights,
        )
        tn, update_tn = tf.metrics.true_negatives_at_thresholds(
            labels=labels,
            predictions=predictions,
            thresholds=thresholds,
            weights=weights,
        )
        fn, update_fn = tf.metrics.false_negatives_at_thresholds(
            labels=labels,
            predictions=predictions,
            thresholds=thresholds,
            weights=weights,
        )

        def compute_summary(tp, fp, tn, fn, collections):
            precision = tp / tf.maximum(_MINIMUM_COUNT, tp + fp)
            recall = tp / tf.maximum(_MINIMUM_COUNT, tp + fn)

            return _create_tensor_summary(
                name,
                tp,
                fp,
                tn,
                fn,
                precision,
                recall,
                num_thresholds,
                display_name,
                description,
                collections,
            )

        pr_curve = compute_summary(tp, fp, tn, fn, metrics_collections)
        update_op = tf.group(update_tp, update_fp, update_tn, update_fn)
        if updates_collections:
            for collection in updates_collections:
                tf.add_to_collection(collection, update_op)

        return pr_curve, update_op


def raw_data_op(
    name,
    true_positive_counts,
    false_positive_counts,
    true_negative_counts,
    false_negative_counts,
    precision,
    recall,
    num_thresholds=None,
    display_name=None,
    description=None,
    collections=None,
):
    """Create an op that collects data for visualizing PR curves.

    Unlike the op above, this one avoids computing precision, recall, and the
    intermediate counts. Instead, it accepts those tensors as arguments and
    relies on the caller to ensure that the calculations are correct (and the
    counts yield the provided precision and recall values).

    This op is useful when a caller seeks to compute precision and recall
    differently but still use the PR curves plugin.

    Args:
      name: A tag attached to the summary. Used by TensorBoard for organization.
      true_positive_counts: A rank-1 tensor of true positive counts. Must contain
          `num_thresholds` elements and be castable to float32. Values correspond
          to thresholds that increase from left to right (from 0 to 1).
      false_positive_counts: A rank-1 tensor of false positive counts. Must
          contain `num_thresholds` elements and be castable to float32. Values
          correspond to thresholds that increase from left to right (from 0 to 1).
      true_negative_counts: A rank-1 tensor of true negative counts. Must contain
          `num_thresholds` elements and be castable to float32. Values
          correspond to thresholds that increase from left to right (from 0 to 1).
      false_negative_counts: A rank-1 tensor of false negative counts. Must
          contain `num_thresholds` elements and be castable to float32. Values
          correspond to thresholds that increase from left to right (from 0 to 1).
      precision: A rank-1 tensor of precision values. Must contain
          `num_thresholds` elements and be castable to float32. Values correspond
          to thresholds that increase from left to right (from 0 to 1).
      recall: A rank-1 tensor of recall values. Must contain `num_thresholds`
          elements and be castable to float32. Values correspond to thresholds
          that increase from left to right (from 0 to 1).
      num_thresholds: Number of thresholds, evenly distributed in `[0, 1]`, to
          compute PR metrics for. Should be `>= 2`. This value should be a
          constant integer value, not a Tensor that stores an integer.
      display_name: Optional name for this summary in TensorBoard, as a
          constant `str`. Defaults to `name`.
      description: Optional long-form description for this summary, as a
          constant `str`. Markdown is supported. Defaults to empty.
      collections: Optional list of graph collections keys. The new
          summary op is added to these collections. Defaults to
          `[Graph Keys.SUMMARIES]`.

    Returns:
      A summary operation for use in a TensorFlow graph. See docs for the `op`
      method for details on the float32 tensor produced by this summary.
    """
    # TODO(nickfelt): remove on-demand imports once dep situation is fixed.
    import tensorflow.compat.v1 as tf

    with tf.name_scope(
        name,
        values=[
            true_positive_counts,
            false_positive_counts,
            true_negative_counts,
            false_negative_counts,
            precision,
            recall,
        ],
    ):
        return _create_tensor_summary(
            name,
            true_positive_counts,
            false_positive_counts,
            true_negative_counts,
            false_negative_counts,
            precision,
            recall,
            num_thresholds,
            display_name,
            description,
            collections,
        )


def raw_data_pb(
    name,
    true_positive_counts,
    false_positive_counts,
    true_negative_counts,
    false_negative_counts,
    precision,
    recall,
    num_thresholds=None,
    display_name=None,
    description=None,
):
    """Create a PR curves summary protobuf from raw data values.

    Args:
      name: A tag attached to the summary. Used by TensorBoard for organization.
      true_positive_counts: A rank-1 numpy array of true positive counts. Must
          contain `num_thresholds` elements and be castable to float32.
      false_positive_counts: A rank-1 numpy array of false positive counts. Must
          contain `num_thresholds` elements and be castable to float32.
      true_negative_counts: A rank-1 numpy array of true negative counts. Must
          contain `num_thresholds` elements and be castable to float32.
      false_negative_counts: A rank-1 numpy array of false negative counts. Must
          contain `num_thresholds` elements and be castable to float32.
      precision: A rank-1 numpy array of precision values. Must contain
          `num_thresholds` elements and be castable to float32.
      recall: A rank-1 numpy array of recall values. Must contain `num_thresholds`
          elements and be castable to float32.
      num_thresholds: Number of thresholds, evenly distributed in `[0, 1]`, to
          compute PR metrics for. Should be an int `>= 2`.
      display_name: Optional name for this summary in TensorBoard, as a `str`.
          Defaults to `name`.
      description: Optional long-form description for this summary, as a `str`.
          Markdown is supported. Defaults to empty.

    Returns:
      A summary operation for use in a TensorFlow graph. See docs for the `op`
      method for details on the float32 tensor produced by this summary.
    """
    # TODO(nickfelt): remove on-demand imports once dep situation is fixed.
    import tensorflow.compat.v1 as tf

    if display_name is None:
        display_name = name
    summary_metadata = metadata.create_summary_metadata(
        display_name=display_name if display_name is not None else name,
        description=description or "",
        num_thresholds=num_thresholds,
    )
    tf_summary_metadata = tf.SummaryMetadata.FromString(
        summary_metadata.SerializeToString()
    )
    summary = tf.Summary()
    data = np.stack(
        (
            true_positive_counts,
            false_positive_counts,
            true_negative_counts,
            false_negative_counts,
            precision,
            recall,
        )
    )
    tensor = tf.make_tensor_proto(np.float32(data), dtype=tf.float32)
    summary.value.add(
        tag="%s/pr_curves" % name, metadata=tf_summary_metadata, tensor=tensor
    )
    return summary


def _create_tensor_summary(
    name,
    true_positive_counts,
    false_positive_counts,
    true_negative_counts,
    false_negative_counts,
    precision,
    recall,
    num_thresholds=None,
    display_name=None,
    description=None,
    collections=None,
):
    """A private helper method for generating a tensor summary.

    We use a helper method instead of having `op` directly call `raw_data_op`
    to prevent the scope of `raw_data_op` from being embedded within `op`.

    Arguments are the same as for raw_data_op.

    Returns:
      A tensor summary that collects data for PR curves.
    """
    # TODO(nickfelt): remove on-demand imports once dep situation is fixed.
    import tensorflow.compat.v1 as tf

    # Store the number of thresholds within the summary metadata because
    # that value is constant for all pr curve summaries with the same tag.
    summary_metadata = metadata.create_summary_metadata(
        display_name=display_name if display_name is not None else name,
        description=description or "",
        num_thresholds=num_thresholds,
    )

    # Store values within a tensor. We store them in the order:
    # true positives, false positives, true negatives, false
    # negatives, precision, and recall.
    combined_data = tf.stack(
        [
            tf.cast(true_positive_counts, tf.float32),
            tf.cast(false_positive_counts, tf.float32),
            tf.cast(true_negative_counts, tf.float32),
            tf.cast(false_negative_counts, tf.float32),
            tf.cast(precision, tf.float32),
            tf.cast(recall, tf.float32),
        ]
    )

    return tf.summary.tensor_summary(
        name="pr_curves",
        tensor=combined_data,
        collections=collections,
        summary_metadata=summary_metadata,
    )