Varun Wadhwa commited on
Commit
663c988
·
unverified ·
1 Parent(s): 6582c50
Files changed (1) hide show
  1. app.py +9 -11
app.py CHANGED
@@ -121,14 +121,20 @@ def evaluate_model(model, dataloader, device):
121
  all_preds = []
122
  all_labels = []
123
 
124
- test = True
125
- test2 = True
126
  # Disable gradient calculations
127
  with torch.no_grad():
128
  for batch in dataloader:
129
  input_ids = batch['input_ids'].to(device)
130
  attention_mask = batch['attention_mask'].to(device)
131
  labels = batch['labels'].to(device)
 
 
 
 
 
 
 
 
132
 
133
  # Forward pass to get logits
134
  outputs = model(input_ids, attention_mask=attention_mask)
@@ -136,17 +142,9 @@ def evaluate_model(model, dataloader, device):
136
 
137
  # Get predictions
138
  preds = torch.argmax(logits, dim=-1).cpu().numpy()
139
- if test or test2:
140
- if test:
141
- test2 = False
142
- test = False
143
- print(len(preds))
144
- print(len(labels))
145
- print(preds)
146
- print(labels)
147
 
148
  all_preds.extend(preds)
149
- all_labels.extend(labels.cpu().numpy())
150
 
151
  # Calculate evaluation metrics
152
  print("evaluate_model sizes")
 
121
  all_preds = []
122
  all_labels = []
123
 
 
 
124
  # Disable gradient calculations
125
  with torch.no_grad():
126
  for batch in dataloader:
127
  input_ids = batch['input_ids'].to(device)
128
  attention_mask = batch['attention_mask'].to(device)
129
  labels = batch['labels'].to(device)
130
+ x = len(labels[0])
131
+ print(labels[0])
132
+ for l in labels:
133
+ if len(l) != x:
134
+ print(len(l))
135
+ print(l)
136
+ break
137
+
138
 
139
  # Forward pass to get logits
140
  outputs = model(input_ids, attention_mask=attention_mask)
 
142
 
143
  # Get predictions
144
  preds = torch.argmax(logits, dim=-1).cpu().numpy()
 
 
 
 
 
 
 
 
145
 
146
  all_preds.extend(preds)
147
+ all_labels.extend(labels)
148
 
149
  # Calculate evaluation metrics
150
  print("evaluate_model sizes")