aur-beau commited on
Commit
42a0f5f
·
1 Parent(s): 7e15097

fixed errors

Browse files
Files changed (1) hide show
  1. tasks/text.py +4 -5
tasks/text.py CHANGED
@@ -67,8 +67,7 @@ async def evaluate_text(request: TextEvaluationRequest):
67
  sentence_model = SentenceTransformer(model_name)
68
 
69
  # Convert each sentence into a vector representation (embedding)
70
- embeddings = model.encode(test_dataset['quote'].tolist())
71
- embeddings = embeddings.cpu()
72
 
73
  # Update the code below to replace the random baseline by your model inference within the inference pass where the energy consumption and emissions are tracked.
74
  #--------------------------------------------------------------------------------------------
@@ -91,11 +90,11 @@ async def evaluate_text(request: TextEvaluationRequest):
91
 
92
  # Make predictions
93
  with torch.no_grad():
94
- outputs = model_nn(text_embeddings)
95
  _, predicted = torch.max(outputs, 1) # Get the class with the highest score
96
-
97
  # Decode the predictions back to original labels using label_encoder
98
- predicted_labels = label_encoder.inverse_transform(predicted.cpu().numpy())
99
 
100
  #--------------------------------------------------------------------------------------------
101
  # YOUR MODEL INFERENCE STOPS HERE
 
67
  sentence_model = SentenceTransformer(model_name)
68
 
69
  # Convert each sentence into a vector representation (embedding)
70
+ embeddings = sentence_model.encode(test_dataset['quote'])
 
71
 
72
  # Update the code below to replace the random baseline by your model inference within the inference pass where the energy consumption and emissions are tracked.
73
  #--------------------------------------------------------------------------------------------
 
90
 
91
  # Make predictions
92
  with torch.no_grad():
93
+ outputs = model_nn(embeddings)
94
  _, predicted = torch.max(outputs, 1) # Get the class with the highest score
95
+
96
  # Decode the predictions back to original labels using label_encoder
97
+ predictions = predicted.cpu().numpy()
98
 
99
  #--------------------------------------------------------------------------------------------
100
  # YOUR MODEL INFERENCE STOPS HERE