pedro-thenewsroom commited on
Commit
a1a5fb1
·
verified ·
1 Parent(s): 923778a

Update tasks/text.py

Browse files
Files changed (1) hide show
  1. tasks/text.py +23 -1
tasks/text.py CHANGED
@@ -5,6 +5,9 @@ from sklearn.metrics import accuracy_score
5
  from sentence_transformers import SentenceTransformer
6
  import numpy as np
7
 
 
 
 
8
  router = APIRouter()
9
 
10
  DESCRIPTION = "Class embeddings with cosine similarity using batching and thresholding"
@@ -42,10 +45,13 @@ class_embeddings = embedding_model.encode(class_labels, batch_size=8, convert_to
42
 
43
 
44
  @router.post(ROUTE, tags=["Text Task"], description=DESCRIPTION)
45
- async def evaluate_text(request: dict):
46
  """
47
  Evaluate text classification using precomputed embeddings and cosine similarity.
48
  """
 
 
 
49
  # Load dataset
50
  dataset = load_dataset("QuotaClimat/frugalaichallenge-text-train")
51
  df_train = dataset["train"].to_pandas()
@@ -56,6 +62,14 @@ async def evaluate_text(request: dict):
56
  quotes = df["quote"].tolist()
57
  true_labels = df["label"].apply(lambda x: int(x.split("_")[0]) if isinstance(x, str) else 0).tolist()
58
 
 
 
 
 
 
 
 
 
59
  # Encode dataset quotes in batches
60
  batch_size = 32
61
  quote_embeddings = embedding_model.encode(quotes, batch_size=batch_size, convert_to_numpy=True, normalize_embeddings=True)
@@ -70,14 +84,22 @@ async def evaluate_text(request: dict):
70
  # Apply threshold (0.9) for classification
71
  predicted_labels = [best_idx if best_sim > 0.9 else 0 for best_idx, best_sim in zip(best_indices, best_similarities)]
72
 
 
 
 
73
  # Calculate accuracy
74
  accuracy = accuracy_score(true_labels, predicted_labels)
75
 
76
  # Prepare results dictionary
77
  results = {
 
 
78
  "submission_timestamp": datetime.now().isoformat(),
79
  "model_description": DESCRIPTION,
80
  "accuracy": float(accuracy),
 
 
 
81
  "api_route": ROUTE,
82
  "dataset_config": {
83
  "dataset_name": "QuotaClimat/frugalaichallenge-text-train",
 
5
  from sentence_transformers import SentenceTransformer
6
  import numpy as np
7
 
8
+ from .utils.emissions import clean_emissions_data, get_space_info, tracker
9
+ from .utils.evaluation import TextEvaluationRequest
10
+
11
  router = APIRouter()
12
 
13
  DESCRIPTION = "Class embeddings with cosine similarity using batching and thresholding"
 
45
 
46
 
47
  @router.post(ROUTE, tags=["Text Task"], description=DESCRIPTION)
48
+ async def evaluate_text(request: TextEvaluationRequest):
49
  """
50
  Evaluate text classification using precomputed embeddings and cosine similarity.
51
  """
52
+ # Get space info
53
+ username, space_url = get_space_info()
54
+
55
  # Load dataset
56
  dataset = load_dataset("QuotaClimat/frugalaichallenge-text-train")
57
  df_train = dataset["train"].to_pandas()
 
62
  quotes = df["quote"].tolist()
63
  true_labels = df["label"].apply(lambda x: int(x.split("_")[0]) if isinstance(x, str) else 0).tolist()
64
 
65
+ # Start tracking emissions
66
+ tracker.start()
67
+ tracker.start_task("inference")
68
+
69
+ # --------------------------------------------------------------------------------------------
70
+ # Optimized cosine similarity-based classification
71
+ # --------------------------------------------------------------------------------------------
72
+
73
  # Encode dataset quotes in batches
74
  batch_size = 32
75
  quote_embeddings = embedding_model.encode(quotes, batch_size=batch_size, convert_to_numpy=True, normalize_embeddings=True)
 
84
  # Apply threshold (0.9) for classification
85
  predicted_labels = [best_idx if best_sim > 0.9 else 0 for best_idx, best_sim in zip(best_indices, best_similarities)]
86
 
87
+ # Stop tracking emissions
88
+ emissions_data = tracker.stop_task()
89
+
90
  # Calculate accuracy
91
  accuracy = accuracy_score(true_labels, predicted_labels)
92
 
93
  # Prepare results dictionary
94
  results = {
95
+ "username": username,
96
+ "space_url": space_url,
97
  "submission_timestamp": datetime.now().isoformat(),
98
  "model_description": DESCRIPTION,
99
  "accuracy": float(accuracy),
100
+ "energy_consumed_wh": emissions_data.energy_consumed * 1000,
101
+ "emissions_gco2eq": emissions_data.emissions * 1000,
102
+ "emissions_data": clean_emissions_data(emissions_data),
103
  "api_route": ROUTE,
104
  "dataset_config": {
105
  "dataset_name": "QuotaClimat/frugalaichallenge-text-train",