rohansampath commited on
Commit
0e843f9
·
verified ·
1 Parent(s): 1856ad2

Delete toy_dataset_eval.py

Browse files
Files changed (1) hide show
  1. toy_dataset_eval.py +0 -151
toy_dataset_eval.py DELETED
@@ -1,151 +0,0 @@
1
- import torch
2
- import evaluate
3
- import re
4
- import base64
5
- import io
6
- import matplotlib.pyplot as plt
7
- from transformers import AutoModelForCausalLM, AutoTokenizer
8
- import spaces # Assuming this is a custom or predefined library for GPU handling
9
-
10
- # ---------------------------------------------------------------------------
11
- # 1. Simple Test Dataset to Run GPU Calls On
12
- # ---------------------------------------------------------------------------
13
- test_data = [
14
- {"question": "What is 2+2?", "answer": "4"},
15
- {"question": "What is 3*3?", "answer": "9"},
16
- {"question": "What is 10/2?", "answer": "5"},
17
- ]
18
-
19
- # ---------------------------------------------------------------------------
20
- # 2. Load metric
21
- # ---------------------------------------------------------------------------
22
- accuracy_metric = evaluate.load("accuracy")
23
-
24
- # ---------------------------------------------------------------------------
25
- # 4. Inference helper functions
26
- # ---------------------------------------------------------------------------
27
- @spaces.GPU
28
- def generate_answer(question, model, tokenizer):
29
- """
30
- Generates an answer using Mistral's instruction format.
31
- """
32
- # Mistral instruction format
33
- prompt = f"""<s>[INST] {question}. Provide only the numerical answer. [/INST]"""
34
-
35
- inputs = tokenizer(prompt, return_tensors="pt").to('cuda')
36
- with torch.no_grad():
37
- outputs = model.generate(
38
- **inputs,
39
- max_new_tokens=50,
40
- pad_token_id=tokenizer.pad_token_id,
41
- eos_token_id=tokenizer.eos_token_id
42
- )
43
- text_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
44
- # Remove the original question from the output
45
- return text_output.replace(question, "").strip()
46
-
47
- def parse_answer(model_output):
48
- """
49
- Extract numeric answer from model's text output.
50
- """
51
- # Look for numbers (including decimals)
52
- match = re.search(r"(-?\d*\.?\d+)", model_output)
53
- if match:
54
- return match.group(1)
55
- return model_output.strip()
56
-
57
-
58
- @spaces.GPU(duration=120) # Allow up to 2 minutes for full evaluation
59
- def evaluate_toy_dataset(model, tokenizer):
60
- predictions = []
61
- references = []
62
- raw_outputs = [] # Store full model outputs for display
63
-
64
- for sample in test_data:
65
- question = sample["question"]
66
- reference_answer = sample["answer"]
67
-
68
- # Model inference
69
- model_output = generate_answer(question, model, tokenizer)
70
- predicted_answer = parse_answer(model_output)
71
-
72
- predictions.append(predicted_answer)
73
- references.append(reference_answer)
74
- raw_outputs.append({
75
- "question": question,
76
- "model_output": model_output,
77
- "parsed_answer": predicted_answer,
78
- "reference": reference_answer
79
- })
80
-
81
- # Normalize answers
82
- def normalize_answer(ans):
83
- return str(ans).lower().strip()
84
-
85
- norm_preds = [normalize_answer(p) for p in predictions]
86
- norm_refs = [normalize_answer(r) for r in references]
87
-
88
- # Compute accuracy
89
- results = accuracy_metric.compute(predictions=norm_preds, references=norm_refs)
90
- accuracy = results["accuracy"]
91
-
92
- # Create visualization
93
- fig, ax = plt.subplots(figsize=(8, 6))
94
- correct_count = sum(p == r for p, r in zip(norm_preds, norm_refs))
95
- incorrect_count = len(test_data) - correct_count
96
-
97
- bars = ax.bar(["Correct", "Incorrect"],
98
- [correct_count, incorrect_count],
99
- color=["#2ecc71", "#e74c3c"])
100
-
101
- # Add value labels on bars
102
- for bar in bars:
103
- height = bar.get_height()
104
- ax.text(bar.get_x() + bar.get_width()/2., height,
105
- f'{int(height)}',
106
- ha='center', va='bottom')
107
-
108
- ax.set_title("Evaluation Results")
109
- ax.set_ylabel("Count")
110
- ax.set_ylim([0, len(test_data) + 0.5])
111
-
112
- # Convert plot to base64
113
- buf = io.BytesIO()
114
- plt.savefig(buf, format="png", bbox_inches='tight', dpi=300)
115
- buf.seek(0)
116
- plt.close(fig)
117
- data = base64.b64encode(buf.read()).decode("utf-8")
118
-
119
- # Create detailed results HTML
120
- details_html = """
121
- <div style="margin-top: 20px;">
122
- <h3>Detailed Results:</h3>
123
- <table style="width:100%; border-collapse: collapse;">
124
- <tr style="background-color: #f5f5f5;">
125
- <th style="padding: 8px; border: 1px solid #ddd;">Question</th>
126
- <th style="padding: 8px; border: 1px solid #ddd;">Model Output</th>
127
- <th style="padding: 8px; border: 1px solid #ddd;">Parsed Answer</th>
128
- <th style="padding: 8px; border: 1px solid #ddd;">Reference</th>
129
- </tr>
130
- """
131
-
132
- for result in raw_outputs:
133
- details_html += f"""
134
- <tr>
135
- <td style="padding: 8px; border: 1px solid #ddd;">{result['question']}</td>
136
- <td style="padding: 8px; border: 1px solid #ddd;">{result['model_output']}</td>
137
- <td style="padding: 8px; border: 1px solid #ddd;">{result['parsed_answer']}</td>
138
- <td style="padding: 8px; border: 1px solid #ddd;">{result['reference']}</td>
139
- </tr>
140
- """
141
-
142
- details_html += "</table></div>"
143
-
144
- full_html = f"""
145
- <div>
146
- <img src="data:image/png;base64,{data}" style="width:100%; max-width:600px;">
147
- {details_html}
148
- </div>
149
- """
150
-
151
- return f"Accuracy: {accuracy:.2f}", full_html