Spaces:
Sleeping
Sleeping
File size: 5,638 Bytes
e6d07cd 4d6e8c2 4477f42 4d6e8c2 e6d07cd ece5856 6af9c73 1f08781 21262c6 4d6e8c2 1a885c6 4d6e8c2 21262c6 1f08781 21262c6 4d6e8c2 1f08781 1c33274 de4e4d7 1f08781 70f5f26 e6d07cd 1f08781 7abed63 c3f000b 89f8be4 c3f000b 89f8be4 c3f000b 89f8be4 c3f000b 89f8be4 c3f000b 7abed63 c3f000b e6d07cd c3f000b 4357468 c3f000b 4357468 c3f000b 1f08781 4357468 1f08781 c3f000b e6d07cd 7eb6153 4357468 7eb6153 4477f42 4d6e8c2 4477f42 e6d07cd 85c5204 e6d07cd 6f0e9af 1f08781 6f0e9af 1f08781 6f0e9af f3f30d7 1f08781 6f0e9af c3f000b 6f0e9af 1f08781 c3f000b 1f08781 c3f000b 1f08781 c3f000b 1f08781 6f0e9af c3f000b 6f0e9af e6d07cd 6f0e9af 85c5204 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
from fastapi import APIRouter
from datetime import datetime
import time
from datasets import load_dataset
from sklearn.metrics import accuracy_score
import os
from concurrent.futures import ThreadPoolExecutor
from typing import List, Dict, Tuple
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from torch.utils.data import DataLoader
from transformers import DataCollatorWithPadding
from huggingface_hub import login
from dotenv import load_dotenv
from .utils.evaluation import TextEvaluationRequest
from .utils.emissions import tracker, clean_emissions_data, get_space_info
# Load environment variables
load_dotenv()
# Authenticate with Hugging Face
HF_TOKEN = os.getenv('HF_TOKEN')
if HF_TOKEN:
login(token=HF_TOKEN)
router = APIRouter()
DESCRIPTION = "Climate Guard Toxic Agent is a ModernBERT for Climate Disinformation Detection"
ROUTE = "/text"
MODEL_NAME = "Tonic/climate-guard-toxic-agent"
TOKENIZER_NAME = "answerdotai/ModernBERT-base"
class TextClassifier:
def __init__(self):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
try:
# Initialize tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
# Initialize model
self.model = BertForSequenceClassification.from_pretrained(
MODEL_NAME,
num_labels=8,
ignore_mismatched_sizes=True
).to(self.device)
# Convert to half precision and eval mode
self.model = self.model.half()
self.model.eval()
print("Model initialized successfully")
except Exception as e:
print(f"Error initializing model: {str(e)}")
raise
def process_batch(self, batch):
try:
# Move batch to device
input_ids = batch['input_ids'].to(self.device)
attention_mask = batch['attention_mask'].to(self.device)
# Get predictions
with torch.no_grad():
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
predictions = torch.argmax(outputs.logits, dim=-1)
return predictions.cpu().numpy().tolist()
except Exception as e:
print(f"Error in batch processing: {str(e)}")
return [0] * len(batch['input_ids'])
def __del__(self):
if hasattr(self, 'model'):
del self.model
if torch.cuda.is_available():
torch.cuda.empty_cache()
@router.post(ROUTE, tags=["Text Task"], description=DESCRIPTION)
async def evaluate_text(request: TextEvaluationRequest):
"""Evaluate text classification for climate disinformation detection."""
username, space_url = get_space_info()
LABEL_MAPPING = {
"0_not_relevant": 0,
"1_not_happening": 1,
"2_not_human": 2,
"3_not_bad": 3,
"4_solutions_harmful_unnecessary": 4,
"5_science_unreliable": 5,
"6_proponents_biased": 6,
"7_fossil_fuels_needed": 7
}
try:
# Load dataset
dataset = load_dataset(request.dataset_name)
# Convert labels
dataset = dataset.map(lambda x: {"label": LABEL_MAPPING[x["label"]]})
test_dataset = dataset["test"]
# Start tracking emissions
tracker.start()
tracker.start_task("inference")
# Initialize model
classifier = TextClassifier()
# Prepare tokenization function
def preprocess_function(examples):
return classifier.tokenizer(
examples["quote"],
truncation=True,
padding=True,
max_length=512
)
# Tokenize dataset
tokenized_test = test_dataset.map(preprocess_function, batched=True)
tokenized_test.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
# Create DataLoader
data_collator = DataCollatorWithPadding(tokenizer=classifier.tokenizer)
test_loader = DataLoader(
tokenized_test,
batch_size=16,
collate_fn=data_collator
)
# Get predictions
all_predictions = []
for batch in test_loader:
batch_preds = classifier.process_batch(batch)
all_predictions.extend(batch_preds)
# Stop tracking emissions
emissions_data = tracker.stop_task()
# Calculate accuracy
accuracy = accuracy_score(test_dataset["label"], all_predictions)
# Prepare results
results = {
"username": username,
"space_url": space_url,
"submission_timestamp": datetime.now().isoformat(),
"model_description": DESCRIPTION,
"accuracy": float(accuracy),
"energy_consumed_wh": emissions_data.energy_consumed * 1000,
"emissions_gco2eq": emissions_data.emissions * 1000,
"emissions_data": clean_emissions_data(emissions_data),
"api_route": ROUTE,
"dataset_config": {
"dataset_name": request.dataset_name,
"test_size": request.test_size,
"test_seed": request.test_seed
}
}
return results
except Exception as e:
print(f"Error in evaluate_text: {str(e)}")
raise Exception(f"Failed to process request: {str(e)}") |