Spaces:
Sleeping
Sleeping
Update tasks/text.py
Browse files- tasks/text.py +5 -1
tasks/text.py
CHANGED
@@ -61,6 +61,10 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
61 |
#--------------------------------------------------------------------------------------------
|
62 |
# YOUR MODEL INFERENCE CODE HERE
|
63 |
|
|
|
|
|
|
|
|
|
64 |
#Load the embedding model
|
65 |
#model = SentenceTransformer("dunzhang/stella_en_400M_v5",trust_remote_code=True)
|
66 |
model_name = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2" # You can use other Sentence Transformers models as needed
|
@@ -91,7 +95,7 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
91 |
|
92 |
# Make predictions
|
93 |
with torch.no_grad():
|
94 |
-
outputs = model_nn(text_embeddings)
|
95 |
_, predicted = torch.max(outputs, 1) # Get the class with the highest score
|
96 |
|
97 |
# Decode the predictions back to original labels using label_encoder
|
|
|
61 |
#--------------------------------------------------------------------------------------------
|
62 |
# YOUR MODEL INFERENCE CODE HERE
|
63 |
|
64 |
+
# Set the device to MPS (if available)
|
65 |
+
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
|
66 |
+
print(f"Using device: {device}")
|
67 |
+
|
68 |
#Load the embedding model
|
69 |
#model = SentenceTransformer("dunzhang/stella_en_400M_v5",trust_remote_code=True)
|
70 |
model_name = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2" # You can use other Sentence Transformers models as needed
|
|
|
95 |
|
96 |
# Make predictions
|
97 |
with torch.no_grad():
|
98 |
+
outputs = model_nn(text_embeddings.to(device))
|
99 |
_, predicted = torch.max(outputs, 1) # Get the class with the highest score
|
100 |
|
101 |
# Decode the predictions back to original labels using label_encoder
|