pedro-thenewsroom commited on
Commit
0a7a34d
·
verified ·
1 Parent(s): 9bb3e73

Update tasks/text.py

Browse files
Files changed (1) hide show
  1. tasks/text.py +2 -2
tasks/text.py CHANGED
@@ -64,7 +64,7 @@ async def evaluate_text(request: TextEvaluationRequest):
64
  train_test = dataset["train"].train_test_split(test_size=request.test_size, seed=request.test_seed)
65
  test_dataset = train_test["test"]
66
  # Batch embed all test dataset quotes
67
- test_embeddings = embedding_model.encode(test_dataset["quote"], batch_size=128, convert_to_numpy=True, normalize_embeddings=True)
68
 
69
  # Start tracking emissions
70
  tracker.start()
@@ -81,7 +81,7 @@ async def evaluate_text(request: TextEvaluationRequest):
81
 
82
  # Apply threshold (0.9) for classification
83
  predictions = [
84
- LABEL_MAPPING[class_labels[idx]] if sim > 0.9 else LABEL_MAPPING["0_not_relevant"]
85
  for idx, sim in zip(best_indices, best_similarities)
86
  ]
87
 
 
64
  train_test = dataset["train"].train_test_split(test_size=request.test_size, seed=request.test_seed)
65
  test_dataset = train_test["test"]
66
  # Batch embed all test dataset quotes
67
+ test_embeddings = embedding_model.encode(test_dataset["quote"], batch_size=128)
68
 
69
  # Start tracking emissions
70
  tracker.start()
 
81
 
82
  # Apply threshold (0.9) for classification
83
  predictions = [
84
+ LABEL_MAPPING[class_labels[idx]] if sim > 0.8 else LABEL_MAPPING["0_not_relevant"]
85
  for idx, sim in zip(best_indices, best_similarities)
86
  ]
87