pedro-thenewsroom commited on
Commit
90194b0
·
verified ·
1 Parent(s): 40ac593

Update tasks/text.py

Browse files
Files changed (1) hide show
  1. tasks/text.py +7 -13
tasks/text.py CHANGED
@@ -2,8 +2,7 @@ from fastapi import APIRouter
2
  from datetime import datetime
3
  from datasets import load_dataset
4
  from sklearn.metrics import accuracy_score
5
- from sentence_transformers import SentenceTransformer, util
6
- import faiss
7
  import numpy as np
8
 
9
  from .utils.emissions import clean_emissions_data, get_space_info, tracker
@@ -11,7 +10,7 @@ from .utils.evaluation import TextEvaluationRequest
11
 
12
  router = APIRouter()
13
 
14
- DESCRIPTION = "Embedding-based classification with similarity threshold"
15
  ROUTE = "/text"
16
 
17
  # Load custom embedding model
@@ -41,16 +40,11 @@ class_descriptions = {
41
  "7_fossil_fuels_needed": "Fossil fuels have powered centuries of progress, lifted billions out of poverty, and remain the backbone of global energy, while alternatives, though promising, cannot yet match their scale, reliability, or affordability.",
42
  }
43
 
44
- # Precompute class embeddings
45
  class_labels = list(class_descriptions.keys())
46
  class_sentences = list(class_descriptions.values())
47
  class_embeddings = embedding_model.encode(class_sentences, batch_size=8, convert_to_numpy=True, normalize_embeddings=True)
48
 
49
- # Build FAISS index for efficient similarity search
50
- dimension = class_embeddings.shape[1]
51
- faiss_index = faiss.IndexFlatIP(dimension) # Inner product = cosine similarity for normalized vectors
52
- faiss_index.add(class_embeddings)
53
-
54
 
55
  @router.post(ROUTE, tags=["Text Task"], description=DESCRIPTION)
56
  async def evaluate_text(request: TextEvaluationRequest):
@@ -81,10 +75,10 @@ async def evaluate_text(request: TextEvaluationRequest):
81
  # Batch embed all test dataset quotes
82
  test_embeddings = embedding_model.encode(test_dataset["quote"], batch_size=32, convert_to_numpy=True, normalize_embeddings=True)
83
 
84
- # Use FAISS to find the nearest class for each embedding
85
- similarities, indices = faiss_index.search(test_embeddings, 1) # Top-1 match for each input
86
- best_similarities = similarities.flatten()
87
- best_indices = indices.flatten()
88
 
89
  # Apply threshold (0.9) for classification
90
  predictions = [
 
2
  from datetime import datetime
3
  from datasets import load_dataset
4
  from sklearn.metrics import accuracy_score
5
+ from sentence_transformers import SentenceTransformer
 
6
  import numpy as np
7
 
8
  from .utils.emissions import clean_emissions_data, get_space_info, tracker
 
10
 
11
  router = APIRouter()
12
 
13
+ DESCRIPTION = "Efficient embedding-based classification with similarity threshold"
14
  ROUTE = "/text"
15
 
16
  # Load custom embedding model
 
40
  "7_fossil_fuels_needed": "Fossil fuels have powered centuries of progress, lifted billions out of poverty, and remain the backbone of global energy, while alternatives, though promising, cannot yet match their scale, reliability, or affordability.",
41
  }
42
 
43
+ # Precompute class embeddings (normalized for cosine similarity)
44
  class_labels = list(class_descriptions.keys())
45
  class_sentences = list(class_descriptions.values())
46
  class_embeddings = embedding_model.encode(class_sentences, batch_size=8, convert_to_numpy=True, normalize_embeddings=True)
47
 
 
 
 
 
 
48
 
49
  @router.post(ROUTE, tags=["Text Task"], description=DESCRIPTION)
50
  async def evaluate_text(request: TextEvaluationRequest):
 
75
  # Batch embed all test dataset quotes
76
  test_embeddings = embedding_model.encode(test_dataset["quote"], batch_size=32, convert_to_numpy=True, normalize_embeddings=True)
77
 
78
+ # Compute cosine similarity in a single operation
79
+ similarity_matrix = np.dot(test_embeddings, class_embeddings.T) # Efficient matrix multiplication
80
+ best_indices = similarity_matrix.argmax(axis=1) # Get index of highest similarity for each test sample
81
+ best_similarities = similarity_matrix.max(axis=1) # Get max similarity values
82
 
83
  # Apply threshold (0.9) for classification
84
  predictions = [