from fastapi import APIRouter from datetime import datetime from datasets import load_dataset from sklearn.metrics import accuracy_score import torch from transformers import AutoTokenizer, AutoModelForSequenceClassification, ModernBertConfig from torch.utils.data import DataLoader from transformers import DataCollatorWithPadding from .utils.evaluation import TextEvaluationRequest from .utils.emissions import tracker, clean_emissions_data, get_space_info router = APIRouter() DESCRIPTION = "Climate Guard Toxic Agent is a ModernBERT for Climate Disinformation Detection" ROUTE = "/text" @router.post(ROUTE, tags=["Text Task"], description=DESCRIPTION) async def evaluate_text(request: TextEvaluationRequest): """ Evaluate text classification for climate disinformation detection using ModernBERT. """ # Get space info username, space_url = get_space_info() # Define the label mapping LABEL_MAPPING = { "0_not_relevant": 0, "1_not_happening": 1, "2_not_human": 2, "3_not_bad": 3, "4_solutions_harmful_unnecessary": 4, "5_science_unreliable": 5, "6_proponents_biased": 6, "7_fossil_fuels_needed": 7 } # Load and prepare the dataset dataset = load_dataset(request.dataset_name) # Convert string labels to integers dataset = dataset.map(lambda x: {"label": LABEL_MAPPING[x["label"]]}) # Get test dataset test_dataset = dataset["test"] # Start tracking emissions tracker.start() tracker.start_task("inference") #-------------------------------------------------------------------------------------------- # MODEL INFERENCE CODE #-------------------------------------------------------------------------------------------- try: # Set device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Model and tokenizer paths model_name = "Tonic/climate-guard-toxic-agent" tokenizer_name = "Tonic/climate-guard-toxic-agent" # Create config config = ModernBertConfig( vocab_size=50368, hidden_size=768, num_hidden_layers=22, num_attention_heads=12, intermediate_size=1152, max_position_embeddings=8192, layer_norm_eps=1e-5, position_embedding_type="absolute", pad_token_id=50283, bos_token_id=50281, eos_token_id=50282, sep_token_id=50282, cls_token_id=50281, hidden_activation="gelu", classifier_activation="gelu", classifier_pooling="mean", num_labels=8, id2label={str(i): label for i, label in enumerate(LABEL_MAPPING.keys())}, label2id=LABEL_MAPPING, problem_type="single_label_classification", architectures=["ModernBertForSequenceClassification"], model_type="modernbert" ) # Load tokenizer tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) # Load model model = AutoModelForSequenceClassification.from_pretrained( model_name, config=config, trust_remote_code=True, ignore_mismatched_sizes=True, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 ).to(device) # Set model to evaluation mode model.eval() # Preprocess function def preprocess_function(examples): return tokenizer( examples["quote"], padding=False, truncation=True, max_length=512, return_tensors=None ) # Tokenize dataset tokenized_test = test_dataset.map( preprocess_function, batched=True, remove_columns=test_dataset.column_names ) # Set format for pytorch tokenized_test.set_format("torch") # Create DataLoader data_collator = DataCollatorWithPadding(tokenizer=tokenizer) test_loader = DataLoader( tokenized_test, batch_size=16, collate_fn=data_collator, shuffle=False ) # Get predictions predictions = [] with torch.no_grad(): for batch in test_loader: batch = {k: v.to(device) for k, v in batch.items()} outputs = model(**batch) preds = torch.argmax(outputs.logits, dim=-1) predictions.extend(preds.cpu().numpy().tolist()) # Clean up GPU memory if torch.cuda.is_available(): torch.cuda.empty_cache() except Exception as e: print(f"Error during model inference: {str(e)}") raise #-------------------------------------------------------------------------------------------- # MODEL INFERENCE ENDS HERE #-------------------------------------------------------------------------------------------- # Stop tracking emissions emissions_data = tracker.stop_task() # Calculate accuracy accuracy = accuracy_score(test_dataset["label"], predictions) # Prepare results dictionary results = { "username": username, "space_url": space_url, "submission_timestamp": datetime.now().isoformat(), "model_description": DESCRIPTION, "accuracy": float(accuracy), "energy_consumed_wh": emissions_data.energy_consumed * 1000, "emissions_gco2eq": emissions_data.emissions * 1000, "emissions_data": clean_emissions_data(emissions_data), "api_route": ROUTE, "dataset_config": { "dataset_name": request.dataset_name, "test_size": request.test_size, "test_seed": request.test_seed } } return results