Update tasks/text.py
Browse files- tasks/text.py +23 -1
tasks/text.py
CHANGED
@@ -5,6 +5,9 @@ from sklearn.metrics import accuracy_score
|
|
5 |
from sentence_transformers import SentenceTransformer
|
6 |
import numpy as np
|
7 |
|
|
|
|
|
|
|
8 |
router = APIRouter()
|
9 |
|
10 |
DESCRIPTION = "Class embeddings with cosine similarity using batching and thresholding"
|
@@ -42,10 +45,13 @@ class_embeddings = embedding_model.encode(class_labels, batch_size=8, convert_to
|
|
42 |
|
43 |
|
44 |
@router.post(ROUTE, tags=["Text Task"], description=DESCRIPTION)
|
45 |
-
async def evaluate_text(request:
|
46 |
"""
|
47 |
Evaluate text classification using precomputed embeddings and cosine similarity.
|
48 |
"""
|
|
|
|
|
|
|
49 |
# Load dataset
|
50 |
dataset = load_dataset("QuotaClimat/frugalaichallenge-text-train")
|
51 |
df_train = dataset["train"].to_pandas()
|
@@ -56,6 +62,14 @@ async def evaluate_text(request: dict):
|
|
56 |
quotes = df["quote"].tolist()
|
57 |
true_labels = df["label"].apply(lambda x: int(x.split("_")[0]) if isinstance(x, str) else 0).tolist()
|
58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
# Encode dataset quotes in batches
|
60 |
batch_size = 32
|
61 |
quote_embeddings = embedding_model.encode(quotes, batch_size=batch_size, convert_to_numpy=True, normalize_embeddings=True)
|
@@ -70,14 +84,22 @@ async def evaluate_text(request: dict):
|
|
70 |
# Apply threshold (0.9) for classification
|
71 |
predicted_labels = [best_idx if best_sim > 0.9 else 0 for best_idx, best_sim in zip(best_indices, best_similarities)]
|
72 |
|
|
|
|
|
|
|
73 |
# Calculate accuracy
|
74 |
accuracy = accuracy_score(true_labels, predicted_labels)
|
75 |
|
76 |
# Prepare results dictionary
|
77 |
results = {
|
|
|
|
|
78 |
"submission_timestamp": datetime.now().isoformat(),
|
79 |
"model_description": DESCRIPTION,
|
80 |
"accuracy": float(accuracy),
|
|
|
|
|
|
|
81 |
"api_route": ROUTE,
|
82 |
"dataset_config": {
|
83 |
"dataset_name": "QuotaClimat/frugalaichallenge-text-train",
|
|
|
5 |
from sentence_transformers import SentenceTransformer
|
6 |
import numpy as np
|
7 |
|
8 |
+
from .utils.emissions import clean_emissions_data, get_space_info, tracker
|
9 |
+
from .utils.evaluation import TextEvaluationRequest
|
10 |
+
|
11 |
router = APIRouter()
|
12 |
|
13 |
DESCRIPTION = "Class embeddings with cosine similarity using batching and thresholding"
|
|
|
45 |
|
46 |
|
47 |
@router.post(ROUTE, tags=["Text Task"], description=DESCRIPTION)
|
48 |
+
async def evaluate_text(request: TextEvaluationRequest):
|
49 |
"""
|
50 |
Evaluate text classification using precomputed embeddings and cosine similarity.
|
51 |
"""
|
52 |
+
# Get space info
|
53 |
+
username, space_url = get_space_info()
|
54 |
+
|
55 |
# Load dataset
|
56 |
dataset = load_dataset("QuotaClimat/frugalaichallenge-text-train")
|
57 |
df_train = dataset["train"].to_pandas()
|
|
|
62 |
quotes = df["quote"].tolist()
|
63 |
true_labels = df["label"].apply(lambda x: int(x.split("_")[0]) if isinstance(x, str) else 0).tolist()
|
64 |
|
65 |
+
# Start tracking emissions
|
66 |
+
tracker.start()
|
67 |
+
tracker.start_task("inference")
|
68 |
+
|
69 |
+
# --------------------------------------------------------------------------------------------
|
70 |
+
# Optimized cosine similarity-based classification
|
71 |
+
# --------------------------------------------------------------------------------------------
|
72 |
+
|
73 |
# Encode dataset quotes in batches
|
74 |
batch_size = 32
|
75 |
quote_embeddings = embedding_model.encode(quotes, batch_size=batch_size, convert_to_numpy=True, normalize_embeddings=True)
|
|
|
84 |
# Apply threshold (0.9) for classification
|
85 |
predicted_labels = [best_idx if best_sim > 0.9 else 0 for best_idx, best_sim in zip(best_indices, best_similarities)]
|
86 |
|
87 |
+
# Stop tracking emissions
|
88 |
+
emissions_data = tracker.stop_task()
|
89 |
+
|
90 |
# Calculate accuracy
|
91 |
accuracy = accuracy_score(true_labels, predicted_labels)
|
92 |
|
93 |
# Prepare results dictionary
|
94 |
results = {
|
95 |
+
"username": username,
|
96 |
+
"space_url": space_url,
|
97 |
"submission_timestamp": datetime.now().isoformat(),
|
98 |
"model_description": DESCRIPTION,
|
99 |
"accuracy": float(accuracy),
|
100 |
+
"energy_consumed_wh": emissions_data.energy_consumed * 1000,
|
101 |
+
"emissions_gco2eq": emissions_data.emissions * 1000,
|
102 |
+
"emissions_data": clean_emissions_data(emissions_data),
|
103 |
"api_route": ROUTE,
|
104 |
"dataset_config": {
|
105 |
"dataset_name": "QuotaClimat/frugalaichallenge-text-train",
|