Update tasks/text.py
Browse files- tasks/text.py +67 -71
tasks/text.py
CHANGED
@@ -2,91 +2,87 @@ from fastapi import APIRouter
|
|
2 |
from datetime import datetime
|
3 |
from datasets import load_dataset
|
4 |
from sklearn.metrics import accuracy_score
|
5 |
-
import
|
6 |
-
|
7 |
-
from .utils.evaluation import TextEvaluationRequest
|
8 |
-
from .utils.emissions import tracker, clean_emissions_data, get_space_info
|
9 |
|
10 |
router = APIRouter()
|
11 |
|
12 |
-
DESCRIPTION = "
|
13 |
ROUTE = "/text"
|
14 |
|
15 |
-
|
16 |
-
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
"""
|
19 |
-
Evaluate text classification
|
20 |
-
|
21 |
-
Current Model: Random Baseline
|
22 |
-
- Makes random predictions from the label space (0-7)
|
23 |
-
- Used as a baseline for comparison
|
24 |
"""
|
25 |
-
#
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
-
# Load and prepare the dataset
|
41 |
-
dataset = load_dataset(request.dataset_name)
|
42 |
-
|
43 |
-
# Convert string labels to integers
|
44 |
-
dataset = dataset.map(lambda x: {"label": LABEL_MAPPING[x["label"]]})
|
45 |
-
|
46 |
-
# Split dataset
|
47 |
-
train_test = dataset["train"]
|
48 |
-
test_dataset = dataset["test"]
|
49 |
-
|
50 |
-
# Start tracking emissions
|
51 |
-
tracker.start()
|
52 |
-
tracker.start_task("inference")
|
53 |
-
|
54 |
-
#--------------------------------------------------------------------------------------------
|
55 |
-
# YOUR MODEL INFERENCE CODE HERE
|
56 |
-
# Update the code below to replace the random baseline by your model inference within the inference pass where the energy consumption and emissions are tracked.
|
57 |
-
#--------------------------------------------------------------------------------------------
|
58 |
-
|
59 |
-
# Make random predictions (placeholder for actual model inference)
|
60 |
-
true_labels = test_dataset["label"]
|
61 |
-
predictions = [random.randint(0, 7) for _ in range(len(true_labels))]
|
62 |
-
|
63 |
-
#--------------------------------------------------------------------------------------------
|
64 |
-
# YOUR MODEL INFERENCE STOPS HERE
|
65 |
-
#--------------------------------------------------------------------------------------------
|
66 |
-
|
67 |
-
|
68 |
-
# Stop tracking emissions
|
69 |
-
emissions_data = tracker.stop_task()
|
70 |
-
|
71 |
# Calculate accuracy
|
72 |
-
accuracy = accuracy_score(true_labels,
|
73 |
-
|
74 |
# Prepare results dictionary
|
75 |
results = {
|
76 |
-
"username": username,
|
77 |
-
"space_url": space_url,
|
78 |
"submission_timestamp": datetime.now().isoformat(),
|
79 |
"model_description": DESCRIPTION,
|
80 |
"accuracy": float(accuracy),
|
81 |
-
"energy_consumed_wh": emissions_data.energy_consumed * 1000,
|
82 |
-
"emissions_gco2eq": emissions_data.emissions * 1000,
|
83 |
-
"emissions_data": clean_emissions_data(emissions_data),
|
84 |
"api_route": ROUTE,
|
85 |
"dataset_config": {
|
86 |
-
"dataset_name":
|
87 |
-
"
|
88 |
-
|
89 |
-
}
|
90 |
}
|
91 |
-
|
92 |
-
return results
|
|
|
2 |
from datetime import datetime
|
3 |
from datasets import load_dataset
|
4 |
from sklearn.metrics import accuracy_score
|
5 |
+
from sentence_transformers import SentenceTransformer
|
6 |
+
import numpy as np
|
|
|
|
|
7 |
|
8 |
router = APIRouter()
|
9 |
|
10 |
+
DESCRIPTION = "Class embeddings with cosine similarity using batching and thresholding"
|
11 |
ROUTE = "/text"
|
12 |
|
13 |
+
# Load fine-tuned model
|
14 |
+
model_name = "pedro-thenewsroom/climate-misinfo-embed"
|
15 |
+
embedding_model = SentenceTransformer(model_name)
|
16 |
+
|
17 |
+
# Define label mapping
|
18 |
+
LABEL_MAPPING = {
|
19 |
+
"0_not_relevant": 0,
|
20 |
+
"1_not_happening": 1,
|
21 |
+
"2_not_human": 2,
|
22 |
+
"3_not_bad": 3,
|
23 |
+
"4_solutions_harmful_unnecessary": 4,
|
24 |
+
"5_science_unreliable": 5,
|
25 |
+
"6_proponents_biased": 6,
|
26 |
+
"7_fossil_fuels_needed": 7,
|
27 |
+
}
|
28 |
+
|
29 |
+
# Class descriptions for embedding comparison
|
30 |
+
class_labels = [
|
31 |
+
"Despite the alarmists' claims of global warming, temperatures have remained steady or even dropped in many areas, proving that climate change is nothing more than natural variability.",
|
32 |
+
"The Earth's climate has always changed due to natural cycles and external factors, and the role of human activity or CO2 emissions in driving these changes is negligible or unsupported by evidence.",
|
33 |
+
"Contrary to the alarmist rhetoric, rising CO2 levels and modest warming are fostering a greener planet, boosting crop yields, and enhancing global prosperity while posing no significant threat to human or environmental health.",
|
34 |
+
"Global climate policies are costly, ineffective, and fail to address the unchecked emissions from developing nations, rendering efforts by industrialized countries futile and economically damaging.",
|
35 |
+
"The so-called consensus on climate change relies on flawed models, manipulated data, and a refusal to address legitimate scientific uncertainties, all to serve a predetermined political narrative.",
|
36 |
+
"Climate change is nothing more than a fabricated agenda pushed by corrupt elites, politicians, and scientists to control the masses, gain wealth, and suppress freedom.",
|
37 |
+
"Fossil fuels have powered centuries of progress, lifted billions out of poverty, and remain the backbone of global energy, while alternatives, though promising, cannot yet match their scale, reliability, or affordability.",
|
38 |
+
]
|
39 |
+
|
40 |
+
# Precompute normalized class embeddings
|
41 |
+
class_embeddings = embedding_model.encode(class_labels, batch_size=8, convert_to_numpy=True, normalize_embeddings=True)
|
42 |
+
|
43 |
+
|
44 |
+
@router.post(ROUTE, tags=["Text Task"], description=DESCRIPTION)
|
45 |
+
async def evaluate_text(request: dict):
|
46 |
"""
|
47 |
+
Evaluate text classification using precomputed embeddings and cosine similarity.
|
|
|
|
|
|
|
|
|
48 |
"""
|
49 |
+
# Load dataset
|
50 |
+
dataset = load_dataset("QuotaClimat/frugalaichallenge-text-train")
|
51 |
+
df_train = dataset["train"].to_pandas()
|
52 |
+
df_test = dataset["test"].to_pandas()
|
53 |
+
df = pd.concat([df_train, df_test], ignore_index=True)
|
54 |
+
|
55 |
+
# Extract quotes and their true labels
|
56 |
+
quotes = df["quote"].tolist()
|
57 |
+
true_labels = df["label"].apply(lambda x: int(x.split("_")[0]) if isinstance(x, str) else 0).tolist()
|
58 |
+
|
59 |
+
# Encode dataset quotes in batches
|
60 |
+
batch_size = 32
|
61 |
+
quote_embeddings = embedding_model.encode(quotes, batch_size=batch_size, convert_to_numpy=True, normalize_embeddings=True)
|
62 |
+
|
63 |
+
# Compute cosine similarity using matrix multiplication (efficient)
|
64 |
+
cosine_similarities = np.dot(quote_embeddings, class_embeddings.T)
|
65 |
+
|
66 |
+
# Get the best match for each sentence
|
67 |
+
best_indices = np.argmax(cosine_similarities, axis=1)
|
68 |
+
best_similarities = np.max(cosine_similarities, axis=1)
|
69 |
+
|
70 |
+
# Apply threshold (0.9) for classification
|
71 |
+
predicted_labels = [best_idx if best_sim > 0.9 else 0 for best_idx, best_sim in zip(best_indices, best_similarities)]
|
72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
# Calculate accuracy
|
74 |
+
accuracy = accuracy_score(true_labels, predicted_labels)
|
75 |
+
|
76 |
# Prepare results dictionary
|
77 |
results = {
|
|
|
|
|
78 |
"submission_timestamp": datetime.now().isoformat(),
|
79 |
"model_description": DESCRIPTION,
|
80 |
"accuracy": float(accuracy),
|
|
|
|
|
|
|
81 |
"api_route": ROUTE,
|
82 |
"dataset_config": {
|
83 |
+
"dataset_name": "QuotaClimat/frugalaichallenge-text-train",
|
84 |
+
"total_samples": len(df),
|
85 |
+
},
|
|
|
86 |
}
|
87 |
+
|
88 |
+
return results
|