Varun Wadhwa commited on
Commit
6f56edb
·
unverified ·
1 Parent(s): e3f5712
Files changed (1) hide show
  1. app.py +7 -9
app.py CHANGED
@@ -133,23 +133,21 @@ def evaluate_model(model, dataloader, device):
133
  logits = outputs.logits
134
 
135
  # Get predictions
136
- preds = torch.argmax(logits, dim=-1)
137
-
138
- for p, l in zip(preds.cpu().numpy(), labels.cpu().numpy()):
139
- # Filter out `-100` labels and align predictions with valid tokens
140
- valid_indices = l != -100
141
- all_preds.extend(p[valid_indices])
142
- all_labels.extend(l[valid_indices])
143
 
144
  # Calculate evaluation metrics
145
  print("evaluate_model sizes")
146
  print(len(all_preds[0]))
147
  print(len(all_labels[0]))
148
- all_preds = np.array(all_preds)
149
- all_labels = np.array(all_labels)
150
  print("Flattened sizes")
151
  print(all_preds.size)
152
  print(all_labels.size)
 
 
153
  accuracy = accuracy_score(all_labels, all_preds)
154
  precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='micro')
155
 
 
133
  logits = outputs.logits
134
 
135
  # Get predictions
136
+ preds = torch.argmax(logits, dim=-1).cpu().numpy()
137
+ all_preds.extend(preds)
138
+ all_labels.extend(labels.cpu().numpy())
 
 
 
 
139
 
140
  # Calculate evaluation metrics
141
  print("evaluate_model sizes")
142
  print(len(all_preds[0]))
143
  print(len(all_labels[0]))
144
+ all_preds = np.asarray(all_preds, dtype=np.float32)
145
+ all_labels = np.asarray(all_labels, dtype=np.float32)
146
  print("Flattened sizes")
147
  print(all_preds.size)
148
  print(all_labels.size)
149
+ all_preds = all_preds.flatten()
150
+ all_labels = all_labels.flatten()
151
  accuracy = accuracy_score(all_labels, all_preds)
152
  precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='micro')
153