Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -23,14 +23,13 @@ def load_model(model_id):
|
|
23 |
return generator
|
24 |
|
25 |
def format_prompt(item):
|
26 |
-
#
|
27 |
-
system_instruction = "Respond ONLY with a single capital letter: A, B, C, or D. No other text."
|
28 |
prompt = f"""{item['question']}
|
29 |
A. {item['choices'][0]}
|
30 |
B. {item['choices'][1]}
|
31 |
C. {item['choices'][2]}
|
32 |
D. {item['choices'][3]}
|
33 |
-
Answer:
|
34 |
return prompt, item['answer']
|
35 |
|
36 |
def extract_choice_letter(output):
|
@@ -38,6 +37,10 @@ def extract_choice_letter(output):
|
|
38 |
match = re.search(r"\b([ABCD])\b", output.strip())
|
39 |
return match.group(1) if match else None
|
40 |
|
|
|
|
|
|
|
|
|
41 |
def evaluate(model_id, sample_count, config_name, progress=gr.Progress()):
|
42 |
if config_name == "ALL":
|
43 |
# Dynamically get all MMLU subjects
|
@@ -50,15 +53,19 @@ def evaluate(model_id, sample_count, config_name, progress=gr.Progress()):
|
|
50 |
for i, subject in enumerate(progress.tqdm(subjects, desc="Evaluating subjects")):
|
51 |
dataset = load_dataset("cais/mmlu", subject, token=HF_TOKEN)["test"]
|
52 |
dataset = dataset.shuffle(seed=42).select(range(min(sample_count, len(dataset))))
|
53 |
-
|
54 |
for j, item in enumerate(progress.tqdm(dataset, desc=f"Processing {subject} samples")):
|
55 |
-
prompt,
|
|
|
|
|
56 |
# Crucial change: Limit generation to 1 new token
|
57 |
output = gen(prompt, max_new_tokens=1, do_sample=False)[0]["generated_text"]
|
58 |
-
output_letter = extract_choice_letter(output)
|
59 |
-
|
60 |
-
|
61 |
-
|
|
|
|
|
62 |
total_samples += len(dataset)
|
63 |
avg_accuracy = total_correct / total_samples * 100
|
64 |
return avg_accuracy, all_results
|
@@ -71,13 +78,16 @@ def evaluate(model_id, sample_count, config_name, progress=gr.Progress()):
|
|
71 |
results = []
|
72 |
|
73 |
for i, item in enumerate(progress.tqdm(dataset, desc=f"Processing {config_name} samples")):
|
74 |
-
prompt,
|
|
|
|
|
75 |
# Crucial change: Limit generation to 1 new token
|
76 |
output = gen(prompt, max_new_tokens=1, do_sample=False)[0]["generated_text"]
|
77 |
-
output_letter = extract_choice_letter(output)
|
78 |
-
|
|
|
79 |
correct += is_correct
|
80 |
-
results.append((prompt, output.strip(),
|
81 |
|
82 |
accuracy = correct / len(dataset) * 100
|
83 |
return accuracy, results
|
|
|
23 |
return generator
|
24 |
|
25 |
def format_prompt(item):
|
26 |
+
# Simplified prompt: rely on max_new_tokens=1 and model's understanding for single-letter answer
|
|
|
27 |
prompt = f"""{item['question']}
|
28 |
A. {item['choices'][0]}
|
29 |
B. {item['choices'][1]}
|
30 |
C. {item['choices'][2]}
|
31 |
D. {item['choices'][3]}
|
32 |
+
Answer:""" # Removed direct instruction from here
|
33 |
return prompt, item['answer']
|
34 |
|
35 |
def extract_choice_letter(output):
|
|
|
37 |
match = re.search(r"\b([ABCD])\b", output.strip())
|
38 |
return match.group(1) if match else None
|
39 |
|
40 |
+
def get_choice_letter(index):
|
41 |
+
"""Converts a numerical choice index (0-3) to a capital letter (A-D)."""
|
42 |
+
return chr(ord('A') + index)
|
43 |
+
|
44 |
def evaluate(model_id, sample_count, config_name, progress=gr.Progress()):
|
45 |
if config_name == "ALL":
|
46 |
# Dynamically get all MMLU subjects
|
|
|
53 |
for i, subject in enumerate(progress.tqdm(subjects, desc="Evaluating subjects")):
|
54 |
dataset = load_dataset("cais/mmlu", subject, token=HF_TOKEN)["test"]
|
55 |
dataset = dataset.shuffle(seed=42).select(range(min(sample_count, len(dataset))))
|
56 |
+
correct_subject = 0
|
57 |
for j, item in enumerate(progress.tqdm(dataset, desc=f"Processing {subject} samples")):
|
58 |
+
prompt, answer_idx = format_prompt(item) # answer_idx is 0, 1, 2, or 3
|
59 |
+
expected_letter = get_choice_letter(answer_idx) # Convert to 'A', 'B', 'C', 'D'
|
60 |
+
|
61 |
# Crucial change: Limit generation to 1 new token
|
62 |
output = gen(prompt, max_new_tokens=1, do_sample=False)[0]["generated_text"]
|
63 |
+
output_letter = extract_choice_letter(output) # Extract the letter from model's output
|
64 |
+
|
65 |
+
is_correct = output_letter == expected_letter
|
66 |
+
correct_subject += is_correct
|
67 |
+
all_results.append((prompt, output.strip(), expected_letter, output_letter, is_correct)) # Store expected_letter
|
68 |
+
total_correct += correct_subject
|
69 |
total_samples += len(dataset)
|
70 |
avg_accuracy = total_correct / total_samples * 100
|
71 |
return avg_accuracy, all_results
|
|
|
78 |
results = []
|
79 |
|
80 |
for i, item in enumerate(progress.tqdm(dataset, desc=f"Processing {config_name} samples")):
|
81 |
+
prompt, answer_idx = format_prompt(item) # answer_idx is 0, 1, 2, or 3
|
82 |
+
expected_letter = get_choice_letter(answer_idx) # Convert to 'A', 'B', 'C', 'D'
|
83 |
+
|
84 |
# Crucial change: Limit generation to 1 new token
|
85 |
output = gen(prompt, max_new_tokens=1, do_sample=False)[0]["generated_text"]
|
86 |
+
output_letter = extract_choice_letter(output) # Extract the letter from model's output
|
87 |
+
|
88 |
+
is_correct = output_letter == expected_letter
|
89 |
correct += is_correct
|
90 |
+
results.append((prompt, output.strip(), expected_letter, output_letter, is_correct)) # Store expected_letter
|
91 |
|
92 |
accuracy = correct / len(dataset) * 100
|
93 |
return accuracy, results
|