Update tasks/text.py
Browse files- tasks/text.py +7 -13
tasks/text.py
CHANGED
@@ -2,8 +2,7 @@ from fastapi import APIRouter
|
|
2 |
from datetime import datetime
|
3 |
from datasets import load_dataset
|
4 |
from sklearn.metrics import accuracy_score
|
5 |
-
from sentence_transformers import SentenceTransformer
|
6 |
-
import faiss
|
7 |
import numpy as np
|
8 |
|
9 |
from .utils.emissions import clean_emissions_data, get_space_info, tracker
|
@@ -11,7 +10,7 @@ from .utils.evaluation import TextEvaluationRequest
|
|
11 |
|
12 |
router = APIRouter()
|
13 |
|
14 |
-
DESCRIPTION = "
|
15 |
ROUTE = "/text"
|
16 |
|
17 |
# Load custom embedding model
|
@@ -41,16 +40,11 @@ class_descriptions = {
|
|
41 |
"7_fossil_fuels_needed": "Fossil fuels have powered centuries of progress, lifted billions out of poverty, and remain the backbone of global energy, while alternatives, though promising, cannot yet match their scale, reliability, or affordability.",
|
42 |
}
|
43 |
|
44 |
-
# Precompute class embeddings
|
45 |
class_labels = list(class_descriptions.keys())
|
46 |
class_sentences = list(class_descriptions.values())
|
47 |
class_embeddings = embedding_model.encode(class_sentences, batch_size=8, convert_to_numpy=True, normalize_embeddings=True)
|
48 |
|
49 |
-
# Build FAISS index for efficient similarity search
|
50 |
-
dimension = class_embeddings.shape[1]
|
51 |
-
faiss_index = faiss.IndexFlatIP(dimension) # Inner product = cosine similarity for normalized vectors
|
52 |
-
faiss_index.add(class_embeddings)
|
53 |
-
|
54 |
|
55 |
@router.post(ROUTE, tags=["Text Task"], description=DESCRIPTION)
|
56 |
async def evaluate_text(request: TextEvaluationRequest):
|
@@ -81,10 +75,10 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
81 |
# Batch embed all test dataset quotes
|
82 |
test_embeddings = embedding_model.encode(test_dataset["quote"], batch_size=32, convert_to_numpy=True, normalize_embeddings=True)
|
83 |
|
84 |
-
#
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
|
89 |
# Apply threshold (0.9) for classification
|
90 |
predictions = [
|
|
|
2 |
from datetime import datetime
|
3 |
from datasets import load_dataset
|
4 |
from sklearn.metrics import accuracy_score
|
5 |
+
from sentence_transformers import SentenceTransformer
|
|
|
6 |
import numpy as np
|
7 |
|
8 |
from .utils.emissions import clean_emissions_data, get_space_info, tracker
|
|
|
10 |
|
11 |
router = APIRouter()
|
12 |
|
13 |
+
DESCRIPTION = "Efficient embedding-based classification with similarity threshold"
|
14 |
ROUTE = "/text"
|
15 |
|
16 |
# Load custom embedding model
|
|
|
40 |
"7_fossil_fuels_needed": "Fossil fuels have powered centuries of progress, lifted billions out of poverty, and remain the backbone of global energy, while alternatives, though promising, cannot yet match their scale, reliability, or affordability.",
|
41 |
}
|
42 |
|
43 |
+
# Precompute class embeddings (normalized for cosine similarity)
|
44 |
class_labels = list(class_descriptions.keys())
|
45 |
class_sentences = list(class_descriptions.values())
|
46 |
class_embeddings = embedding_model.encode(class_sentences, batch_size=8, convert_to_numpy=True, normalize_embeddings=True)
|
47 |
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
@router.post(ROUTE, tags=["Text Task"], description=DESCRIPTION)
|
50 |
async def evaluate_text(request: TextEvaluationRequest):
|
|
|
75 |
# Batch embed all test dataset quotes
|
76 |
test_embeddings = embedding_model.encode(test_dataset["quote"], batch_size=32, convert_to_numpy=True, normalize_embeddings=True)
|
77 |
|
78 |
+
# Compute cosine similarity in a single operation
|
79 |
+
similarity_matrix = np.dot(test_embeddings, class_embeddings.T) # Efficient matrix multiplication
|
80 |
+
best_indices = similarity_matrix.argmax(axis=1) # Get index of highest similarity for each test sample
|
81 |
+
best_similarities = similarity_matrix.max(axis=1) # Get max similarity values
|
82 |
|
83 |
# Apply threshold (0.9) for classification
|
84 |
predictions = [
|