rohansampath commited on
Commit
3195f7f
·
verified ·
1 Parent(s): c5224d3

Update app.py with a basic demonstration of loading Llama-3.1-instruct and running a simple eval on some Math

Browse files
Files changed (1) hide show
  1. app.py +130 -3
app.py CHANGED
@@ -1,7 +1,134 @@
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
  demo.launch()
 
 
1
  import gradio as gr
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ import evaluate
5
+ import re
6
+ import matplotlib
7
+ matplotlib.use('Agg') # for non-interactive envs
8
+ import matplotlib.pyplot as plt
9
+ import io
10
+ import base64
11
 
12
+ # ---------------------------------------------------------------------------
13
+ # 1. Define model name and load model/tokenizer
14
+ # ---------------------------------------------------------------------------
15
+ model_name = "meta-llama/Llama-3.2-1B-Instruct" # fictional placeholder
16
+
17
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
18
+ model = AutoModelForCausalLM.from_pretrained(model_name)
19
+
20
+ # ---------------------------------------------------------------------------
21
+ # 2. Define a tiny "dataset" for demonstration
22
+ # In reality, you'll load a real dataset from HF or custom code.
23
+ # ---------------------------------------------------------------------------
24
+ test_data = [
25
+ {"question": "What is 2+2?", "answer": "4"},
26
+ {"question": "What is 3*3?", "answer": "9"},
27
+ {"question": "What is 10/2?", "answer": "5"},
28
+ ]
29
+
30
+ # ---------------------------------------------------------------------------
31
+ # 3. Load a metric (accuracy) from Hugging Face evaluate library
32
+ # ---------------------------------------------------------------------------
33
+ accuracy_metric = evaluate.load("accuracy")
34
+
35
+ # ---------------------------------------------------------------------------
36
+ # 4. Inference helper functions
37
+ # ---------------------------------------------------------------------------
38
+ def generate_answer(question):
39
+ """
40
+ Generates an answer to the given question using the loaded model.
41
+ """
42
+ # Simple prompt
43
+ prompt = f"Question: {question}\nAnswer:"
44
+
45
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
46
+ with torch.no_grad():
47
+ outputs = model.generate(
48
+ **inputs,
49
+ max_new_tokens=30,
50
+ temperature=0.0, # deterministic
51
+ )
52
+ text_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
53
+ return text_output
54
+
55
+ def parse_answer(model_output):
56
+ """
57
+ Heuristic to extract the final numeric answer from model's text.
58
+ You can customize this regex or logic as needed.
59
+ """
60
+ # Example: find digits (possibly multiple, but we keep the first match)
61
+ match = re.search(r"(\d+)", model_output)
62
+ if match:
63
+ return match.group(1)
64
+ # fallback to entire text if no digits found
65
+ return model_output.strip()
66
+
67
+ # ---------------------------------------------------------------------------
68
+ # 5. Evaluation routine
69
+ # ---------------------------------------------------------------------------
70
+ def run_evaluation():
71
+ predictions = []
72
+ references = []
73
+
74
+ for sample in test_data:
75
+ question = sample["question"]
76
+ reference_answer = sample["answer"]
77
+
78
+ # Model inference
79
+ model_output = generate_answer(question)
80
+ predicted_answer = parse_answer(model_output)
81
+
82
+ predictions.append(predicted_answer)
83
+ references.append(reference_answer)
84
+
85
+ # Normalize answers (simple: just remove spaces/punctuation, lower case)
86
+ def normalize_answer(ans):
87
+ return ans.lower().strip()
88
+
89
+ norm_preds = [normalize_answer(p) for p in predictions]
90
+ norm_refs = [normalize_answer(r) for r in references]
91
+
92
+ # Compute accuracy
93
+ results = accuracy_metric.compute(predictions=norm_preds, references=norm_refs)
94
+ accuracy = results["accuracy"]
95
+
96
+ # Create a simple bar chart: correct vs. incorrect
97
+ correct_count = sum(p == r for p, r in zip(norm_preds, norm_refs))
98
+ incorrect_count = len(test_data) - correct_count
99
+
100
+ fig, ax = plt.subplots()
101
+ ax.bar(["Correct", "Incorrect"], [correct_count, incorrect_count], color=["green", "red"])
102
+ ax.set_title("Evaluation Results")
103
+ ax.set_ylabel("Count")
104
+ ax.set_ylim([0, len(test_data)])
105
+
106
+ # Convert the plot to a base64-encoded PNG for Gradio display
107
+ buf = io.BytesIO()
108
+ plt.savefig(buf, format="png")
109
+ buf.seek(0)
110
+ plt.close(fig)
111
+ data = base64.b64encode(buf.read()).decode("utf-8")
112
+ image_url = f"data:image/png;base64,{data}"
113
+
114
+ # Return text and the plot
115
+ return f"Accuracy: {accuracy:.2f}", image_url
116
+
117
+ # ---------------------------------------------------------------------------
118
+ # 6. Gradio App
119
+ # ---------------------------------------------------------------------------
120
+ with gr.Blocks() as demo:
121
+ gr.Markdown("# Simple Math Evaluation with 'Llama 3.2'")
122
+
123
+ eval_button = gr.Button("Run Evaluation")
124
+ output_text = gr.Textbox(label="Results")
125
+ output_plot = gr.HTML(label="Plot")
126
+
127
+ eval_button.click(
128
+ fn=run_evaluation,
129
+ inputs=None,
130
+ outputs=[output_text, output_plot]
131
+ )
132
 
 
133
  demo.launch()
134
+