Spaces:
Sleeping
Sleeping
Update mmlu_pro_eval_adapted.py
Browse files- mmlu_pro_eval_adapted.py +3 -3
mmlu_pro_eval_adapted.py
CHANGED
@@ -14,12 +14,13 @@ import logging
|
|
14 |
import sys
|
15 |
from datasets import load_dataset
|
16 |
import pandas as pd
|
|
|
17 |
|
18 |
logging.basicConfig(level=logging.INFO)
|
19 |
logger = logging.getLogger(__name__)
|
20 |
|
21 |
# Can be found at https://github.com/TIGER-AI-Lab/MMLU-Pro/blob/main/cot_prompt_lib/initial_prompt.txt
|
22 |
-
initial_prompt = "The following are multiple choice questions (with answers) about {$}. Think step by step and then finish your answer with "the answer is (X)" where X is the correct letter choice."
|
23 |
|
24 |
choices = ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P"]
|
25 |
max_model_length = 4096
|
@@ -127,7 +128,7 @@ def batch_inference(llm, sampling_params, inference_batch):
|
|
127 |
response_batch.append(generated_text)
|
128 |
pred = extract_answer(generated_text)
|
129 |
pred_batch.append(pred)
|
130 |
-
logging.info("PRED BATCH:
|
131 |
return pred_batch, response_batch
|
132 |
|
133 |
|
@@ -235,7 +236,6 @@ def evaluate_mmlu_pro(model_name, num_subjects=-1, num_questions=10, num_shots=5
|
|
235 |
'Accuracy': acc
|
236 |
})
|
237 |
|
238 |
-
import numpy as np # Added: missing import
|
239 |
weighted_acc = np.mean(all_correctness)
|
240 |
|
241 |
min_acc_subject = min(results.items(), key=lambda x: x[1])[0]
|
|
|
14 |
import sys
|
15 |
from datasets import load_dataset
|
16 |
import pandas as pd
|
17 |
+
import numpy as mnp
|
18 |
|
19 |
logging.basicConfig(level=logging.INFO)
|
20 |
logger = logging.getLogger(__name__)
|
21 |
|
22 |
# Can be found at https://github.com/TIGER-AI-Lab/MMLU-Pro/blob/main/cot_prompt_lib/initial_prompt.txt
|
23 |
+
initial_prompt = "The following are multiple choice questions (with answers) about {$}. Think step by step and then finish your answer with \"the answer is (X)\" where X is the correct letter choice."
|
24 |
|
25 |
choices = ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P"]
|
26 |
max_model_length = 4096
|
|
|
128 |
response_batch.append(generated_text)
|
129 |
pred = extract_answer(generated_text)
|
130 |
pred_batch.append(pred)
|
131 |
+
logging.info("PRED BATCH: %s, RESPONSE BATCH: %s", pred_batch, response_batch)
|
132 |
return pred_batch, response_batch
|
133 |
|
134 |
|
|
|
236 |
'Accuracy': acc
|
237 |
})
|
238 |
|
|
|
239 |
weighted_acc = np.mean(all_correctness)
|
240 |
|
241 |
min_acc_subject = min(results.items(), key=lambda x: x[1])[0]
|