import evaluate from evaluate.utils.file_utils import add_start_docstrings import datasets import torch from transformers import CLIPProcessor, CLIPModel from tqdm import tqdm _DESCRIPTION = """ This metric evaluates CLIP models on image-text retrieval tasks using standard datasets. It calculates Recall@K metrics for both text-to-image and image-to-text retrieval. """ _KWARGS_DESCRIPTION = """ Args: model_name: Name or path of the CLIP model to evaluate (e.g., "openai/clip-vit-base-patch32") dataset_names: List of dataset names to evaluate on (choices: "mscoco", "flickr") n_examples: Number of examples to use for evaluation (-1 for all) Returns: Dictionary containing Recall@K metrics for each dataset and retrieval direction """ _CITATION = """ @inproceedings{radford2021learning, title={Learning transferable visual models from natural language supervision}, author={Radford, Alec and Kim, Jong Wook and Hallacy, Chris and Ramesh, Aditya and others}, booktitle={International Conference on Machine Learning}, year={2021}, } """ @add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) class DmxClipEval(evaluate.Metric): def _info(self): return evaluate.MetricInfo( module_type="metric", description=_DESCRIPTION, citation=_CITATION, inputs_description=_KWARGS_DESCRIPTION, features=[ datasets.Features( { "model_name": datasets.Value("string"), "dataset_names": datasets.Value("string"), "n_examples": datasets.Value("int32"), } ), ], ) def clip_dataset_evaluator( self, model, device, desc, dataset_name="mscoco", n_examples=-1 ): processor = CLIPProcessor.from_pretrained(model.config._name_or_path) if dataset_name == "mscoco": ds = datasets.load_dataset( "clip-benchmark/wds_mscoco_captions", split="test" ) elif dataset_name == "flickr": ds = datasets.load_dataset("clip-benchmark/wds_flickr8k", split="test") else: raise ValueError(f"invalid dataset name : {dataset_name}") if n_examples != -1: ds = ds.select(range(min(n_examples, len(ds)))) dl = torch.utils.data.DataLoader(torch.arange(len(ds)), batch_size=8) all_image_embeds = [] all_text_embeds = [] for indices in tqdm(dl, desc=f"Processing {dataset_name}"): batch = ds[indices.tolist()] inputs = processor( text=batch["txt"], images=batch["jpg"], return_tensors="pt", padding=True, ) inputs["input_ids"] = inputs["input_ids"][:, :77] inputs["attention_mask"] = inputs["attention_mask"][:, :77] inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): output = model(**inputs) all_image_embeds.append(output.image_embeds.cpu()) all_text_embeds.append(output.text_embeds.cpu()) all_image_embeds = torch.cat(all_image_embeds, dim=0) all_text_embeds = torch.cat(all_text_embeds, dim=0) text_img_sim = all_text_embeds @ all_image_embeds.t() def get_top_k(sim_mat, k_arr): ordered_winners = torch.argsort(sim_mat, dim=-1, descending=True) correct_winner_mask = ( ordered_winners == torch.arange(ordered_winners.shape[0]) .unsqueeze(1) .to(ordered_winners.device) ).long() return [ correct_winner_mask[:, :k].sum(-1).float().mean().item() for k in k_arr ] k_arr = [1, 5, 10] metrics = { **{ f"{dataset_name}:image_recall@{k}": val for k, val in zip(k_arr, get_top_k(text_img_sim, k_arr)) }, **{ f"{dataset_name}:text_recall@{k}": val for k, val in zip(k_arr, get_top_k(text_img_sim.t(), k_arr)) }, } return metrics def clip_evaluator(self, model, device, desc, n_examples=-1): metrics = {} for name in ["mscoco", "flickr"]: metrics.update( self.clip_dataset_evaluator(model, device, desc, name, n_examples) ) return metrics def _compute(self, model_name, dataset_names, n_examples): actual_model_name = model_name[0] actual_dataset_name_str = dataset_names[0] actual_n_examples = n_examples[0] device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = CLIPModel.from_pretrained(actual_model_name).to(device) datasets_to_evaluate = [actual_dataset_name_str] metrics = {} for ds_name_loop_var in datasets_to_evaluate: dataset_metrics = self.clip_dataset_evaluator( model=model, device=device, desc=actual_model_name, dataset_name=ds_name_loop_var, n_examples=actual_n_examples, ) metrics.update(dataset_metrics) return metrics