rohansampath's picture
Made some changes
e8d7a5b verified
raw
history blame
7.19 kB
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import evaluate
import re
import matplotlib
matplotlib.use('Agg') # for non-interactive envs
import matplotlib.pyplot as plt
import io
import base64
import os
from huggingface_hub import login
# Read token and login
hf_token = os.getenv("HF_TOKEN_READ_WRITE")
if hf_token:
login(hf_token)
else:
print("⚠️ No HF_TOKEN_READ_WRITE found in environment")
# Check GPU availability
if torch.cuda.is_available():
print("✅ GPU is available")
print("GPU Name:", torch.cuda.get_device_name(0))
else:
print("❌ No GPU available")
# ---------------------------------------------------------------------------
# 1. Define model name and load model/tokenizer
# ---------------------------------------------------------------------------
model_name = "mistralai/Mistral-7B-Instruct-v0.3"
tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForCausalLM.from_pretrained(
model_name,
token=hf_token,
torch_dtype=torch.float16,
device_map="auto"
)
print(f"✅ Model loaded on {device}")
# ---------------------------------------------------------------------------
# 2. Test dataset
# ---------------------------------------------------------------------------
test_data = [
{"question": "What is 2+2?", "answer": "4"},
{"question": "What is 3*3?", "answer": "9"},
{"question": "What is 10/2?", "answer": "5"},
]
# ---------------------------------------------------------------------------
# 3. Load metric
# ---------------------------------------------------------------------------
accuracy_metric = evaluate.load("accuracy")
# ---------------------------------------------------------------------------
# 4. Inference helper functions
# ---------------------------------------------------------------------------
def generate_answer(question):
"""
Generates an answer using Mistral's instruction format.
"""
# Mistral instruction format
prompt = f"""<s>[INST] {question} [/INST]"""
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=50,
temperature=0.0, # deterministic
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id
)
text_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Remove the original question from the output
return text_output.replace(question, "").strip()
def parse_answer(model_output):
"""
Extract numeric answer from model's text output.
"""
# Look for numbers (including decimals)
match = re.search(r"(-?\d*\.?\d+)", model_output)
if match:
return match.group(1)
return model_output.strip()
# ---------------------------------------------------------------------------
# 5. Evaluation routine
# ---------------------------------------------------------------------------
def run_evaluation():
predictions = []
references = []
raw_outputs = [] # Store full model outputs for display
for sample in test_data:
question = sample["question"]
reference_answer = sample["answer"]
# Model inference
model_output = generate_answer(question)
predicted_answer = parse_answer(model_output)
predictions.append(predicted_answer)
references.append(reference_answer)
raw_outputs.append({
"question": question,
"model_output": model_output,
"parsed_answer": predicted_answer,
"reference": reference_answer
})
# Normalize answers
def normalize_answer(ans):
return str(ans).lower().strip()
norm_preds = [normalize_answer(p) for p in predictions]
norm_refs = [normalize_answer(r) for r in references]
# Compute accuracy
results = accuracy_metric.compute(predictions=norm_preds, references=norm_refs)
accuracy = results["accuracy"]
# Create visualization
correct_count = sum(p == r for p, r in zip(norm_preds, norm_refs))
incorrect_count = len(test_data) - correct_count
fig, ax = plt.subplots(figsize=(8, 6))
bars = ax.bar(["Correct", "Incorrect"],
[correct_count, incorrect_count],
color=["#2ecc71", "#e74c3c"])
# Add value labels on bars
for bar in bars:
height = bar.get_height()
ax.text(bar.get_x() + bar.get_width()/2., height,
f'{int(height)}',
ha='center', va='bottom')
ax.set_title("Evaluation Results")
ax.set_ylabel("Count")
ax.set_ylim([0, len(test_data) + 0.5]) # Add some padding at top
# Convert plot to base64
buf = io.BytesIO()
plt.savefig(buf, format="png", bbox_inches='tight', dpi=300)
buf.seek(0)
plt.close(fig)
data = base64.b64encode(buf.read()).decode("utf-8")
# Create detailed results HTML
details_html = """
<div style="margin-top: 20px;">
<h3>Detailed Results:</h3>
<table style="width:100%; border-collapse: collapse;">
<tr style="background-color: #f5f5f5;">
<th style="padding: 8px; border: 1px solid #ddd;">Question</th>
<th style="padding: 8px; border: 1px solid #ddd;">Model Output</th>
<th style="padding: 8px; border: 1px solid #ddd;">Parsed Answer</th>
<th style="padding: 8px; border: 1px solid #ddd;">Reference</th>
</tr>
"""
for result in raw_outputs:
details_html += f"""
<tr>
<td style="padding: 8px; border: 1px solid #ddd;">{result['question']}</td>
<td style="padding: 8px; border: 1px solid #ddd;">{result['model_output']}</td>
<td style="padding: 8px; border: 1px solid #ddd;">{result['parsed_answer']}</td>
<td style="padding: 8px; border: 1px solid #ddd;">{result['reference']}</td>
</tr>
"""
details_html += "</table></div>"
# Combine plot and details
full_html = f"""
<div>
<img src="data:image/png;base64,{data}" style="width:100%; max-width:600px;">
{details_html}
</div>
"""
return f"Accuracy: {accuracy:.2f}", full_html
# ---------------------------------------------------------------------------
# 6. Gradio Interface
# ---------------------------------------------------------------------------
with gr.Blocks() as demo:
gr.Markdown("# Mistral-7B Math Evaluation Demo")
gr.Markdown("""
This demo evaluates Mistral-7B on basic math problems.
Press the button below to run the evaluation.
""")
eval_button = gr.Button("Run Evaluation", variant="primary")
output_text = gr.Textbox(label="Results")
output_plot = gr.HTML(label="Visualization and Details")
eval_button.click(
fn=run_evaluation,
inputs=None,
outputs=[output_text, output_plot]
)
demo.launch()