Varun Wadhwa commited on
Commit
6582c50
·
unverified ·
1 Parent(s): 35114e6
Files changed (1) hide show
  1. app.py +6 -1
app.py CHANGED
@@ -122,6 +122,7 @@ def evaluate_model(model, dataloader, device):
122
  all_labels = []
123
 
124
  test = True
 
125
  # Disable gradient calculations
126
  with torch.no_grad():
127
  for batch in dataloader:
@@ -135,8 +136,12 @@ def evaluate_model(model, dataloader, device):
135
 
136
  # Get predictions
137
  preds = torch.argmax(logits, dim=-1).cpu().numpy()
138
- if test:
 
 
139
  test = False
 
 
140
  print(preds)
141
  print(labels)
142
 
 
122
  all_labels = []
123
 
124
  test = True
125
+ test2 = True
126
  # Disable gradient calculations
127
  with torch.no_grad():
128
  for batch in dataloader:
 
136
 
137
  # Get predictions
138
  preds = torch.argmax(logits, dim=-1).cpu().numpy()
139
+ if test or test2:
140
+ if test:
141
+ test2 = False
142
  test = False
143
+ print(len(preds))
144
+ print(len(labels))
145
  print(preds)
146
  print(labels)
147