Spaces:
Build error
Build error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
############################################################################## | |
# | |
# Based on: | |
# -------------------------------------------------------- | |
# ActivityNet | |
# Copyright (c) 2015 ActivityNet | |
# Licensed under The MIT License | |
# [see https://github.com/activitynet/ActivityNet/blob/master/LICENSE for details] | |
# -------------------------------------------------------- | |
"""Helper functions for AVA evaluation.""" | |
from __future__ import ( | |
absolute_import, | |
division, | |
print_function, | |
unicode_literals, | |
) | |
import csv | |
import logging | |
import numpy as np | |
import pprint | |
import time | |
from collections import defaultdict | |
from fvcore.common.file_io import PathManager | |
import timesformer.utils.distributed as du | |
from timesformer.utils.ava_evaluation import ( | |
object_detection_evaluation, | |
standard_fields, | |
) | |
logger = logging.getLogger(__name__) | |
def make_image_key(video_id, timestamp): | |
"""Returns a unique identifier for a video id & timestamp.""" | |
return "%s,%04d" % (video_id, int(timestamp)) | |
def read_csv(csv_file, class_whitelist=None, load_score=False): | |
"""Loads boxes and class labels from a CSV file in the AVA format. | |
CSV file format described at https://research.google.com/ava/download.html. | |
Args: | |
csv_file: A file object. | |
class_whitelist: If provided, boxes corresponding to (integer) class labels | |
not in this set are skipped. | |
Returns: | |
boxes: A dictionary mapping each unique image key (string) to a list of | |
boxes, given as coordinates [y1, x1, y2, x2]. | |
labels: A dictionary mapping each unique image key (string) to a list of | |
integer class lables, matching the corresponding box in `boxes`. | |
scores: A dictionary mapping each unique image key (string) to a list of | |
score values lables, matching the corresponding label in `labels`. If | |
scores are not provided in the csv, then they will default to 1.0. | |
""" | |
boxes = defaultdict(list) | |
labels = defaultdict(list) | |
scores = defaultdict(list) | |
with PathManager.open(csv_file, "r") as f: | |
reader = csv.reader(f) | |
for row in reader: | |
assert len(row) in [7, 8], "Wrong number of columns: " + row | |
image_key = make_image_key(row[0], row[1]) | |
x1, y1, x2, y2 = [float(n) for n in row[2:6]] | |
action_id = int(row[6]) | |
if class_whitelist and action_id not in class_whitelist: | |
continue | |
score = 1.0 | |
if load_score: | |
score = float(row[7]) | |
boxes[image_key].append([y1, x1, y2, x2]) | |
labels[image_key].append(action_id) | |
scores[image_key].append(score) | |
return boxes, labels, scores | |
def read_exclusions(exclusions_file): | |
"""Reads a CSV file of excluded timestamps. | |
Args: | |
exclusions_file: A file object containing a csv of video-id,timestamp. | |
Returns: | |
A set of strings containing excluded image keys, e.g. "aaaaaaaaaaa,0904", | |
or an empty set if exclusions file is None. | |
""" | |
excluded = set() | |
if exclusions_file: | |
with PathManager.open(exclusions_file, "r") as f: | |
reader = csv.reader(f) | |
for row in reader: | |
assert len(row) == 2, "Expected only 2 columns, got: " + row | |
excluded.add(make_image_key(row[0], row[1])) | |
return excluded | |
def read_labelmap(labelmap_file): | |
"""Read label map and class ids.""" | |
labelmap = [] | |
class_ids = set() | |
name = "" | |
class_id = "" | |
with PathManager.open(labelmap_file, "r") as f: | |
for line in f: | |
if line.startswith(" name:"): | |
name = line.split('"')[1] | |
elif line.startswith(" id:") or line.startswith(" label_id:"): | |
class_id = int(line.strip().split(" ")[-1]) | |
labelmap.append({"id": class_id, "name": name}) | |
class_ids.add(class_id) | |
return labelmap, class_ids | |
def evaluate_ava_from_files(labelmap, groundtruth, detections, exclusions): | |
"""Run AVA evaluation given annotation/prediction files.""" | |
categories, class_whitelist = read_labelmap(labelmap) | |
excluded_keys = read_exclusions(exclusions) | |
groundtruth = read_csv(groundtruth, class_whitelist, load_score=False) | |
detections = read_csv(detections, class_whitelist, load_score=True) | |
run_evaluation(categories, groundtruth, detections, excluded_keys) | |
def evaluate_ava( | |
preds, | |
original_boxes, | |
metadata, | |
excluded_keys, | |
class_whitelist, | |
categories, | |
groundtruth=None, | |
video_idx_to_name=None, | |
name="latest", | |
): | |
"""Run AVA evaluation given numpy arrays.""" | |
eval_start = time.time() | |
detections = get_ava_eval_data( | |
preds, | |
original_boxes, | |
metadata, | |
class_whitelist, | |
video_idx_to_name=video_idx_to_name, | |
) | |
logger.info("Evaluating with %d unique GT frames." % len(groundtruth[0])) | |
logger.info( | |
"Evaluating with %d unique detection frames" % len(detections[0]) | |
) | |
write_results(detections, "detections_%s.csv" % name) | |
write_results(groundtruth, "groundtruth_%s.csv" % name) | |
results = run_evaluation(categories, groundtruth, detections, excluded_keys) | |
logger.info("AVA eval done in %f seconds." % (time.time() - eval_start)) | |
return results["PascalBoxes_Precision/[email protected]"] | |
def run_evaluation( | |
categories, groundtruth, detections, excluded_keys, verbose=True | |
): | |
"""AVA evaluation main logic.""" | |
pascal_evaluator = object_detection_evaluation.PascalDetectionEvaluator( | |
categories | |
) | |
boxes, labels, _ = groundtruth | |
gt_keys = [] | |
pred_keys = [] | |
for image_key in boxes: | |
if image_key in excluded_keys: | |
logging.info( | |
( | |
"Found excluded timestamp in ground truth: %s. " | |
"It will be ignored." | |
), | |
image_key, | |
) | |
continue | |
pascal_evaluator.add_single_ground_truth_image_info( | |
image_key, | |
{ | |
standard_fields.InputDataFields.groundtruth_boxes: np.array( | |
boxes[image_key], dtype=float | |
), | |
standard_fields.InputDataFields.groundtruth_classes: np.array( | |
labels[image_key], dtype=int | |
), | |
standard_fields.InputDataFields.groundtruth_difficult: np.zeros( | |
len(boxes[image_key]), dtype=bool | |
), | |
}, | |
) | |
gt_keys.append(image_key) | |
boxes, labels, scores = detections | |
for image_key in boxes: | |
if image_key in excluded_keys: | |
logging.info( | |
( | |
"Found excluded timestamp in detections: %s. " | |
"It will be ignored." | |
), | |
image_key, | |
) | |
continue | |
pascal_evaluator.add_single_detected_image_info( | |
image_key, | |
{ | |
standard_fields.DetectionResultFields.detection_boxes: np.array( | |
boxes[image_key], dtype=float | |
), | |
standard_fields.DetectionResultFields.detection_classes: np.array( | |
labels[image_key], dtype=int | |
), | |
standard_fields.DetectionResultFields.detection_scores: np.array( | |
scores[image_key], dtype=float | |
), | |
}, | |
) | |
pred_keys.append(image_key) | |
metrics = pascal_evaluator.evaluate() | |
if du.is_master_proc(): | |
pprint.pprint(metrics, indent=2) | |
return metrics | |
def get_ava_eval_data( | |
scores, | |
boxes, | |
metadata, | |
class_whitelist, | |
verbose=False, | |
video_idx_to_name=None, | |
): | |
""" | |
Convert our data format into the data format used in official AVA | |
evaluation. | |
""" | |
out_scores = defaultdict(list) | |
out_labels = defaultdict(list) | |
out_boxes = defaultdict(list) | |
count = 0 | |
for i in range(scores.shape[0]): | |
video_idx = int(np.round(metadata[i][0])) | |
sec = int(np.round(metadata[i][1])) | |
video = video_idx_to_name[video_idx] | |
key = video + "," + "%04d" % (sec) | |
batch_box = boxes[i].tolist() | |
# The first is batch idx. | |
batch_box = [batch_box[j] for j in [0, 2, 1, 4, 3]] | |
one_scores = scores[i].tolist() | |
for cls_idx, score in enumerate(one_scores): | |
if cls_idx + 1 in class_whitelist: | |
out_scores[key].append(score) | |
out_labels[key].append(cls_idx + 1) | |
out_boxes[key].append(batch_box[1:]) | |
count += 1 | |
return out_boxes, out_labels, out_scores | |
def write_results(detections, filename): | |
"""Write prediction results into official formats.""" | |
start = time.time() | |
boxes, labels, scores = detections | |
with PathManager.open(filename, "w") as f: | |
for key in boxes.keys(): | |
for box, label, score in zip(boxes[key], labels[key], scores[key]): | |
f.write( | |
"%s,%.03f,%.03f,%.03f,%.03f,%d,%.04f\n" | |
% (key, box[1], box[0], box[3], box[2], label, score) | |
) | |
logger.info("AVA results wrote to %s" % filename) | |
logger.info("\ttook %d seconds." % (time.time() - start)) | |