rohansampath commited on
Commit
8fa9808
·
verified ·
1 Parent(s): 4191f43

Create toy-dataset-eval.py

Browse files

A simple file to verify that evals are working properly.

Files changed (1) hide show
  1. toy-dataset-eval.py +153 -0
toy-dataset-eval.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import evaluate
3
+ import re
4
+ import base64
5
+ import io
6
+ import matplotlib.pyplot as plt
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer
8
+ import spaces # Assuming this is a custom or predefined library for GPU handling
9
+
10
+ # ---------------------------------------------------------------------------
11
+ # 1. Simple Test Dataset to Run GPU Calls On
12
+ # ---------------------------------------------------------------------------
13
+ test_data = [
14
+ {"question": "What is 2+2?", "answer": "4"},
15
+ {"question": "What is 3*3?", "answer": "9"},
16
+ {"question": "What is 10/2?", "answer": "5"},
17
+ ]
18
+
19
+ # ---------------------------------------------------------------------------
20
+ # 2. Load metric
21
+ # ---------------------------------------------------------------------------
22
+ accuracy_metric = evaluate.load("accuracy")
23
+
24
+ # ---------------------------------------------------------------------------
25
+ # 4. Inference helper functions
26
+ # ---------------------------------------------------------------------------
27
+ @spaces.GPU
28
+ def generate_answer(question):
29
+ """
30
+ Generates an answer using Mistral's instruction format.
31
+ """
32
+ model, tokenizer = load_model()
33
+
34
+ # Mistral instruction format
35
+ prompt = f"""<s>[INST] {question}. Provide only the numerical answer. [/INST]"""
36
+
37
+ inputs = tokenizer(prompt, return_tensors="pt").to('cuda')
38
+ with torch.no_grad():
39
+ outputs = model.generate(
40
+ **inputs,
41
+ max_new_tokens=50,
42
+ pad_token_id=tokenizer.pad_token_id,
43
+ eos_token_id=tokenizer.eos_token_id
44
+ )
45
+ text_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
46
+ # Remove the original question from the output
47
+ return text_output.replace(question, "").strip()
48
+
49
+ def parse_answer(model_output):
50
+ """
51
+ Extract numeric answer from model's text output.
52
+ """
53
+ # Look for numbers (including decimals)
54
+ match = re.search(r"(-?\d*\.?\d+)", model_output)
55
+ if match:
56
+ return match.group(1)
57
+ return model_output.strip()
58
+
59
+
60
+ @spaces.GPU(duration=120) # Allow up to 2 minutes for full evaluation
61
+ def run_evaluation():
62
+ predictions = []
63
+ references = []
64
+ raw_outputs = [] # Store full model outputs for display
65
+
66
+ for sample in test_data:
67
+ question = sample["question"]
68
+ reference_answer = sample["answer"]
69
+
70
+ # Model inference
71
+ model_output = generate_answer(question)
72
+ predicted_answer = parse_answer(model_output)
73
+
74
+ predictions.append(predicted_answer)
75
+ references.append(reference_answer)
76
+ raw_outputs.append({
77
+ "question": question,
78
+ "model_output": model_output,
79
+ "parsed_answer": predicted_answer,
80
+ "reference": reference_answer
81
+ })
82
+
83
+ # Normalize answers
84
+ def normalize_answer(ans):
85
+ return str(ans).lower().strip()
86
+
87
+ norm_preds = [normalize_answer(p) for p in predictions]
88
+ norm_refs = [normalize_answer(r) for r in references]
89
+
90
+ # Compute accuracy
91
+ results = accuracy_metric.compute(predictions=norm_preds, references=norm_refs)
92
+ accuracy = results["accuracy"]
93
+
94
+ # Create visualization
95
+ fig, ax = plt.subplots(figsize=(8, 6))
96
+ correct_count = sum(p == r for p, r in zip(norm_preds, norm_refs))
97
+ incorrect_count = len(test_data) - correct_count
98
+
99
+ bars = ax.bar(["Correct", "Incorrect"],
100
+ [correct_count, incorrect_count],
101
+ color=["#2ecc71", "#e74c3c"])
102
+
103
+ # Add value labels on bars
104
+ for bar in bars:
105
+ height = bar.get_height()
106
+ ax.text(bar.get_x() + bar.get_width()/2., height,
107
+ f'{int(height)}',
108
+ ha='center', va='bottom')
109
+
110
+ ax.set_title("Evaluation Results")
111
+ ax.set_ylabel("Count")
112
+ ax.set_ylim([0, len(test_data) + 0.5])
113
+
114
+ # Convert plot to base64
115
+ buf = io.BytesIO()
116
+ plt.savefig(buf, format="png", bbox_inches='tight', dpi=300)
117
+ buf.seek(0)
118
+ plt.close(fig)
119
+ data = base64.b64encode(buf.read()).decode("utf-8")
120
+
121
+ # Create detailed results HTML
122
+ details_html = """
123
+ <div style="margin-top: 20px;">
124
+ <h3>Detailed Results:</h3>
125
+ <table style="width:100%; border-collapse: collapse;">
126
+ <tr style="background-color: #f5f5f5;">
127
+ <th style="padding: 8px; border: 1px solid #ddd;">Question</th>
128
+ <th style="padding: 8px; border: 1px solid #ddd;">Model Output</th>
129
+ <th style="padding: 8px; border: 1px solid #ddd;">Parsed Answer</th>
130
+ <th style="padding: 8px; border: 1px solid #ddd;">Reference</th>
131
+ </tr>
132
+ """
133
+
134
+ for result in raw_outputs:
135
+ details_html += f"""
136
+ <tr>
137
+ <td style="padding: 8px; border: 1px solid #ddd;">{result['question']}</td>
138
+ <td style="padding: 8px; border: 1px solid #ddd;">{result['model_output']}</td>
139
+ <td style="padding: 8px; border: 1px solid #ddd;">{result['parsed_answer']}</td>
140
+ <td style="padding: 8px; border: 1px solid #ddd;">{result['reference']}</td>
141
+ </tr>
142
+ """
143
+
144
+ details_html += "</table></div>"
145
+
146
+ full_html = f"""
147
+ <div>
148
+ <img src="data:image/png;base64,{data}" style="width:100%; max-width:600px;">
149
+ {details_html}
150
+ </div>
151
+ """
152
+
153
+ return f"Accuracy: {accuracy:.2f}", full_html