|
from fastapi import APIRouter |
|
from datetime import datetime |
|
from datasets import load_dataset |
|
from sklearn.metrics import accuracy_score |
|
from sentence_transformers import SentenceTransformer |
|
import numpy as np |
|
|
|
from .utils.emissions import clean_emissions_data, get_space_info, tracker |
|
from .utils.evaluation import TextEvaluationRequest |
|
|
|
router = APIRouter() |
|
|
|
DESCRIPTION = "Efficient embedding-based classification with similarity threshold" |
|
ROUTE = "/text" |
|
|
|
|
|
model_name = "pedro-thenewsroom/climate-misinfo-embed" |
|
|
|
embedding_model = SentenceTransformer(model_name) |
|
|
|
|
|
LABEL_MAPPING = { |
|
"0_not_relevant": 0, |
|
"1_not_happening": 1, |
|
"2_not_human": 2, |
|
"3_not_bad": 3, |
|
"4_solutions_harmful_unnecessary": 4, |
|
"5_science_unreliable": 5, |
|
"6_proponents_biased": 6, |
|
"7_fossil_fuels_needed": 7, |
|
} |
|
|
|
|
|
class_descriptions = { |
|
"1_not_happening": "Despite the alarmists' claims of global warming, temperatures have remained steady or even dropped in many areas, proving that climate change is nothing more than natural variability.", |
|
"2_not_human": "The Earth's climate has always changed due to natural cycles and external factors, and the role of human activity or CO2 emissions in driving these changes is negligible or unsupported by evidence.", |
|
"3_not_bad": "Contrary to the alarmist rhetoric, rising CO2 levels and modest warming are fostering a greener planet, boosting crop yields, and enhancing global prosperity while posing no significant threat to human or environmental health.", |
|
"4_solutions_harmful_unnecessary": "Global climate policies are costly, ineffective, and fail to address the unchecked emissions from developing nations, rendering efforts by industrialized countries futile and economically damaging.", |
|
"5_science_unreliable": "The so-called consensus on climate change relies on flawed models, manipulated data, and a refusal to address legitimate scientific uncertainties, all to serve a predetermined political narrative.", |
|
"6_proponents_biased": "Climate change is nothing more than a fabricated agenda pushed by corrupt elites, politicians, and scientists to control the masses, gain wealth, and suppress freedom.", |
|
"7_fossil_fuels_needed": "Fossil fuels have powered centuries of progress, lifted billions out of poverty, and remain the backbone of global energy, while alternatives, though promising, cannot yet match their scale, reliability, or affordability.", |
|
} |
|
|
|
|
|
class_labels = list(class_descriptions.keys()) |
|
class_sentences = list(class_descriptions.values()) |
|
class_embeddings = embedding_model.encode(class_sentences, batch_size=8, convert_to_numpy=True, normalize_embeddings=True) |
|
|
|
@router.post(ROUTE, tags=["Text Task"], description=DESCRIPTION) |
|
async def evaluate_text(request: TextEvaluationRequest): |
|
""" |
|
Evaluate text classification for climate disinformation detection using cosine similarity. |
|
""" |
|
|
|
username, space_url = get_space_info() |
|
|
|
|
|
dataset = load_dataset(request.dataset_name) |
|
|
|
|
|
dataset = dataset.map(lambda x: {"label": LABEL_MAPPING[x["label"]]}) |
|
|
|
|
|
train_test = dataset["train"].train_test_split(test_size=request.test_size, seed=request.test_seed) |
|
test_dataset = train_test["test"] |
|
|
|
|
|
tracker.start() |
|
tracker.start_task("inference") |
|
|
|
|
|
|
|
|
|
|
|
|
|
def embed_quote(example): |
|
example["quote_embedding"] = embedding_model.encode(example["quote"]).tolist() |
|
return example |
|
|
|
test_dataset = test_dataset.map(embed_quote, batched=True) |
|
|
|
|
|
test_embeddings = np.array(test_dataset["quote_embedding"]) |
|
|
|
|
|
similarity_matrix = np.dot(test_embeddings, class_embeddings.T) |
|
best_indices = similarity_matrix.argmax(axis=1) |
|
best_similarities = similarity_matrix.max(axis=1) |
|
|
|
|
|
predictions = [ |
|
LABEL_MAPPING[class_labels[idx]] if sim > 0.8 else LABEL_MAPPING["0_not_relevant"] |
|
for idx, sim in zip(best_indices, best_similarities) |
|
] |
|
|
|
|
|
true_labels = test_dataset["label"] |
|
|
|
|
|
|
|
emissions_data = tracker.stop_task() |
|
|
|
|
|
accuracy = accuracy_score(true_labels, predictions) |
|
|
|
|
|
results = { |
|
"username": username, |
|
"space_url": space_url, |
|
"submission_timestamp": datetime.now().isoformat(), |
|
"model_description": DESCRIPTION, |
|
"accuracy": float(accuracy), |
|
"energy_consumed_wh": emissions_data.energy_consumed * 1000, |
|
"emissions_gco2eq": emissions_data.emissions * 1000, |
|
"emissions_data": clean_emissions_data(emissions_data), |
|
"api_route": ROUTE, |
|
"dataset_config": { |
|
"dataset_name": request.dataset_name, |
|
"test_size": request.test_size, |
|
"test_seed": request.test_seed, |
|
}, |
|
} |
|
|
|
return results |
|
|