Varun Wadhwa commited on
Commit
f3c2885
·
unverified ·
1 Parent(s): 0444fba
Files changed (1) hide show
  1. app.py +5 -2
app.py CHANGED
@@ -134,14 +134,17 @@ def evaluate_model(model, dataloader, device):
134
 
135
  # Get predictions
136
  preds = torch.argmax(logits, dim=-1).cpu().numpy()
 
 
 
137
 
138
  all_preds.extend(preds)
139
  all_labels.extend(labels)
140
 
141
  # Calculate evaluation metrics
142
  print("evaluate_model sizes")
143
- print("Shape of preds:", all_preds.shape)
144
- print("Shape of labels:", all_labels.shape)
145
  all_preds = np.asarray(all_preds, dtype=np.float32)
146
  all_labels = np.asarray(all_labels, dtype=np.float32)
147
  print("Flattened sizes")
 
134
 
135
  # Get predictions
136
  preds = torch.argmax(logits, dim=-1).cpu().numpy()
137
+
138
+ print("Shape of preds:", preds.shape)
139
+ print("Shape of labels:", labels.shape)
140
 
141
  all_preds.extend(preds)
142
  all_labels.extend(labels)
143
 
144
  # Calculate evaluation metrics
145
  print("evaluate_model sizes")
146
+ print(len(all_preds[0]))
147
+ print(len(all_labels[0]))
148
  all_preds = np.asarray(all_preds, dtype=np.float32)
149
  all_labels = np.asarray(all_labels, dtype=np.float32)
150
  print("Flattened sizes")