pedro-thenewsroom commited on
Commit
941eb28
·
verified ·
1 Parent(s): aa18df0

Update tasks/text.py

Browse files
Files changed (1) hide show
  1. tasks/text.py +11 -5
tasks/text.py CHANGED
@@ -45,7 +45,6 @@ class_labels = list(class_descriptions.keys())
45
  class_sentences = list(class_descriptions.values())
46
  class_embeddings = embedding_model.encode(class_sentences, batch_size=8, convert_to_numpy=True, normalize_embeddings=True)
47
 
48
-
49
  @router.post(ROUTE, tags=["Text Task"], description=DESCRIPTION)
50
  async def evaluate_text(request: TextEvaluationRequest):
51
  """
@@ -71,16 +70,23 @@ async def evaluate_text(request: TextEvaluationRequest):
71
  # --------------------------------------------------------------------------------------------
72
  # Optimized cosine similarity-based classification with threshold
73
  # --------------------------------------------------------------------------------------------
74
-
75
- # Batch embed all test dataset quotes
76
- test_embeddings = embedding_model.encode(test_dataset["quote"], batch_size=128)
77
 
 
 
 
 
 
 
 
 
 
 
78
  # Compute cosine similarity in a single operation
79
  similarity_matrix = np.dot(test_embeddings, class_embeddings.T) # Efficient matrix multiplication
80
  best_indices = similarity_matrix.argmax(axis=1) # Get index of highest similarity for each test sample
81
  best_similarities = similarity_matrix.max(axis=1) # Get max similarity values
82
 
83
- # Apply threshold (0.9) for classification
84
  predictions = [
85
  LABEL_MAPPING[class_labels[idx]] if sim > 0.8 else LABEL_MAPPING["0_not_relevant"]
86
  for idx, sim in zip(best_indices, best_similarities)
 
45
  class_sentences = list(class_descriptions.values())
46
  class_embeddings = embedding_model.encode(class_sentences, batch_size=8, convert_to_numpy=True, normalize_embeddings=True)
47
 
 
48
  @router.post(ROUTE, tags=["Text Task"], description=DESCRIPTION)
49
  async def evaluate_text(request: TextEvaluationRequest):
50
  """
 
70
  # --------------------------------------------------------------------------------------------
71
  # Optimized cosine similarity-based classification with threshold
72
  # --------------------------------------------------------------------------------------------
 
 
 
73
 
74
+ # Convert "quote" key into embeddings
75
+ def embed_quote(example):
76
+ example["quote_embedding"] = embedding_model.encode(example["quote"]).tolist()
77
+ return example
78
+
79
+ test_dataset = test_dataset.map(embed_quote, batched=True)
80
+
81
+ # Convert test embeddings to numpy array
82
+ test_embeddings = np.array(test_dataset["quote_embedding"])
83
+
84
  # Compute cosine similarity in a single operation
85
  similarity_matrix = np.dot(test_embeddings, class_embeddings.T) # Efficient matrix multiplication
86
  best_indices = similarity_matrix.argmax(axis=1) # Get index of highest similarity for each test sample
87
  best_similarities = similarity_matrix.max(axis=1) # Get max similarity values
88
 
89
+ # Apply threshold (0.8) for classification
90
  predictions = [
91
  LABEL_MAPPING[class_labels[idx]] if sim > 0.8 else LABEL_MAPPING["0_not_relevant"]
92
  for idx, sim in zip(best_indices, best_similarities)