pedro-thenewsroom's picture
Update tasks/text.py
f193c7f verified
from fastapi import APIRouter
from datetime import datetime
from datasets import load_dataset
from sklearn.metrics import accuracy_score
from sentence_transformers import SentenceTransformer
import numpy as np
from .utils.emissions import clean_emissions_data, get_space_info, tracker
from .utils.evaluation import TextEvaluationRequest
router = APIRouter()
DESCRIPTION = "Efficient embedding-based classification with similarity threshold"
ROUTE = "/text"
# Load custom embedding model
model_name = "pedro-thenewsroom/climate-misinfo-embed"
# model_name = "pedro-thenewsroom/climate-misinfo-embed-8bit"
embedding_model = SentenceTransformer(model_name)
# Define label mapping
LABEL_MAPPING = {
"0_not_relevant": 0,
"1_not_happening": 1,
"2_not_human": 2,
"3_not_bad": 3,
"4_solutions_harmful_unnecessary": 4,
"5_science_unreliable": 5,
"6_proponents_biased": 6,
"7_fossil_fuels_needed": 7,
}
# Class descriptions for embedding comparison
class_descriptions = {
"1_not_happening": "Despite the alarmists' claims of global warming, temperatures have remained steady or even dropped in many areas, proving that climate change is nothing more than natural variability.",
"2_not_human": "The Earth's climate has always changed due to natural cycles and external factors, and the role of human activity or CO2 emissions in driving these changes is negligible or unsupported by evidence.",
"3_not_bad": "Contrary to the alarmist rhetoric, rising CO2 levels and modest warming are fostering a greener planet, boosting crop yields, and enhancing global prosperity while posing no significant threat to human or environmental health.",
"4_solutions_harmful_unnecessary": "Global climate policies are costly, ineffective, and fail to address the unchecked emissions from developing nations, rendering efforts by industrialized countries futile and economically damaging.",
"5_science_unreliable": "The so-called consensus on climate change relies on flawed models, manipulated data, and a refusal to address legitimate scientific uncertainties, all to serve a predetermined political narrative.",
"6_proponents_biased": "Climate change is nothing more than a fabricated agenda pushed by corrupt elites, politicians, and scientists to control the masses, gain wealth, and suppress freedom.",
"7_fossil_fuels_needed": "Fossil fuels have powered centuries of progress, lifted billions out of poverty, and remain the backbone of global energy, while alternatives, though promising, cannot yet match their scale, reliability, or affordability.",
}
# Precompute class embeddings (normalized for cosine similarity)
class_labels = list(class_descriptions.keys())
class_sentences = list(class_descriptions.values())
class_embeddings = embedding_model.encode(class_sentences, batch_size=8, convert_to_numpy=True, normalize_embeddings=True)
@router.post(ROUTE, tags=["Text Task"], description=DESCRIPTION)
async def evaluate_text(request: TextEvaluationRequest):
"""
Evaluate text classification for climate disinformation detection using cosine similarity.
"""
# Get space info
username, space_url = get_space_info()
# Load and prepare the dataset
dataset = load_dataset(request.dataset_name)
# Convert dataset labels to integers based on LABEL_MAPPING
dataset = dataset.map(lambda x: {"label": LABEL_MAPPING[x["label"]]})
# Split dataset into train and test
train_test = dataset["train"].train_test_split(test_size=request.test_size, seed=request.test_seed)
test_dataset = train_test["test"]
# Start tracking emissions
tracker.start()
tracker.start_task("inference")
# --------------------------------------------------------------------------------------------
# Optimized cosine similarity-based classification with threshold
# --------------------------------------------------------------------------------------------
# Convert "quote" key into embeddings
def embed_quote(example):
example["quote_embedding"] = embedding_model.encode(example["quote"]).tolist()
return example
test_dataset = test_dataset.map(embed_quote, batched=True)
# Convert test embeddings to numpy array
test_embeddings = np.array(test_dataset["quote_embedding"])
# Compute cosine similarity in a single operation
similarity_matrix = np.dot(test_embeddings, class_embeddings.T) # Efficient matrix multiplication
best_indices = similarity_matrix.argmax(axis=1) # Get index of highest similarity for each test sample
best_similarities = similarity_matrix.max(axis=1) # Get max similarity values
# Apply threshold (0.8) for classification
predictions = [
LABEL_MAPPING[class_labels[idx]] if sim > 0.8 else LABEL_MAPPING["0_not_relevant"]
for idx, sim in zip(best_indices, best_similarities)
]
# Get ground truth labels
true_labels = test_dataset["label"]
# --------------------------------------------------------------------------------------------
# Stop tracking emissions
emissions_data = tracker.stop_task()
# Calculate accuracy
accuracy = accuracy_score(true_labels, predictions)
# Prepare results dictionary
results = {
"username": username,
"space_url": space_url,
"submission_timestamp": datetime.now().isoformat(),
"model_description": DESCRIPTION,
"accuracy": float(accuracy),
"energy_consumed_wh": emissions_data.energy_consumed * 1000,
"emissions_gco2eq": emissions_data.emissions * 1000,
"emissions_data": clean_emissions_data(emissions_data),
"api_route": ROUTE,
"dataset_config": {
"dataset_name": request.dataset_name,
"test_size": request.test_size,
"test_seed": request.test_seed,
},
}
return results