Update tasks/text.py
Browse files- tasks/text.py +11 -5
tasks/text.py
CHANGED
@@ -45,7 +45,6 @@ class_labels = list(class_descriptions.keys())
|
|
45 |
class_sentences = list(class_descriptions.values())
|
46 |
class_embeddings = embedding_model.encode(class_sentences, batch_size=8, convert_to_numpy=True, normalize_embeddings=True)
|
47 |
|
48 |
-
|
49 |
@router.post(ROUTE, tags=["Text Task"], description=DESCRIPTION)
|
50 |
async def evaluate_text(request: TextEvaluationRequest):
|
51 |
"""
|
@@ -71,16 +70,23 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
71 |
# --------------------------------------------------------------------------------------------
|
72 |
# Optimized cosine similarity-based classification with threshold
|
73 |
# --------------------------------------------------------------------------------------------
|
74 |
-
|
75 |
-
# Batch embed all test dataset quotes
|
76 |
-
test_embeddings = embedding_model.encode(test_dataset["quote"], batch_size=128)
|
77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
# Compute cosine similarity in a single operation
|
79 |
similarity_matrix = np.dot(test_embeddings, class_embeddings.T) # Efficient matrix multiplication
|
80 |
best_indices = similarity_matrix.argmax(axis=1) # Get index of highest similarity for each test sample
|
81 |
best_similarities = similarity_matrix.max(axis=1) # Get max similarity values
|
82 |
|
83 |
-
# Apply threshold (0.
|
84 |
predictions = [
|
85 |
LABEL_MAPPING[class_labels[idx]] if sim > 0.8 else LABEL_MAPPING["0_not_relevant"]
|
86 |
for idx, sim in zip(best_indices, best_similarities)
|
|
|
45 |
class_sentences = list(class_descriptions.values())
|
46 |
class_embeddings = embedding_model.encode(class_sentences, batch_size=8, convert_to_numpy=True, normalize_embeddings=True)
|
47 |
|
|
|
48 |
@router.post(ROUTE, tags=["Text Task"], description=DESCRIPTION)
|
49 |
async def evaluate_text(request: TextEvaluationRequest):
|
50 |
"""
|
|
|
70 |
# --------------------------------------------------------------------------------------------
|
71 |
# Optimized cosine similarity-based classification with threshold
|
72 |
# --------------------------------------------------------------------------------------------
|
|
|
|
|
|
|
73 |
|
74 |
+
# Convert "quote" key into embeddings
|
75 |
+
def embed_quote(example):
|
76 |
+
example["quote_embedding"] = embedding_model.encode(example["quote"]).tolist()
|
77 |
+
return example
|
78 |
+
|
79 |
+
test_dataset = test_dataset.map(embed_quote, batched=True)
|
80 |
+
|
81 |
+
# Convert test embeddings to numpy array
|
82 |
+
test_embeddings = np.array(test_dataset["quote_embedding"])
|
83 |
+
|
84 |
# Compute cosine similarity in a single operation
|
85 |
similarity_matrix = np.dot(test_embeddings, class_embeddings.T) # Efficient matrix multiplication
|
86 |
best_indices = similarity_matrix.argmax(axis=1) # Get index of highest similarity for each test sample
|
87 |
best_similarities = similarity_matrix.max(axis=1) # Get max similarity values
|
88 |
|
89 |
+
# Apply threshold (0.8) for classification
|
90 |
predictions = [
|
91 |
LABEL_MAPPING[class_labels[idx]] if sim > 0.8 else LABEL_MAPPING["0_not_relevant"]
|
92 |
for idx, sim in zip(best_indices, best_similarities)
|