Varun Wadhwa commited on
Commit
35114e6
·
unverified ·
1 Parent(s): e4a2227
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -121,6 +121,7 @@ def evaluate_model(model, dataloader, device):
121
  all_preds = []
122
  all_labels = []
123
 
 
124
  # Disable gradient calculations
125
  with torch.no_grad():
126
  for batch in dataloader:
@@ -134,6 +135,10 @@ def evaluate_model(model, dataloader, device):
134
 
135
  # Get predictions
136
  preds = torch.argmax(logits, dim=-1).cpu().numpy()
 
 
 
 
137
 
138
  all_preds.extend(preds)
139
  all_labels.extend(labels.cpu().numpy())
@@ -142,10 +147,6 @@ def evaluate_model(model, dataloader, device):
142
  print("evaluate_model sizes")
143
  print(len(all_preds[0]))
144
  print(len(all_labels[0]))
145
- for p in all_preds:
146
- if len(p) != len(all_preds[0]):
147
- print(len(p))
148
- print(p)
149
  all_preds = np.asarray(all_preds, dtype=np.float32)
150
  all_labels = np.asarray(all_labels, dtype=np.float32)
151
  print("Flattened sizes")
 
121
  all_preds = []
122
  all_labels = []
123
 
124
+ test = True
125
  # Disable gradient calculations
126
  with torch.no_grad():
127
  for batch in dataloader:
 
135
 
136
  # Get predictions
137
  preds = torch.argmax(logits, dim=-1).cpu().numpy()
138
+ if test:
139
+ test = False
140
+ print(preds)
141
+ print(labels)
142
 
143
  all_preds.extend(preds)
144
  all_labels.extend(labels.cpu().numpy())
 
147
  print("evaluate_model sizes")
148
  print(len(all_preds[0]))
149
  print(len(all_labels[0]))
 
 
 
 
150
  all_preds = np.asarray(all_preds, dtype=np.float32)
151
  all_labels = np.asarray(all_labels, dtype=np.float32)
152
  print("Flattened sizes")