sumesh4C commited on
Commit
dfa54b4
·
verified ·
1 Parent(s): 7e15097

Update tasks/text.py

Browse files
Files changed (1) hide show
  1. tasks/text.py +5 -1
tasks/text.py CHANGED
@@ -61,6 +61,10 @@ async def evaluate_text(request: TextEvaluationRequest):
61
  #--------------------------------------------------------------------------------------------
62
  # YOUR MODEL INFERENCE CODE HERE
63
 
 
 
 
 
64
  #Load the embedding model
65
  #model = SentenceTransformer("dunzhang/stella_en_400M_v5",trust_remote_code=True)
66
  model_name = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2" # You can use other Sentence Transformers models as needed
@@ -91,7 +95,7 @@ async def evaluate_text(request: TextEvaluationRequest):
91
 
92
  # Make predictions
93
  with torch.no_grad():
94
- outputs = model_nn(text_embeddings)
95
  _, predicted = torch.max(outputs, 1) # Get the class with the highest score
96
 
97
  # Decode the predictions back to original labels using label_encoder
 
61
  #--------------------------------------------------------------------------------------------
62
  # YOUR MODEL INFERENCE CODE HERE
63
 
64
+ # Set the device to MPS (if available)
65
+ device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
66
+ print(f"Using device: {device}")
67
+
68
  #Load the embedding model
69
  #model = SentenceTransformer("dunzhang/stella_en_400M_v5",trust_remote_code=True)
70
  model_name = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2" # You can use other Sentence Transformers models as needed
 
95
 
96
  # Make predictions
97
  with torch.no_grad():
98
+ outputs = model_nn(text_embeddings.to(device))
99
  _, predicted = torch.max(outputs, 1) # Get the class with the highest score
100
 
101
  # Decode the predictions back to original labels using label_encoder