from fastapi import APIRouter from datetime import datetime import time from datasets import load_dataset from sklearn.metrics import accuracy_score import os from concurrent.futures import ThreadPoolExecutor from typing import List, Dict, Tuple import torch from transformers import AutoModelForSequenceClassification, AutoTokenizer from .utils.evaluation import TextEvaluationRequest from .utils.emissions import tracker, clean_emissions_data, get_space_info, start_tracking, stop_tracking # Disable torch compile os.environ["TORCH_COMPILE_DISABLE"] = "1" router = APIRouter() DESCRIPTION = "Climate Guard Toxic Agent Classifier" ROUTE = "/text" class TextClassifier: def __init__(self): self.device = "cuda" if torch.cuda.is_available() else "cpu" max_retries = 3 for attempt in range(max_retries): try: # Load model and tokenizer directly instead of using pipeline self.model = AutoModelForSequenceClassification.from_pretrained( "Tonic/climate-guard-toxic-agent" ).to(self.device) self.tokenizer = AutoTokenizer.from_pretrained( "Tonic/climate-guard-toxic-agent" ) self.model.eval() # Set to evaluation mode print("Model initialized successfully") break except Exception as e: if attempt == max_retries - 1: raise Exception(f"Failed to initialize model after {max_retries} attempts: {str(e)}") print(f"Attempt {attempt + 1} failed, retrying...") time.sleep(1) def predict_single(self, text: str) -> int: """Predict single text instance""" try: inputs = self.tokenizer( text, return_tensors="pt", truncation=True, max_length=512, padding=True ).to(self.device) with torch.no_grad(): outputs = self.model(**inputs) predictions = outputs.logits.argmax(-1) return predictions.item() except Exception as e: print(f"Error in single prediction: {str(e)}") return 0 # Return default prediction on error def process_batch(self, batch: List[str], batch_idx: int) -> Tuple[List[int], int]: """Process a batch of texts and return their predictions""" max_retries = 3 for attempt in range(max_retries): try: print(f"Processing batch {batch_idx} with {len(batch)} items (attempt {attempt + 1})") predictions = [] # Process texts one by one for better error handling for text in batch: pred = self.predict_single(text) predictions.append(pred) if not predictions: raise Exception("No predictions generated for batch") print(f"Completed batch {batch_idx} with {len(predictions)} predictions") return predictions, batch_idx except Exception as e: if attempt == max_retries - 1: print(f"Final error in batch {batch_idx}: {str(e)}") return [0] * len(batch), batch_idx print(f"Error in batch {batch_idx} (attempt {attempt + 1}): {str(e)}") time.sleep(1) @router.post(ROUTE, tags=["Text Task"], description=DESCRIPTION) async def evaluate_text(request: TextEvaluationRequest): """Evaluate text classification for climate disinformation detection.""" # Get space info username, space_url = get_space_info() # Define the label mapping LABEL_MAPPING = { "0_not_relevant": 0, "1_not_happening": 1, "2_not_human": 2, "3_not_bad": 3, "4_solutions_harmful_unnecessary": 4, "5_science_unreliable": 5, "6_proponents_biased": 6, "7_fossil_fuels_needed": 7 } # Load and prepare the dataset dataset = load_dataset(request.dataset_name) dataset = dataset.map(lambda x: {"label": LABEL_MAPPING[x["label"]]}) test_dataset = dataset["test"] # Start tracking emissions start_tracking() # tracker.start_task("inference") true_labels = test_dataset["label"] # Initialize the model once classifier = TextClassifier() # Prepare batches batch_size = 16 # Reduced batch size for better memory management quotes = test_dataset["quote"] num_batches = len(quotes) // batch_size + (1 if len(quotes) % batch_size != 0 else 0) batches = [ quotes[i * batch_size:(i + 1) * batch_size] for i in range(num_batches) ] # Initialize batch_results batch_results = [[] for _ in range(num_batches)] # Process batches in parallel max_workers = min(os.cpu_count(), 4) print(f"Processing with {max_workers} workers") with ThreadPoolExecutor(max_workers=max_workers) as executor: future_to_batch = { executor.submit(classifier.process_batch, batch, idx): idx for idx, batch in enumerate(batches) } for future in future_to_batch: batch_idx = future_to_batch[future] try: predictions, idx = future.result() if predictions: batch_results[idx] = predictions print(f"Stored results for batch {idx} ({len(predictions)} predictions)") except Exception as e: print(f"Failed to get results for batch {batch_idx}: {e}") batch_results[batch_idx] = [0] * len(batches[batch_idx]) # Flatten predictions predictions = [] for batch_preds in batch_results: if batch_preds is not None: predictions.extend(batch_preds) # Stop tracking emissions emissions_data = stop_tracking() # emissions_data = tracker.stop_task() # Calculate accuracy accuracy = accuracy_score(true_labels, predictions) print("accuracy:", accuracy) # Prepare results results = { "username": username, "space_url": space_url, "submission_timestamp": datetime.now().isoformat(), "model_description": DESCRIPTION, "accuracy": float(accuracy), "energy_consumed_wh": emissions_data.energy_consumed * 1000, "emissions_gco2eq": emissions_data.emissions * 1000, "emissions_data": clean_emissions_data(emissions_data), "api_route": ROUTE, "dataset_config": { "dataset_name": request.dataset_name, "test_size": request.test_size, "test_seed": request.test_seed } } print("results:", results) return results