Update tasks/text.py
Browse files- tasks/text.py +2 -2
tasks/text.py
CHANGED
@@ -64,7 +64,7 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
64 |
train_test = dataset["train"].train_test_split(test_size=request.test_size, seed=request.test_seed)
|
65 |
test_dataset = train_test["test"]
|
66 |
# Batch embed all test dataset quotes
|
67 |
-
test_embeddings = embedding_model.encode(test_dataset["quote"], batch_size=128
|
68 |
|
69 |
# Start tracking emissions
|
70 |
tracker.start()
|
@@ -81,7 +81,7 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
81 |
|
82 |
# Apply threshold (0.9) for classification
|
83 |
predictions = [
|
84 |
-
LABEL_MAPPING[class_labels[idx]] if sim > 0.
|
85 |
for idx, sim in zip(best_indices, best_similarities)
|
86 |
]
|
87 |
|
|
|
64 |
train_test = dataset["train"].train_test_split(test_size=request.test_size, seed=request.test_seed)
|
65 |
test_dataset = train_test["test"]
|
66 |
# Batch embed all test dataset quotes
|
67 |
+
test_embeddings = embedding_model.encode(test_dataset["quote"], batch_size=128)
|
68 |
|
69 |
# Start tracking emissions
|
70 |
tracker.start()
|
|
|
81 |
|
82 |
# Apply threshold (0.9) for classification
|
83 |
predictions = [
|
84 |
+
LABEL_MAPPING[class_labels[idx]] if sim > 0.8 else LABEL_MAPPING["0_not_relevant"]
|
85 |
for idx, sim in zip(best_indices, best_similarities)
|
86 |
]
|
87 |
|