rohansampath commited on
Commit
5d2ee20
·
verified ·
1 Parent(s): 3d567ab

Update mmlu_eval_original.py

Browse files
Files changed (1) hide show
  1. mmlu_eval_original.py +11 -6
mmlu_eval_original.py CHANGED
@@ -78,6 +78,8 @@ def gen_prompt(df, subject, k=-1):
78
  def eval (subject, model, tokenizer, dev_df, test_df, num_questions_per_subject=5, train_shots=5):
79
  assert all(dev_df['subject'] == subject), f"Not all items in dev_df match subject {subject}"
80
  assert all(test_df['subject'] == subject), f"Not all items in test_df match subject {subject}"
 
 
81
 
82
  cors = []
83
  all_probs = []
@@ -101,7 +103,10 @@ def eval (subject, model, tokenizer, dev_df, test_df, num_questions_per_subject=
101
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(
102
  model.device
103
  )
104
- logger.info (f"Prompt: {prompt}")
 
 
 
105
 
106
  label = test_df.iloc[i, test_df.shape[1] - 1]
107
 
@@ -128,11 +133,11 @@ def eval (subject, model, tokenizer, dev_df, test_df, num_questions_per_subject=
128
 
129
  cor = pred == label
130
 
131
- logger.info(f"Label: {label}")
132
- logger.info(f"Logits: {logits}")
133
- logger.info(f"Probabilities: {probs}")
134
- logger.info(f"Prediction: {pred}")
135
- logger.info(f"Correct: {cor}")
136
 
137
  cors.append(cor)
138
  all_probs.append(probs)
 
78
  def eval (subject, model, tokenizer, dev_df, test_df, num_questions_per_subject=5, train_shots=5):
79
  assert all(dev_df['subject'] == subject), f"Not all items in dev_df match subject {subject}"
80
  assert all(test_df['subject'] == subject), f"Not all items in test_df match subject {subject}"
81
+
82
+ logger.info(f"Subject: {subject}")
83
 
84
  cors = []
85
  all_probs = []
 
103
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(
104
  model.device
105
  )
106
+
107
+ logger.info (f"Sample: {i}")
108
+
109
+ #logger.info (f"Prompt: {prompt}")
110
 
111
  label = test_df.iloc[i, test_df.shape[1] - 1]
112
 
 
133
 
134
  cor = pred == label
135
 
136
+ #logger.info(f"Label: {label}")
137
+ #logger.info(f"Logits: {logits}")
138
+ #logger.info(f"Probabilities: {probs}")
139
+ #logger.info(f"Prediction: {pred}")
140
+ #logger.info(f"Correct: {cor}")
141
 
142
  cors.append(cor)
143
  all_probs.append(probs)