Spaces:
Sleeping
Sleeping
""" | |
Copyright (c) 2022, salesforce.com, inc. | |
All rights reserved. | |
SPDX-License-Identifier: BSD-3-Clause | |
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | |
""" | |
import json | |
import logging | |
import os | |
import numpy as np | |
import torch | |
from lavis.common.dist_utils import is_main_process | |
from lavis.common.registry import registry | |
from lavis.tasks.base_task import BaseTask | |
class RetrievalTask(BaseTask): | |
def __init__(self, cfg): | |
super().__init__() | |
self.cfg = cfg | |
def setup_task(cls, cfg): | |
run_cfg = cfg.run_cfg | |
return cls(cfg=run_cfg) | |
def evaluation(self, model, data_loader, **kwargs): | |
# score_i2t, score_t2i = model.compute_sim_matrix(model, data_loader) | |
score_i2t, score_t2i = model.compute_sim_matrix(data_loader, task_cfg=self.cfg) | |
if is_main_process(): | |
eval_result = self._report_metrics( | |
score_i2t, | |
score_t2i, | |
data_loader.dataset.txt2img, | |
data_loader.dataset.img2txt, | |
) | |
logging.info(eval_result) | |
else: | |
eval_result = None | |
return eval_result | |
def after_evaluation(self, val_result, **kwargs): | |
return val_result | |
def _report_metrics(scores_i2t, scores_t2i, txt2img, img2txt): | |
# Images->Text | |
ranks = np.zeros(scores_i2t.shape[0]) | |
for index, score in enumerate(scores_i2t): | |
inds = np.argsort(score)[::-1] | |
# Score | |
rank = 1e20 | |
for i in img2txt[index]: | |
tmp = np.where(inds == i)[0][0] | |
if tmp < rank: | |
rank = tmp | |
ranks[index] = rank | |
# Compute metrics | |
tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) | |
tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) | |
tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) | |
# Text->Images | |
ranks = np.zeros(scores_t2i.shape[0]) | |
for index, score in enumerate(scores_t2i): | |
inds = np.argsort(score)[::-1] | |
ranks[index] = np.where(inds == txt2img[index])[0][0] | |
# Compute metrics | |
ir1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) | |
ir5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) | |
ir10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) | |
tr_mean = (tr1 + tr5 + tr10) / 3 | |
ir_mean = (ir1 + ir5 + ir10) / 3 | |
r_mean = (tr_mean + ir_mean) / 2 | |
agg_metrics = (tr1 + tr5 + tr10) / 3 | |
eval_result = { | |
"txt_r1": tr1, | |
"txt_r5": tr5, | |
"txt_r10": tr10, | |
"txt_r_mean": tr_mean, | |
"img_r1": ir1, | |
"img_r5": ir5, | |
"img_r10": ir10, | |
"img_r_mean": ir_mean, | |
"r_mean": r_mean, | |
"agg_metrics": agg_metrics, | |
} | |
with open( | |
os.path.join(registry.get_path("output_dir"), "evaluate.txt"), "a" | |
) as f: | |
f.write(json.dumps(eval_result) + "\n") | |
return eval_result | |