|
import evaluate |
|
from evaluate.utils.file_utils import add_start_docstrings |
|
import datasets |
|
import torch |
|
from transformers import CLIPProcessor, CLIPModel |
|
from tqdm import tqdm |
|
|
|
_DESCRIPTION = """ |
|
This metric evaluates CLIP models on image-text retrieval tasks using standard datasets. |
|
It calculates Recall@K metrics for both text-to-image and image-to-text retrieval. |
|
""" |
|
|
|
_KWARGS_DESCRIPTION = """ |
|
Args: |
|
model_name: Name or path of the CLIP model to evaluate (e.g., "openai/clip-vit-base-patch32") |
|
dataset_names: List of dataset names to evaluate on (choices: "mscoco", "flickr") |
|
n_examples: Number of examples to use for evaluation (-1 for all) |
|
|
|
Returns: |
|
Dictionary containing Recall@K metrics for each dataset and retrieval direction |
|
""" |
|
|
|
_CITATION = """ |
|
@inproceedings{radford2021learning, |
|
title={Learning transferable visual models from natural language supervision}, |
|
author={Radford, Alec and Kim, Jong Wook and Hallacy, Chris and Ramesh, Aditya and others}, |
|
booktitle={International Conference on Machine Learning}, |
|
year={2021}, |
|
} |
|
""" |
|
|
|
|
|
@add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) |
|
class DmxClipEval(evaluate.Metric): |
|
def _info(self): |
|
return evaluate.MetricInfo( |
|
module_type="metric", |
|
description=_DESCRIPTION, |
|
citation=_CITATION, |
|
inputs_description=_KWARGS_DESCRIPTION, |
|
features=[ |
|
datasets.Features( |
|
{ |
|
"model_name": datasets.Value("string"), |
|
"dataset_names": datasets.Value("string"), |
|
"n_examples": datasets.Value("int32"), |
|
} |
|
), |
|
], |
|
) |
|
|
|
def clip_dataset_evaluator( |
|
self, model, device, desc, dataset_name="mscoco", n_examples=-1 |
|
): |
|
processor = CLIPProcessor.from_pretrained(model.config._name_or_path) |
|
if dataset_name == "mscoco": |
|
ds = datasets.load_dataset( |
|
"clip-benchmark/wds_mscoco_captions", split="test" |
|
) |
|
elif dataset_name == "flickr": |
|
ds = datasets.load_dataset("clip-benchmark/wds_flickr8k", split="test") |
|
else: |
|
raise ValueError(f"invalid dataset name : {dataset_name}") |
|
|
|
if n_examples != -1: |
|
ds = ds.select(range(min(n_examples, len(ds)))) |
|
|
|
dl = torch.utils.data.DataLoader(torch.arange(len(ds)), batch_size=8) |
|
all_image_embeds = [] |
|
all_text_embeds = [] |
|
|
|
for indices in tqdm(dl, desc=f"Processing {dataset_name}"): |
|
batch = ds[indices.tolist()] |
|
inputs = processor( |
|
text=batch["txt"], |
|
images=batch["jpg"], |
|
return_tensors="pt", |
|
padding=True, |
|
) |
|
inputs["input_ids"] = inputs["input_ids"][:, :77] |
|
inputs["attention_mask"] = inputs["attention_mask"][:, :77] |
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
|
with torch.no_grad(): |
|
output = model(**inputs) |
|
|
|
all_image_embeds.append(output.image_embeds.cpu()) |
|
all_text_embeds.append(output.text_embeds.cpu()) |
|
|
|
all_image_embeds = torch.cat(all_image_embeds, dim=0) |
|
all_text_embeds = torch.cat(all_text_embeds, dim=0) |
|
text_img_sim = all_text_embeds @ all_image_embeds.t() |
|
|
|
def get_top_k(sim_mat, k_arr): |
|
ordered_winners = torch.argsort(sim_mat, dim=-1, descending=True) |
|
correct_winner_mask = ( |
|
ordered_winners |
|
== torch.arange(ordered_winners.shape[0]) |
|
.unsqueeze(1) |
|
.to(ordered_winners.device) |
|
).long() |
|
return [ |
|
correct_winner_mask[:, :k].sum(-1).float().mean().item() for k in k_arr |
|
] |
|
|
|
k_arr = [1, 5, 10] |
|
metrics = { |
|
**{ |
|
f"{dataset_name}:image_recall@{k}": val |
|
for k, val in zip(k_arr, get_top_k(text_img_sim, k_arr)) |
|
}, |
|
**{ |
|
f"{dataset_name}:text_recall@{k}": val |
|
for k, val in zip(k_arr, get_top_k(text_img_sim.t(), k_arr)) |
|
}, |
|
} |
|
return metrics |
|
|
|
def clip_evaluator(self, model, device, desc, n_examples=-1): |
|
metrics = {} |
|
for name in ["mscoco", "flickr"]: |
|
metrics.update( |
|
self.clip_dataset_evaluator(model, device, desc, name, n_examples) |
|
) |
|
return metrics |
|
|
|
def _compute(self, model_name, dataset_names, n_examples): |
|
|
|
actual_model_name = model_name[0] |
|
actual_dataset_name_str = dataset_names[0] |
|
actual_n_examples = n_examples[0] |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model = CLIPModel.from_pretrained(actual_model_name).to(device) |
|
|
|
datasets_to_evaluate = [actual_dataset_name_str] |
|
|
|
metrics = {} |
|
for ds_name_loop_var in datasets_to_evaluate: |
|
dataset_metrics = self.clip_dataset_evaluator( |
|
model=model, |
|
device=device, |
|
desc=actual_model_name, |
|
dataset_name=ds_name_loop_var, |
|
n_examples=actual_n_examples, |
|
) |
|
metrics.update(dataset_metrics) |
|
|
|
return metrics |
|
|