clip_eval / clip_eval.py
wanzin's picture
fixed the evaluator
9eac0b7
import evaluate
from evaluate.utils.file_utils import add_start_docstrings
import datasets
import torch
from transformers import CLIPProcessor, CLIPModel
from tqdm import tqdm
_DESCRIPTION = """
This metric evaluates CLIP models on image-text retrieval tasks using standard datasets.
It calculates Recall@K metrics for both text-to-image and image-to-text retrieval.
"""
_KWARGS_DESCRIPTION = """
Args:
model_name: Name or path of the CLIP model to evaluate (e.g., "openai/clip-vit-base-patch32")
dataset_names: List of dataset names to evaluate on (choices: "mscoco", "flickr")
n_examples: Number of examples to use for evaluation (-1 for all)
Returns:
Dictionary containing Recall@K metrics for each dataset and retrieval direction
"""
_CITATION = """
@inproceedings{radford2021learning,
title={Learning transferable visual models from natural language supervision},
author={Radford, Alec and Kim, Jong Wook and Hallacy, Chris and Ramesh, Aditya and others},
booktitle={International Conference on Machine Learning},
year={2021},
}
"""
@add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class DmxClipEval(evaluate.Metric):
def _info(self):
return evaluate.MetricInfo(
module_type="metric",
description=_DESCRIPTION,
citation=_CITATION,
inputs_description=_KWARGS_DESCRIPTION,
features=[
datasets.Features(
{
"model_name": datasets.Value("string"),
"dataset_names": datasets.Value("string"),
"n_examples": datasets.Value("int32"),
}
),
],
)
def clip_dataset_evaluator(
self, model, device, desc, dataset_name="mscoco", n_examples=-1
):
processor = CLIPProcessor.from_pretrained(model.config._name_or_path)
if dataset_name == "mscoco":
ds = datasets.load_dataset(
"clip-benchmark/wds_mscoco_captions", split="test"
)
elif dataset_name == "flickr":
ds = datasets.load_dataset("clip-benchmark/wds_flickr8k", split="test")
else:
raise ValueError(f"invalid dataset name : {dataset_name}")
if n_examples != -1:
ds = ds.select(range(min(n_examples, len(ds))))
dl = torch.utils.data.DataLoader(torch.arange(len(ds)), batch_size=8)
all_image_embeds = []
all_text_embeds = []
for indices in tqdm(dl, desc=f"Processing {dataset_name}"):
batch = ds[indices.tolist()]
inputs = processor(
text=batch["txt"],
images=batch["jpg"],
return_tensors="pt",
padding=True,
)
inputs["input_ids"] = inputs["input_ids"][:, :77]
inputs["attention_mask"] = inputs["attention_mask"][:, :77]
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
output = model(**inputs)
all_image_embeds.append(output.image_embeds.cpu())
all_text_embeds.append(output.text_embeds.cpu())
all_image_embeds = torch.cat(all_image_embeds, dim=0)
all_text_embeds = torch.cat(all_text_embeds, dim=0)
text_img_sim = all_text_embeds @ all_image_embeds.t()
def get_top_k(sim_mat, k_arr):
ordered_winners = torch.argsort(sim_mat, dim=-1, descending=True)
correct_winner_mask = (
ordered_winners
== torch.arange(ordered_winners.shape[0])
.unsqueeze(1)
.to(ordered_winners.device)
).long()
return [
correct_winner_mask[:, :k].sum(-1).float().mean().item() for k in k_arr
]
k_arr = [1, 5, 10]
metrics = {
**{
f"{dataset_name}:image_recall@{k}": val
for k, val in zip(k_arr, get_top_k(text_img_sim, k_arr))
},
**{
f"{dataset_name}:text_recall@{k}": val
for k, val in zip(k_arr, get_top_k(text_img_sim.t(), k_arr))
},
}
return metrics
def clip_evaluator(self, model, device, desc, n_examples=-1):
metrics = {}
for name in ["mscoco", "flickr"]:
metrics.update(
self.clip_dataset_evaluator(model, device, desc, name, n_examples)
)
return metrics
def _compute(self, model_name, dataset_names, n_examples):
actual_model_name = model_name[0]
actual_dataset_name_str = dataset_names[0]
actual_n_examples = n_examples[0]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CLIPModel.from_pretrained(actual_model_name).to(device)
datasets_to_evaluate = [actual_dataset_name_str]
metrics = {}
for ds_name_loop_var in datasets_to_evaluate:
dataset_metrics = self.clip_dataset_evaluator(
model=model,
device=device,
desc=actual_model_name,
dataset_name=ds_name_loop_var,
n_examples=actual_n_examples,
)
metrics.update(dataset_metrics)
return metrics