JasonTPhillipsJr commited on
Commit
9edc447
·
verified ·
1 Parent(s): b28bcdc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -1
app.py CHANGED
@@ -179,9 +179,16 @@ discriminator.eval()
179
 
180
  def get_prediction(embeddings):
181
  with torch.no_grad():
 
182
  last_rep, logits, probs = discriminator(embeddings)
 
 
 
 
 
 
183
 
184
- predicted_labels = torch.argmax(probs,dim=-1)
185
  predicted_labels = predicted_labels.cpu().numpy()
186
  return predicted_labels
187
 
 
179
 
180
  def get_prediction(embeddings):
181
  with torch.no_grad():
182
+ # Forward pass through the discriminator to get the logits and probabilities
183
  last_rep, logits, probs = discriminator(embeddings)
184
+
185
+ # Filter logits to ignore the last dimension (assuming you only care about the first two)
186
+ filtered_logits = logits[:, 0:-1]
187
+
188
+ # Get the predicted labels using the filtered logits
189
+ _, predicted_labels = torch.max(filtered_logits, dim=-1)
190
 
191
+ # Convert to numpy array if needed
192
  predicted_labels = predicted_labels.cpu().numpy()
193
  return predicted_labels
194