from fastapi import APIRouter from datetime import datetime from datasets import load_dataset from sklearn.metrics import accuracy_score from sentence_transformers import SentenceTransformer import numpy as np from .utils.emissions import clean_emissions_data, get_space_info, tracker from .utils.evaluation import TextEvaluationRequest router = APIRouter() DESCRIPTION = "Efficient embedding-based classification with similarity threshold" ROUTE = "/text" # Load custom embedding model model_name = "pedro-thenewsroom/climate-misinfo-embed" # model_name = "pedro-thenewsroom/climate-misinfo-embed-8bit" embedding_model = SentenceTransformer(model_name) # Define label mapping LABEL_MAPPING = { "0_not_relevant": 0, "1_not_happening": 1, "2_not_human": 2, "3_not_bad": 3, "4_solutions_harmful_unnecessary": 4, "5_science_unreliable": 5, "6_proponents_biased": 6, "7_fossil_fuels_needed": 7, } # Class descriptions for embedding comparison class_descriptions = { "1_not_happening": "Despite the alarmists' claims of global warming, temperatures have remained steady or even dropped in many areas, proving that climate change is nothing more than natural variability.", "2_not_human": "The Earth's climate has always changed due to natural cycles and external factors, and the role of human activity or CO2 emissions in driving these changes is negligible or unsupported by evidence.", "3_not_bad": "Contrary to the alarmist rhetoric, rising CO2 levels and modest warming are fostering a greener planet, boosting crop yields, and enhancing global prosperity while posing no significant threat to human or environmental health.", "4_solutions_harmful_unnecessary": "Global climate policies are costly, ineffective, and fail to address the unchecked emissions from developing nations, rendering efforts by industrialized countries futile and economically damaging.", "5_science_unreliable": "The so-called consensus on climate change relies on flawed models, manipulated data, and a refusal to address legitimate scientific uncertainties, all to serve a predetermined political narrative.", "6_proponents_biased": "Climate change is nothing more than a fabricated agenda pushed by corrupt elites, politicians, and scientists to control the masses, gain wealth, and suppress freedom.", "7_fossil_fuels_needed": "Fossil fuels have powered centuries of progress, lifted billions out of poverty, and remain the backbone of global energy, while alternatives, though promising, cannot yet match their scale, reliability, or affordability.", } # Precompute class embeddings (normalized for cosine similarity) class_labels = list(class_descriptions.keys()) class_sentences = list(class_descriptions.values()) class_embeddings = embedding_model.encode(class_sentences, batch_size=8, convert_to_numpy=True, normalize_embeddings=True) @router.post(ROUTE, tags=["Text Task"], description=DESCRIPTION) async def evaluate_text(request: TextEvaluationRequest): """ Evaluate text classification for climate disinformation detection using cosine similarity. """ # Get space info username, space_url = get_space_info() # Load and prepare the dataset dataset = load_dataset(request.dataset_name) # Convert dataset labels to integers based on LABEL_MAPPING dataset = dataset.map(lambda x: {"label": LABEL_MAPPING[x["label"]]}) # Split dataset into train and test train_test = dataset["train"].train_test_split(test_size=request.test_size, seed=request.test_seed) test_dataset = train_test["test"] # Start tracking emissions tracker.start() tracker.start_task("inference") # -------------------------------------------------------------------------------------------- # Optimized cosine similarity-based classification with threshold # -------------------------------------------------------------------------------------------- # Convert "quote" key into embeddings def embed_quote(example): example["quote_embedding"] = embedding_model.encode(example["quote"]).tolist() return example test_dataset = test_dataset.map(embed_quote, batched=True) # Convert test embeddings to numpy array test_embeddings = np.array(test_dataset["quote_embedding"]) # Compute cosine similarity in a single operation similarity_matrix = np.dot(test_embeddings, class_embeddings.T) # Efficient matrix multiplication best_indices = similarity_matrix.argmax(axis=1) # Get index of highest similarity for each test sample best_similarities = similarity_matrix.max(axis=1) # Get max similarity values # Apply threshold (0.8) for classification predictions = [ LABEL_MAPPING[class_labels[idx]] if sim > 0.8 else LABEL_MAPPING["0_not_relevant"] for idx, sim in zip(best_indices, best_similarities) ] # Get ground truth labels true_labels = test_dataset["label"] # -------------------------------------------------------------------------------------------- # Stop tracking emissions emissions_data = tracker.stop_task() # Calculate accuracy accuracy = accuracy_score(true_labels, predictions) # Prepare results dictionary results = { "username": username, "space_url": space_url, "submission_timestamp": datetime.now().isoformat(), "model_description": DESCRIPTION, "accuracy": float(accuracy), "energy_consumed_wh": emissions_data.energy_consumed * 1000, "emissions_gco2eq": emissions_data.emissions * 1000, "emissions_data": clean_emissions_data(emissions_data), "api_route": ROUTE, "dataset_config": { "dataset_name": request.dataset_name, "test_size": request.test_size, "test_seed": request.test_seed, }, } return results