File size: 5,893 Bytes
4d6e8c2 90194b0 73febd1 4d6e8c2 a1a5fb1 4d6e8c2 73febd1 1c33274 70f5f26 48ed843 f193c7f 5ba80f7 48ed843 73febd1 48ed843 73febd1 5ba80f7 a1a5fb1 4d6e8c2 48ed843 4d6e8c2 a1a5fb1 48ed843 5ba80f7 48ed843 5ba80f7 a1a5fb1 73febd1 a1a5fb1 aa18df0 941eb28 73febd1 941eb28 73febd1 0a7a34d 73febd1 48ed843 4d6e8c2 48ed843 a1a5fb1 4d6e8c2 48ed843 5ba80f7 4d6e8c2 a1a5fb1 4d6e8c2 70f5f26 4d6e8c2 a1a5fb1 1c33274 4d6e8c2 48ed843 5ba80f7 4d6e8c2 5ba80f7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
from fastapi import APIRouter
from datetime import datetime
from datasets import load_dataset
from sklearn.metrics import accuracy_score
from sentence_transformers import SentenceTransformer
import numpy as np
from .utils.emissions import clean_emissions_data, get_space_info, tracker
from .utils.evaluation import TextEvaluationRequest
router = APIRouter()
DESCRIPTION = "Efficient embedding-based classification with similarity threshold"
ROUTE = "/text"
# Load custom embedding model
model_name = "pedro-thenewsroom/climate-misinfo-embed"
# model_name = "pedro-thenewsroom/climate-misinfo-embed-8bit"
embedding_model = SentenceTransformer(model_name)
# Define label mapping
LABEL_MAPPING = {
"0_not_relevant": 0,
"1_not_happening": 1,
"2_not_human": 2,
"3_not_bad": 3,
"4_solutions_harmful_unnecessary": 4,
"5_science_unreliable": 5,
"6_proponents_biased": 6,
"7_fossil_fuels_needed": 7,
}
# Class descriptions for embedding comparison
class_descriptions = {
"1_not_happening": "Despite the alarmists' claims of global warming, temperatures have remained steady or even dropped in many areas, proving that climate change is nothing more than natural variability.",
"2_not_human": "The Earth's climate has always changed due to natural cycles and external factors, and the role of human activity or CO2 emissions in driving these changes is negligible or unsupported by evidence.",
"3_not_bad": "Contrary to the alarmist rhetoric, rising CO2 levels and modest warming are fostering a greener planet, boosting crop yields, and enhancing global prosperity while posing no significant threat to human or environmental health.",
"4_solutions_harmful_unnecessary": "Global climate policies are costly, ineffective, and fail to address the unchecked emissions from developing nations, rendering efforts by industrialized countries futile and economically damaging.",
"5_science_unreliable": "The so-called consensus on climate change relies on flawed models, manipulated data, and a refusal to address legitimate scientific uncertainties, all to serve a predetermined political narrative.",
"6_proponents_biased": "Climate change is nothing more than a fabricated agenda pushed by corrupt elites, politicians, and scientists to control the masses, gain wealth, and suppress freedom.",
"7_fossil_fuels_needed": "Fossil fuels have powered centuries of progress, lifted billions out of poverty, and remain the backbone of global energy, while alternatives, though promising, cannot yet match their scale, reliability, or affordability.",
}
# Precompute class embeddings (normalized for cosine similarity)
class_labels = list(class_descriptions.keys())
class_sentences = list(class_descriptions.values())
class_embeddings = embedding_model.encode(class_sentences, batch_size=8, convert_to_numpy=True, normalize_embeddings=True)
@router.post(ROUTE, tags=["Text Task"], description=DESCRIPTION)
async def evaluate_text(request: TextEvaluationRequest):
"""
Evaluate text classification for climate disinformation detection using cosine similarity.
"""
# Get space info
username, space_url = get_space_info()
# Load and prepare the dataset
dataset = load_dataset(request.dataset_name)
# Convert dataset labels to integers based on LABEL_MAPPING
dataset = dataset.map(lambda x: {"label": LABEL_MAPPING[x["label"]]})
# Split dataset into train and test
train_test = dataset["train"].train_test_split(test_size=request.test_size, seed=request.test_seed)
test_dataset = train_test["test"]
# Start tracking emissions
tracker.start()
tracker.start_task("inference")
# --------------------------------------------------------------------------------------------
# Optimized cosine similarity-based classification with threshold
# --------------------------------------------------------------------------------------------
# Convert "quote" key into embeddings
def embed_quote(example):
example["quote_embedding"] = embedding_model.encode(example["quote"]).tolist()
return example
test_dataset = test_dataset.map(embed_quote, batched=True)
# Convert test embeddings to numpy array
test_embeddings = np.array(test_dataset["quote_embedding"])
# Compute cosine similarity in a single operation
similarity_matrix = np.dot(test_embeddings, class_embeddings.T) # Efficient matrix multiplication
best_indices = similarity_matrix.argmax(axis=1) # Get index of highest similarity for each test sample
best_similarities = similarity_matrix.max(axis=1) # Get max similarity values
# Apply threshold (0.8) for classification
predictions = [
LABEL_MAPPING[class_labels[idx]] if sim > 0.8 else LABEL_MAPPING["0_not_relevant"]
for idx, sim in zip(best_indices, best_similarities)
]
# Get ground truth labels
true_labels = test_dataset["label"]
# --------------------------------------------------------------------------------------------
# Stop tracking emissions
emissions_data = tracker.stop_task()
# Calculate accuracy
accuracy = accuracy_score(true_labels, predictions)
# Prepare results dictionary
results = {
"username": username,
"space_url": space_url,
"submission_timestamp": datetime.now().isoformat(),
"model_description": DESCRIPTION,
"accuracy": float(accuracy),
"energy_consumed_wh": emissions_data.energy_consumed * 1000,
"emissions_gco2eq": emissions_data.emissions * 1000,
"emissions_data": clean_emissions_data(emissions_data),
"api_route": ROUTE,
"dataset_config": {
"dataset_name": request.dataset_name,
"test_size": request.test_size,
"test_seed": request.test_seed,
},
}
return results
|