Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -179,9 +179,16 @@ discriminator.eval()
|
|
179 |
|
180 |
def get_prediction(embeddings):
|
181 |
with torch.no_grad():
|
|
|
182 |
last_rep, logits, probs = discriminator(embeddings)
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
|
184 |
-
|
185 |
predicted_labels = predicted_labels.cpu().numpy()
|
186 |
return predicted_labels
|
187 |
|
|
|
179 |
|
180 |
def get_prediction(embeddings):
|
181 |
with torch.no_grad():
|
182 |
+
# Forward pass through the discriminator to get the logits and probabilities
|
183 |
last_rep, logits, probs = discriminator(embeddings)
|
184 |
+
|
185 |
+
# Filter logits to ignore the last dimension (assuming you only care about the first two)
|
186 |
+
filtered_logits = logits[:, 0:-1]
|
187 |
+
|
188 |
+
# Get the predicted labels using the filtered logits
|
189 |
+
_, predicted_labels = torch.max(filtered_logits, dim=-1)
|
190 |
|
191 |
+
# Convert to numpy array if needed
|
192 |
predicted_labels = predicted_labels.cpu().numpy()
|
193 |
return predicted_labels
|
194 |
|