wanzin commited on
Commit
ce0bd03
Β·
1 Parent(s): 4f732a3

add clip evaluator on mscoco and flickr8k dataset

Browse files
Files changed (3) hide show
  1. app.py +77 -0
  2. clip_eval.py +149 -0
  3. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import evaluate
3
+
4
+ clip_metric = evaluate.load("d-matrix/dmx_clip_eval")
5
+ print("Successfully loaded CLIP evaluation metric")
6
+
7
+ AVAILABLE_MODELS = [
8
+ "openai/clip-vit-base-patch32",
9
+ "openai/clip-vit-large-patch14",
10
+ "openai/clip-vit-base-patch16",
11
+ ]
12
+
13
+ AVAILABLE_DATASETS = ["mscoco", "flickr"]
14
+
15
+ with gr.Blocks(title="CLIP Evaluation") as demo:
16
+ gr.Markdown("# CLIP Model Evaluation")
17
+ gr.Markdown(
18
+ """
19
+ This tool evaluates CLIP models on image-text retrieval tasks using standard datasets.
20
+ """
21
+ )
22
+
23
+ with gr.Row():
24
+ with gr.Column():
25
+ model_input = gr.Dropdown(
26
+ choices=AVAILABLE_MODELS, value=AVAILABLE_MODELS[0], label="CLIP Model"
27
+ )
28
+
29
+ dataset_input = gr.Dropdown(
30
+ choices=AVAILABLE_DATASETS, value="mscoco", label="Dataset"
31
+ )
32
+
33
+ samples_input = gr.Slider(
34
+ minimum=1, maximum=10, value=1, step=1, label="Number of samples"
35
+ )
36
+
37
+ evaluate_button = gr.Button("Evaluate Model")
38
+
39
+ with gr.Column():
40
+ results_output = gr.Markdown("Results will appear here")
41
+
42
+ def evaluate_clip(model_name, dataset, num_samples, progress=gr.Progress()):
43
+ progress(0, desc="Evaluating CLIP model...")
44
+
45
+ results = clip_metric.compute(
46
+ model_name=[model_name],
47
+ dataset_names=[dataset],
48
+ n_examples=[int(num_samples)],
49
+ )
50
+
51
+ output = f"## CLIP Evaluation Results\n\n"
52
+ output += f"**Model:** {model_name}\n"
53
+ output += f"**Dataset:** {dataset}\n"
54
+ output += f"**Samples:** {num_samples}\n\n"
55
+
56
+ output += "**Image Retrieval (Text→Image):**\n"
57
+ for k in [1, 5, 10]:
58
+ metric_name = f"{dataset}:image_recall@{k}"
59
+ if metric_name in results:
60
+ output += f"* Recall@{k}: {results[metric_name]:.4f}\n"
61
+
62
+ output += "\n**Text Retrieval (Image→Text):**\n"
63
+ for k in [1, 5, 10]:
64
+ metric_name = f"{dataset}:text_recall@{k}"
65
+ if metric_name in results:
66
+ output += f"* Recall@{k}: {results[metric_name]:.4f}\n"
67
+
68
+ return output
69
+
70
+ evaluate_button.click(
71
+ fn=evaluate_clip,
72
+ inputs=[model_input, dataset_input, samples_input],
73
+ outputs=results_output,
74
+ )
75
+
76
+ if __name__ == "__main__":
77
+ demo.launch()
clip_eval.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import evaluate
2
+ from evaluate.utils.file_utils import add_start_docstrings
3
+ import datasets
4
+ import torch
5
+ from transformers import CLIPProcessor, CLIPModel
6
+ from tqdm import tqdm
7
+
8
+ _DESCRIPTION = """
9
+ This metric evaluates CLIP models on image-text retrieval tasks using standard datasets.
10
+ It calculates Recall@K metrics for both text-to-image and image-to-text retrieval.
11
+ """
12
+
13
+ _KWARGS_DESCRIPTION = """
14
+ Args:
15
+ model_name: Name or path of the CLIP model to evaluate (e.g., "openai/clip-vit-base-patch32")
16
+ dataset_names: List of dataset names to evaluate on (choices: "mscoco", "flickr")
17
+ n_examples: Number of examples to use for evaluation (-1 for all)
18
+
19
+ Returns:
20
+ Dictionary containing Recall@K metrics for each dataset and retrieval direction
21
+ """
22
+
23
+ _CITATION = """
24
+ @inproceedings{radford2021learning,
25
+ title={Learning transferable visual models from natural language supervision},
26
+ author={Radford, Alec and Kim, Jong Wook and Hallacy, Chris and Ramesh, Aditya and others},
27
+ booktitle={International Conference on Machine Learning},
28
+ year={2021},
29
+ }
30
+ """
31
+
32
+
33
+ @add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
34
+ class DmxClipEval(evaluate.Metric):
35
+ def _info(self):
36
+ return evaluate.MetricInfo(
37
+ module_type="metric",
38
+ description=_DESCRIPTION,
39
+ citation=_CITATION,
40
+ inputs_description=_KWARGS_DESCRIPTION,
41
+ features=[
42
+ datasets.Features(
43
+ {
44
+ "model_name": datasets.Value("string"),
45
+ "dataset_names": datasets.Value("string"),
46
+ "n_examples": datasets.Value("int32"),
47
+ }
48
+ ),
49
+ ],
50
+ )
51
+
52
+ def clip_dataset_evaluator(
53
+ self, model, device, dataset_name="mscoco", n_examples=-1
54
+ ):
55
+ processor = CLIPProcessor.from_pretrained(model.config._name_or_path)
56
+ if dataset_name == "mscoco":
57
+ ds = datasets.load_dataset(
58
+ "clip-benchmark/wds_mscoco_captions", split="test"
59
+ )
60
+ elif dataset_name == "flickr":
61
+ ds = datasets.load_dataset("clip-benchmark/wds_flickr8k", split="test")
62
+ else:
63
+ raise ValueError(f"invalid dataset name : {dataset_name}")
64
+
65
+ if n_examples != -1:
66
+ ds = ds.select(range(min(n_examples, len(ds))))
67
+
68
+ dl = torch.utils.data.DataLoader(torch.arange(len(ds)), batch_size=8)
69
+ all_image_embeds = []
70
+ all_text_embeds = []
71
+
72
+ for indices in tqdm(dl, desc=f"Processing {dataset_name}"):
73
+ batch = ds[indices.tolist()]
74
+ inputs = processor(
75
+ text=batch["txt"],
76
+ images=batch["jpg"],
77
+ return_tensors="pt",
78
+ padding=True,
79
+ )
80
+ inputs["input_ids"] = inputs["input_ids"][:, :77]
81
+ inputs["attention_mask"] = inputs["attention_mask"][:, :77]
82
+ inputs = {k: v.to(device) for k, v in inputs.items()}
83
+
84
+ with torch.no_grad():
85
+ output = model(**inputs)
86
+
87
+ all_image_embeds.append(output.image_embeds.cpu())
88
+ all_text_embeds.append(output.text_embeds.cpu())
89
+
90
+ all_image_embeds = torch.cat(all_image_embeds, dim=0)
91
+ all_text_embeds = torch.cat(all_text_embeds, dim=0)
92
+ text_img_sim = all_text_embeds @ all_image_embeds.t()
93
+
94
+ def get_top_k(sim_mat, k_arr):
95
+ ordered_winners = torch.argsort(sim_mat, dim=-1, descending=True)
96
+ correct_winner_mask = (
97
+ ordered_winners
98
+ == torch.arange(ordered_winners.shape[0])
99
+ .unsqueeze(1)
100
+ .to(ordered_winners.device)
101
+ ).long()
102
+ return [
103
+ correct_winner_mask[:, :k].sum(-1).float().mean().item() for k in k_arr
104
+ ]
105
+
106
+ k_arr = [1, 5, 10]
107
+ metrics = {
108
+ **{
109
+ f"{dataset_name}:image_recall@{k}": val
110
+ for k, val in zip(k_arr, get_top_k(text_img_sim, k_arr))
111
+ },
112
+ **{
113
+ f"{dataset_name}:text_recall@{k}": val
114
+ for k, val in zip(k_arr, get_top_k(text_img_sim.t(), k_arr))
115
+ },
116
+ }
117
+ return metrics
118
+
119
+ def clip_evaluator(self, model, device, desc, n_examples=-1):
120
+ metrics = {}
121
+ for name in ["mscoco", "flickr"]:
122
+ metrics.update(
123
+ self.clip_dataset_evaluator(model, device, desc, name, n_examples)
124
+ )
125
+ return metrics
126
+
127
+ def _compute(self, model_name, dataset_names, n_examples):
128
+
129
+ actual_model_name = model_name[0]
130
+ actual_dataset_name_str = dataset_names[0]
131
+ actual_n_examples = n_examples[0]
132
+
133
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
134
+ model = CLIPModel.from_pretrained(actual_model_name).to(device)
135
+
136
+ datasets_to_evaluate = [actual_dataset_name_str]
137
+
138
+ metrics = {}
139
+ for ds_name_loop_var in datasets_to_evaluate:
140
+ dataset_metrics = self.clip_dataset_evaluator(
141
+ model=model,
142
+ device=device,
143
+ desc=actual_model_name,
144
+ dataset_name=ds_name_loop_var,
145
+ n_examples=actual_n_examples,
146
+ )
147
+ metrics.update(dataset_metrics)
148
+
149
+ return metrics
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio>=3.50.0
2
+ torch>=2.5.0
3
+ transformers>=4.48.0
4
+ datasets>=2.21.0
5
+ tqdm>=4.65.0
6
+ evaluate>= 0.4.3