Enderchef commited on
Commit
0a040f1
·
verified ·
1 Parent(s): 903eadb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -13
app.py CHANGED
@@ -23,14 +23,13 @@ def load_model(model_id):
23
  return generator
24
 
25
  def format_prompt(item):
26
- # Emphasize the single letter answer instruction to encourage concise output
27
- system_instruction = "Respond ONLY with a single capital letter: A, B, C, or D. No other text."
28
  prompt = f"""{item['question']}
29
  A. {item['choices'][0]}
30
  B. {item['choices'][1]}
31
  C. {item['choices'][2]}
32
  D. {item['choices'][3]}
33
- Answer: {system_instruction}""" # Place instruction after 'Answer:' with a space
34
  return prompt, item['answer']
35
 
36
  def extract_choice_letter(output):
@@ -38,6 +37,10 @@ def extract_choice_letter(output):
38
  match = re.search(r"\b([ABCD])\b", output.strip())
39
  return match.group(1) if match else None
40
 
 
 
 
 
41
  def evaluate(model_id, sample_count, config_name, progress=gr.Progress()):
42
  if config_name == "ALL":
43
  # Dynamically get all MMLU subjects
@@ -50,15 +53,19 @@ def evaluate(model_id, sample_count, config_name, progress=gr.Progress()):
50
  for i, subject in enumerate(progress.tqdm(subjects, desc="Evaluating subjects")):
51
  dataset = load_dataset("cais/mmlu", subject, token=HF_TOKEN)["test"]
52
  dataset = dataset.shuffle(seed=42).select(range(min(sample_count, len(dataset))))
53
- correct = 0
54
  for j, item in enumerate(progress.tqdm(dataset, desc=f"Processing {subject} samples")):
55
- prompt, answer = format_prompt(item)
 
 
56
  # Crucial change: Limit generation to 1 new token
57
  output = gen(prompt, max_new_tokens=1, do_sample=False)[0]["generated_text"]
58
- output_letter = extract_choice_letter(output)
59
- correct += output_letter == answer
60
- all_results.append((prompt, output.strip(), answer, output_letter, output_letter == answer))
61
- total_correct += correct
 
 
62
  total_samples += len(dataset)
63
  avg_accuracy = total_correct / total_samples * 100
64
  return avg_accuracy, all_results
@@ -71,13 +78,16 @@ def evaluate(model_id, sample_count, config_name, progress=gr.Progress()):
71
  results = []
72
 
73
  for i, item in enumerate(progress.tqdm(dataset, desc=f"Processing {config_name} samples")):
74
- prompt, answer = format_prompt(item)
 
 
75
  # Crucial change: Limit generation to 1 new token
76
  output = gen(prompt, max_new_tokens=1, do_sample=False)[0]["generated_text"]
77
- output_letter = extract_choice_letter(output)
78
- is_correct = output_letter == answer
 
79
  correct += is_correct
80
- results.append((prompt, output.strip(), answer, output_letter, is_correct))
81
 
82
  accuracy = correct / len(dataset) * 100
83
  return accuracy, results
 
23
  return generator
24
 
25
  def format_prompt(item):
26
+ # Simplified prompt: rely on max_new_tokens=1 and model's understanding for single-letter answer
 
27
  prompt = f"""{item['question']}
28
  A. {item['choices'][0]}
29
  B. {item['choices'][1]}
30
  C. {item['choices'][2]}
31
  D. {item['choices'][3]}
32
+ Answer:""" # Removed direct instruction from here
33
  return prompt, item['answer']
34
 
35
  def extract_choice_letter(output):
 
37
  match = re.search(r"\b([ABCD])\b", output.strip())
38
  return match.group(1) if match else None
39
 
40
+ def get_choice_letter(index):
41
+ """Converts a numerical choice index (0-3) to a capital letter (A-D)."""
42
+ return chr(ord('A') + index)
43
+
44
  def evaluate(model_id, sample_count, config_name, progress=gr.Progress()):
45
  if config_name == "ALL":
46
  # Dynamically get all MMLU subjects
 
53
  for i, subject in enumerate(progress.tqdm(subjects, desc="Evaluating subjects")):
54
  dataset = load_dataset("cais/mmlu", subject, token=HF_TOKEN)["test"]
55
  dataset = dataset.shuffle(seed=42).select(range(min(sample_count, len(dataset))))
56
+ correct_subject = 0
57
  for j, item in enumerate(progress.tqdm(dataset, desc=f"Processing {subject} samples")):
58
+ prompt, answer_idx = format_prompt(item) # answer_idx is 0, 1, 2, or 3
59
+ expected_letter = get_choice_letter(answer_idx) # Convert to 'A', 'B', 'C', 'D'
60
+
61
  # Crucial change: Limit generation to 1 new token
62
  output = gen(prompt, max_new_tokens=1, do_sample=False)[0]["generated_text"]
63
+ output_letter = extract_choice_letter(output) # Extract the letter from model's output
64
+
65
+ is_correct = output_letter == expected_letter
66
+ correct_subject += is_correct
67
+ all_results.append((prompt, output.strip(), expected_letter, output_letter, is_correct)) # Store expected_letter
68
+ total_correct += correct_subject
69
  total_samples += len(dataset)
70
  avg_accuracy = total_correct / total_samples * 100
71
  return avg_accuracy, all_results
 
78
  results = []
79
 
80
  for i, item in enumerate(progress.tqdm(dataset, desc=f"Processing {config_name} samples")):
81
+ prompt, answer_idx = format_prompt(item) # answer_idx is 0, 1, 2, or 3
82
+ expected_letter = get_choice_letter(answer_idx) # Convert to 'A', 'B', 'C', 'D'
83
+
84
  # Crucial change: Limit generation to 1 new token
85
  output = gen(prompt, max_new_tokens=1, do_sample=False)[0]["generated_text"]
86
+ output_letter = extract_choice_letter(output) # Extract the letter from model's output
87
+
88
+ is_correct = output_letter == expected_letter
89
  correct += is_correct
90
+ results.append((prompt, output.strip(), expected_letter, output_letter, is_correct)) # Store expected_letter
91
 
92
  accuracy = correct / len(dataset) * 100
93
  return accuracy, results