Spaces:
Sleeping
Sleeping
Varun Wadhwa
commited on
Fixing accuracy
Browse files
app.py
CHANGED
@@ -133,21 +133,23 @@ def evaluate_model(model, dataloader, device):
|
|
133 |
logits = outputs.logits
|
134 |
|
135 |
# Get predictions
|
136 |
-
preds = torch.argmax(logits, dim=-1)
|
137 |
-
|
138 |
-
|
|
|
|
|
|
|
|
|
139 |
|
140 |
# Calculate evaluation metrics
|
141 |
print("evaluate_model sizes")
|
142 |
print(len(all_preds[0]))
|
143 |
print(len(all_labels[0]))
|
144 |
-
all_preds = np.
|
145 |
-
all_labels = np.
|
146 |
print("Flattened sizes")
|
147 |
print(all_preds.size)
|
148 |
print(all_labels.size)
|
149 |
-
all_preds = all_preds.flatten()
|
150 |
-
all_labels = all_labels.flatten()
|
151 |
accuracy = accuracy_score(all_labels, all_preds)
|
152 |
precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='micro')
|
153 |
|
|
|
133 |
logits = outputs.logits
|
134 |
|
135 |
# Get predictions
|
136 |
+
preds = torch.argmax(logits, dim=-1)
|
137 |
+
|
138 |
+
for p, l in zip(preds.cpu().numpy(), labels.cpu().numpy()):
|
139 |
+
# Filter out `-100` labels and align predictions with valid tokens
|
140 |
+
valid_indices = l != -100
|
141 |
+
all_preds.extend(p[valid_indices])
|
142 |
+
all_labels.extend(l[valid_indices])
|
143 |
|
144 |
# Calculate evaluation metrics
|
145 |
print("evaluate_model sizes")
|
146 |
print(len(all_preds[0]))
|
147 |
print(len(all_labels[0]))
|
148 |
+
all_preds = np.array(all_preds)
|
149 |
+
all_labels = np.array(all_labels)
|
150 |
print("Flattened sizes")
|
151 |
print(all_preds.size)
|
152 |
print(all_labels.size)
|
|
|
|
|
153 |
accuracy = accuracy_score(all_labels, all_preds)
|
154 |
precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='micro')
|
155 |
|