from fastapi import APIRouter from datetime import datetime import time from datasets import load_dataset from sklearn.metrics import accuracy_score import os from concurrent.futures import ThreadPoolExecutor from typing import List, Dict, Tuple import torch from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline from huggingface_hub import login from dotenv import load_dotenv from .utils.evaluation import TextEvaluationRequest from .utils.emissions import tracker, clean_emissions_data, get_space_info # Load environment variables load_dotenv() # Authenticate with Hugging Face HF_TOKEN = os.getenv('HUGGINGFACE_TOKEN') if HF_TOKEN: login(token=HF_TOKEN) # Disable torch compile os.environ["TORCH_COMPILE_DISABLE"] = "1" router = APIRouter() DESCRIPTION = "ModernBERT fine-tuned for climate disinformation detection" ROUTE = "/text" MODEL_NAME = "Tonic/climate-guard-toxic-agent" class TextClassifier: def __init__(self): self.device = "cuda" if torch.cuda.is_available() else "cpu" max_retries = 3 for attempt in range(max_retries): try: # Initialize tokenizer self.tokenizer = AutoTokenizer.from_pretrained( MODEL_NAME, model_max_length=512, padding_side='right', truncation_side='right' ) # Initialize model with specific configuration self.model = AutoModelForSequenceClassification.from_pretrained( MODEL_NAME, num_labels=8, problem_type="single_label_classification" ) # Move model to appropriate device self.model = self.model.to(self.device) # Initialize pipeline with the model and tokenizer self.classifier = pipeline( "text-classification", model=self.model, tokenizer=self.tokenizer, device=self.device, max_length=512, truncation=True, batch_size=16 ) print("Model initialized successfully") break except Exception as e: if attempt == max_retries - 1: raise Exception(f"Failed to initialize model after {max_retries} attempts: {str(e)}") print(f"Attempt {attempt + 1} failed, retrying... Error: {str(e)}") time.sleep(1) def process_batch(self, batch: List[str], batch_idx: int) -> Tuple[List[int], int]: """Process a batch of texts and return their predictions""" max_retries = 3 for attempt in range(max_retries): try: print(f"Processing batch {batch_idx} with {len(batch)} items") # Process texts with error handling predictions = [] for text in batch: try: result = self.classifier(text) pred_label = int(result[0]['label'].split('_')[0]) predictions.append(pred_label) except Exception as e: print(f"Error processing text in batch {batch_idx}: {str(e)}") predictions.append(0) # Default prediction print(f"Completed batch {batch_idx} with {len(predictions)} predictions") return predictions, batch_idx except Exception as e: if attempt == max_retries - 1: print(f"Final error in batch {batch_idx}: {str(e)}") return [0] * len(batch), batch_idx print(f"Error in batch {batch_idx} (attempt {attempt + 1}): {str(e)}") time.sleep(1) def __del__(self): # Clean up CUDA memory if hasattr(self, 'model'): del self.model if hasattr(self, 'classifier'): del self.classifier if torch.cuda.is_available(): torch.cuda.empty_cache() @router.post(ROUTE, tags=["Text Task"], description=DESCRIPTION) async def evaluate_text(request: TextEvaluationRequest): """Evaluate text classification for climate disinformation detection.""" # Get space info username, space_url = get_space_info() # Define the label mapping LABEL_MAPPING = { "0_not_relevant": 0, "1_not_happening": 1, "2_not_human": 2, "3_not_bad": 3, "4_solutions_harmful_unnecessary": 4, "5_science_unreliable": 5, "6_proponents_biased": 6, "7_fossil_fuels_needed": 7 } try: # Load and prepare the dataset dataset = load_dataset("QuotaClimat/frugalaichallenge-text-train", token=HF_TOKEN) # Convert string labels to integers with error handling def convert_label(example): try: return {"label": LABEL_MAPPING[example["label"]]} except KeyError as e: print(f"Warning: Unknown label {example['label']}") # Return default label or raise exception return {"label": 0} # or raise e if you want to fail on unknown labels dataset = dataset.map(convert_label) # Split dataset test_dataset = dataset["test"] # Start tracking emissions tracker.start() tracker.start_task("inference") true_labels = test_dataset["label"] # Initialize the model once classifier = TextClassifier() # Prepare batches batch_size = 24 quotes = test_dataset["quote"] num_batches = len(quotes) // batch_size + (1 if len(quotes) % batch_size != 0 else 0) batches = [ quotes[i * batch_size:(i + 1) * batch_size] for i in range(num_batches) ] # Initialize batch_results batch_results = [[] for _ in range(num_batches)] # Process batches in parallel max_workers = min(os.cpu_count(), 4) print(f"Processing with {max_workers} workers") with ThreadPoolExecutor(max_workers=max_workers) as executor: future_to_batch = { executor.submit(classifier.process_batch, batch, idx): idx for idx, batch in enumerate(batches) } for future in future_to_batch: batch_idx = future_to_batch[future] try: predictions, idx = future.result() if predictions: batch_results[idx] = predictions print(f"Stored results for batch {idx} ({len(predictions)} predictions)") except Exception as e: print(f"Failed to get results for batch {batch_idx}: {e}") batch_results[batch_idx] = [0] * len(batches[batch_idx]) # Flatten predictions predictions = [] for batch_preds in batch_results: if batch_preds is not None: predictions.extend(batch_preds) # Stop tracking emissions emissions_data = tracker.stop_task() # Calculate accuracy accuracy = accuracy_score(true_labels, predictions) print("accuracy:", accuracy) # Prepare results results = { "username": username, "space_url": space_url, "submission_timestamp": datetime.now().isoformat(), "model_description": DESCRIPTION, "accuracy": float(accuracy), "energy_consumed_wh": emissions_data.energy_consumed * 1000, "emissions_gco2eq": emissions_data.emissions * 1000, "emissions_data": clean_emissions_data(emissions_data), "api_route": ROUTE, "dataset_config": { "dataset_name": request.dataset_name, "test_size": request.test_size, "test_seed": request.test_seed } } print("results:", results) return results except Exception as e: print(f"Error in evaluate_text: {str(e)}") raise Exception(f"Failed to process request: {str(e)}")