rohansampath commited on
Commit
e8d7a5b
·
verified ·
1 Parent(s): 76926dc

Made some changes

Browse files
Files changed (1) hide show
  1. app.py +99 -40
app.py CHANGED
@@ -10,11 +10,15 @@ import io
10
  import base64
11
  import os
12
  from huggingface_hub import login
13
- from transformers import AutoTokenizer, AutoModel
14
 
15
- hf_token = os.getenv("HF_TOKEN_READ_WRITE") # Read the token from Secrets
16
- login(hf_token)
 
 
 
 
17
 
 
18
  if torch.cuda.is_available():
19
  print("✅ GPU is available")
20
  print("GPU Name:", torch.cuda.get_device_name(0))
@@ -24,18 +28,21 @@ else:
24
  # ---------------------------------------------------------------------------
25
  # 1. Define model name and load model/tokenizer
26
  # ---------------------------------------------------------------------------
27
- model_name = "mistralai/Mistral-7B-Instruct-v0.3" # fictional placeholder
28
 
29
- tokenizer = AutoTokenizer.from_pretrained(model_name)
30
  device = "cuda" if torch.cuda.is_available() else "cpu"
31
- model = AutoModelForCausalLM.from_pretrained(model_name, token=hf_token, torch_dtype=torch.float16, device_map="auto")
 
 
 
 
 
32
 
33
  print(f"✅ Model loaded on {device}")
34
- #model = AutoModelForCausalLM.from_pretrained(model_name)
35
 
36
  # ---------------------------------------------------------------------------
37
- # 2. Define a tiny "dataset" for demonstration
38
- # In reality, you'll load a real dataset from HF or custom code.
39
  # ---------------------------------------------------------------------------
40
  test_data = [
41
  {"question": "What is 2+2?", "answer": "4"},
@@ -44,7 +51,7 @@ test_data = [
44
  ]
45
 
46
  # ---------------------------------------------------------------------------
47
- # 3. Load a metric (accuracy) from Hugging Face evaluate library
48
  # ---------------------------------------------------------------------------
49
  accuracy_metric = evaluate.load("accuracy")
50
 
@@ -53,31 +60,32 @@ accuracy_metric = evaluate.load("accuracy")
53
  # ---------------------------------------------------------------------------
54
  def generate_answer(question):
55
  """
56
- Generates an answer to the given question using the loaded model.
57
  """
58
- # Simple prompt
59
- prompt = f"Question: {question}\nAnswer:"
60
 
61
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
62
  with torch.no_grad():
63
  outputs = model.generate(
64
  **inputs,
65
- max_new_tokens=30,
66
  temperature=0.0, # deterministic
 
 
67
  )
68
  text_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
69
- return text_output
 
70
 
71
  def parse_answer(model_output):
72
  """
73
- Heuristic to extract the final numeric answer from model's text.
74
- You can customize this regex or logic as needed.
75
  """
76
- # Example: find digits (possibly multiple, but we keep the first match)
77
- match = re.search(r"(\d+)", model_output)
78
  if match:
79
  return match.group(1)
80
- # fallback to entire text if no digits found
81
  return model_output.strip()
82
 
83
  # ---------------------------------------------------------------------------
@@ -86,6 +94,7 @@ def parse_answer(model_output):
86
  def run_evaluation():
87
  predictions = []
88
  references = []
 
89
 
90
  for sample in test_data:
91
  question = sample["question"]
@@ -97,54 +106,104 @@ def run_evaluation():
97
 
98
  predictions.append(predicted_answer)
99
  references.append(reference_answer)
 
 
 
 
 
 
100
 
101
- # Normalize answers (simple: just remove spaces/punctuation, lower case)
102
  def normalize_answer(ans):
103
- return ans.lower().strip()
104
 
105
  norm_preds = [normalize_answer(p) for p in predictions]
106
- norm_refs = [normalize_answer(r) for r in references]
107
 
108
  # Compute accuracy
109
  results = accuracy_metric.compute(predictions=norm_preds, references=norm_refs)
110
  accuracy = results["accuracy"]
111
 
112
- # Create a simple bar chart: correct vs. incorrect
113
  correct_count = sum(p == r for p, r in zip(norm_preds, norm_refs))
114
  incorrect_count = len(test_data) - correct_count
115
 
116
- fig, ax = plt.subplots()
117
- ax.bar(["Correct", "Incorrect"], [correct_count, incorrect_count], color=["green", "red"])
 
 
 
 
 
 
 
 
 
 
118
  ax.set_title("Evaluation Results")
119
  ax.set_ylabel("Count")
120
- ax.set_ylim([0, len(test_data)])
121
 
122
- # Convert the plot to a base64-encoded PNG for Gradio display
123
  buf = io.BytesIO()
124
- plt.savefig(buf, format="png")
125
  buf.seek(0)
126
  plt.close(fig)
127
  data = base64.b64encode(buf.read()).decode("utf-8")
128
- image_url = f"data:image/png;base64,{data}"
129
 
130
- # Return text and the plot
131
- return f"Accuracy: {accuracy:.2f}", image_url
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
  # ---------------------------------------------------------------------------
134
- # 6. Gradio App
135
  # ---------------------------------------------------------------------------
136
  with gr.Blocks() as demo:
137
- gr.Markdown("# Simple Math Evaluation with 'Llama 3.2'")
138
-
139
- eval_button = gr.Button("Run Evaluation")
 
 
 
 
140
  output_text = gr.Textbox(label="Results")
141
- output_plot = gr.HTML(label="Plot")
142
-
143
  eval_button.click(
144
  fn=run_evaluation,
145
  inputs=None,
146
  outputs=[output_text, output_plot]
147
  )
148
 
149
- demo.launch()
150
-
 
10
  import base64
11
  import os
12
  from huggingface_hub import login
 
13
 
14
+ # Read token and login
15
+ hf_token = os.getenv("HF_TOKEN_READ_WRITE")
16
+ if hf_token:
17
+ login(hf_token)
18
+ else:
19
+ print("⚠️ No HF_TOKEN_READ_WRITE found in environment")
20
 
21
+ # Check GPU availability
22
  if torch.cuda.is_available():
23
  print("✅ GPU is available")
24
  print("GPU Name:", torch.cuda.get_device_name(0))
 
28
  # ---------------------------------------------------------------------------
29
  # 1. Define model name and load model/tokenizer
30
  # ---------------------------------------------------------------------------
31
+ model_name = "mistralai/Mistral-7B-Instruct-v0.3"
32
 
33
+ tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token)
34
  device = "cuda" if torch.cuda.is_available() else "cpu"
35
+ model = AutoModelForCausalLM.from_pretrained(
36
+ model_name,
37
+ token=hf_token,
38
+ torch_dtype=torch.float16,
39
+ device_map="auto"
40
+ )
41
 
42
  print(f"✅ Model loaded on {device}")
 
43
 
44
  # ---------------------------------------------------------------------------
45
+ # 2. Test dataset
 
46
  # ---------------------------------------------------------------------------
47
  test_data = [
48
  {"question": "What is 2+2?", "answer": "4"},
 
51
  ]
52
 
53
  # ---------------------------------------------------------------------------
54
+ # 3. Load metric
55
  # ---------------------------------------------------------------------------
56
  accuracy_metric = evaluate.load("accuracy")
57
 
 
60
  # ---------------------------------------------------------------------------
61
  def generate_answer(question):
62
  """
63
+ Generates an answer using Mistral's instruction format.
64
  """
65
+ # Mistral instruction format
66
+ prompt = f"""<s>[INST] {question} [/INST]"""
67
 
68
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
69
  with torch.no_grad():
70
  outputs = model.generate(
71
  **inputs,
72
+ max_new_tokens=50,
73
  temperature=0.0, # deterministic
74
+ pad_token_id=tokenizer.pad_token_id,
75
+ eos_token_id=tokenizer.eos_token_id
76
  )
77
  text_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
78
+ # Remove the original question from the output
79
+ return text_output.replace(question, "").strip()
80
 
81
  def parse_answer(model_output):
82
  """
83
+ Extract numeric answer from model's text output.
 
84
  """
85
+ # Look for numbers (including decimals)
86
+ match = re.search(r"(-?\d*\.?\d+)", model_output)
87
  if match:
88
  return match.group(1)
 
89
  return model_output.strip()
90
 
91
  # ---------------------------------------------------------------------------
 
94
  def run_evaluation():
95
  predictions = []
96
  references = []
97
+ raw_outputs = [] # Store full model outputs for display
98
 
99
  for sample in test_data:
100
  question = sample["question"]
 
106
 
107
  predictions.append(predicted_answer)
108
  references.append(reference_answer)
109
+ raw_outputs.append({
110
+ "question": question,
111
+ "model_output": model_output,
112
+ "parsed_answer": predicted_answer,
113
+ "reference": reference_answer
114
+ })
115
 
116
+ # Normalize answers
117
  def normalize_answer(ans):
118
+ return str(ans).lower().strip()
119
 
120
  norm_preds = [normalize_answer(p) for p in predictions]
121
+ norm_refs = [normalize_answer(r) for r in references]
122
 
123
  # Compute accuracy
124
  results = accuracy_metric.compute(predictions=norm_preds, references=norm_refs)
125
  accuracy = results["accuracy"]
126
 
127
+ # Create visualization
128
  correct_count = sum(p == r for p, r in zip(norm_preds, norm_refs))
129
  incorrect_count = len(test_data) - correct_count
130
 
131
+ fig, ax = plt.subplots(figsize=(8, 6))
132
+ bars = ax.bar(["Correct", "Incorrect"],
133
+ [correct_count, incorrect_count],
134
+ color=["#2ecc71", "#e74c3c"])
135
+
136
+ # Add value labels on bars
137
+ for bar in bars:
138
+ height = bar.get_height()
139
+ ax.text(bar.get_x() + bar.get_width()/2., height,
140
+ f'{int(height)}',
141
+ ha='center', va='bottom')
142
+
143
  ax.set_title("Evaluation Results")
144
  ax.set_ylabel("Count")
145
+ ax.set_ylim([0, len(test_data) + 0.5]) # Add some padding at top
146
 
147
+ # Convert plot to base64
148
  buf = io.BytesIO()
149
+ plt.savefig(buf, format="png", bbox_inches='tight', dpi=300)
150
  buf.seek(0)
151
  plt.close(fig)
152
  data = base64.b64encode(buf.read()).decode("utf-8")
 
153
 
154
+ # Create detailed results HTML
155
+ details_html = """
156
+ <div style="margin-top: 20px;">
157
+ <h3>Detailed Results:</h3>
158
+ <table style="width:100%; border-collapse: collapse;">
159
+ <tr style="background-color: #f5f5f5;">
160
+ <th style="padding: 8px; border: 1px solid #ddd;">Question</th>
161
+ <th style="padding: 8px; border: 1px solid #ddd;">Model Output</th>
162
+ <th style="padding: 8px; border: 1px solid #ddd;">Parsed Answer</th>
163
+ <th style="padding: 8px; border: 1px solid #ddd;">Reference</th>
164
+ </tr>
165
+ """
166
+
167
+ for result in raw_outputs:
168
+ details_html += f"""
169
+ <tr>
170
+ <td style="padding: 8px; border: 1px solid #ddd;">{result['question']}</td>
171
+ <td style="padding: 8px; border: 1px solid #ddd;">{result['model_output']}</td>
172
+ <td style="padding: 8px; border: 1px solid #ddd;">{result['parsed_answer']}</td>
173
+ <td style="padding: 8px; border: 1px solid #ddd;">{result['reference']}</td>
174
+ </tr>
175
+ """
176
+
177
+ details_html += "</table></div>"
178
+
179
+ # Combine plot and details
180
+ full_html = f"""
181
+ <div>
182
+ <img src="data:image/png;base64,{data}" style="width:100%; max-width:600px;">
183
+ {details_html}
184
+ </div>
185
+ """
186
+
187
+ return f"Accuracy: {accuracy:.2f}", full_html
188
 
189
  # ---------------------------------------------------------------------------
190
+ # 6. Gradio Interface
191
  # ---------------------------------------------------------------------------
192
  with gr.Blocks() as demo:
193
+ gr.Markdown("# Mistral-7B Math Evaluation Demo")
194
+ gr.Markdown("""
195
+ This demo evaluates Mistral-7B on basic math problems.
196
+ Press the button below to run the evaluation.
197
+ """)
198
+
199
+ eval_button = gr.Button("Run Evaluation", variant="primary")
200
  output_text = gr.Textbox(label="Results")
201
+ output_plot = gr.HTML(label="Visualization and Details")
202
+
203
  eval_button.click(
204
  fn=run_evaluation,
205
  inputs=None,
206
  outputs=[output_text, output_plot]
207
  )
208
 
209
+ demo.launch()