TheoLvs's picture
Update tasks/text.py
8ba1dd9 verified
from fastapi import APIRouter, Query
from datetime import datetime
from datasets import load_dataset
from sklearn.metrics import accuracy_score
import numpy as np
import random
import os
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer
from .utils.evaluation import TextEvaluationRequest
from .utils.emissions import tracker, clean_emissions_data, get_space_info
router = APIRouter()
DESCRIPTION = "Random Baseline"
ROUTE = "/text"
models_descriptions = {
"baseline": "random baseline", # Baseline
"distilbert_frugalai": "distilbert frugal ai",
"deberta_frugalai": "deberta frugal ai",
"modernbert_frugalai": "modernbert frugal ai",
"distilroberta_frugalai": "distilroberta frugal ai"
}
def baseline_model(dataset_length: int):
# Make random predictions (placeholder for actual model inference)
predictions = [random.randint(0, 7) for _ in range(dataset_length)]
return predictions
class TextDataset(Dataset):
def __init__(self, texts, tokenizer, max_length=512):
self.texts = texts
self.tokenized_texts = tokenizer(
texts,
truncation=True,
padding=True,
max_length=max_length,
return_tensors="pt",
)
def __getitem__(self, idx):
item = {key: val[idx] for key, val in self.tokenized_texts.items()}
return item
def __len__(self) -> int:
return len(self.texts)
def bert_classifier(test_dataset: dict, model: str):
print("Starting BERT model run")
texts = test_dataset["quote"]
model_repo = f"evgeniiarazum/{model}"
tokenizer = AutoTokenizer.from_pretrained(model_repo)
if model in ["distilbert_frugalai", "deberta_frugalai", "modernbert_frugalai", "distilroberta_frugalai"]:
model = AutoModelForSequenceClassification.from_pretrained(model_repo)
else:
raise(ValueError)
# Use CUDA if available
device = "cuda"
model = model.to(device)
# Prepare dataset
dataset = TextDataset(texts, tokenizer=tokenizer)
dataloader = DataLoader(dataset, batch_size=32, shuffle=False)
model.eval()
with torch.no_grad():
predictions = np.array([])
for batch in dataloader:
test_input_ids = batch["input_ids"].to(device)
test_attention_mask = batch["attention_mask"].to(device)
outputs = model(test_input_ids, test_attention_mask)
p = torch.argmax(outputs.logits, dim=1)
predictions = np.append(predictions, p.cpu().numpy())
print("Finished BERT model run")
return predictions
@router.post(ROUTE, tags=["Text Task"])
async def evaluate_text(request: TextEvaluationRequest,
model: str = "distilbert_frugalai"):
"""
Evaluate text classification for climate disinformation detection.
Current Model: Random Baseline
- Makes random predictions from the label space (0-7)
- Used as a baseline for comparison
"""
# Get space info
username, space_url = get_space_info()
# Define the label mapping
LABEL_MAPPING = {
"0_not_relevant": 0,
"1_not_happening": 1,
"2_not_human": 2,
"3_not_bad": 3,
"4_solutions_harmful_unnecessary": 4,
"5_science_unreliable": 5,
"6_proponents_biased": 6,
"7_fossil_fuels_needed": 7
}
# Load and prepare the dataset
dataset = load_dataset(request.dataset_name, token=os.getenv("HF_TOKEN"))
# Convert string labels to integers
dataset = dataset.map(lambda x: {"label": LABEL_MAPPING[x["label"]]})
# Split dataset
test_dataset = dataset["test"]
# Start tracking emissions
tracker.start()
tracker.start_task("inference")
#--------------------------------------------------------------------------------------------
# YOUR MODEL INFERENCE CODE HERE
# Update the code below to replace the random baseline by your model inference within the inference pass where the energy consumption and emissions are tracked.
#--------------------------------------------------------------------------------------------
# Make random predictions (placeholder for actual model inference)
true_labels = test_dataset["label"]
if model == "baseline":
predictions = baseline_model(len(true_labels))
elif 'bert' in model:
predictions = bert_classifier(test_dataset, model)
#--------------------------------------------------------------------------------------------
# YOUR MODEL INFERENCE STOPS HERE
#--------------------------------------------------------------------------------------------
# Stop tracking emissions
emissions_data = tracker.stop_task()
# Calculate accuracy
accuracy = accuracy_score(true_labels, predictions)
# Prepare results dictionary
results = {
"username": username,
"space_url": space_url,
"submission_timestamp": datetime.now().isoformat(),
"model_description": models_descriptions[model],
"accuracy": float(accuracy),
"energy_consumed_wh": emissions_data.energy_consumed * 1000,
"emissions_gco2eq": emissions_data.emissions * 1000,
"emissions_data": clean_emissions_data(emissions_data),
"api_route": ROUTE,
"dataset_config": {
"dataset_name": request.dataset_name,
"test_size": request.test_size,
"test_seed": request.test_seed
}
}
return results