File size: 5,638 Bytes
e33fed0
4d6e8c2
 
 
 
ece5856
 
 
e33fed0
4d6e8c2
 
 
 
e33fed0
 
 
 
4d6e8c2
 
acf9798
1c33274
70f5f26
acf9798
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d6e8c2
 
70f5f26
4d6e8c2
e33fed0
 
 
4d6e8c2
e33fed0
 
 
 
 
 
 
 
 
 
 
4d6e8c2
e33fed0
 
 
acf9798
e33fed0
 
acf9798
e33fed0
 
acf9798
e33fed0
 
 
acf9798
e33fed0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
acf9798
e33fed0
 
 
 
 
 
 
 
 
acf9798
e33fed0
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
# tasks/text.py
from fastapi import APIRouter
from datetime import datetime
from datasets import load_dataset
from sklearn.metrics import accuracy_score
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from torch.utils.data import Dataset, DataLoader
import logging

from .utils.evaluation import TextEvaluationRequest
from .utils.emissions import tracker, clean_emissions_data, get_space_info

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

router = APIRouter()

DESCRIPTION = "Climate Guard Toxic Agent Model"
ROUTE = "/text"

class TextDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len=128):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = str(self.texts[idx])
        label = self.labels[idx]
        
        encoding = self.tokenizer(
            text,
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_tensors="pt"
        )
        
        return {
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'labels': torch.tensor(label, dtype=torch.long)
        }

@router.post(ROUTE, tags=["Text Task"], description=DESCRIPTION)
async def evaluate_text(request: TextEvaluationRequest):
    """
    Evaluate text classification for climate disinformation detection.
    """
    try:
        logger.info("Starting evaluation")
        username, space_url = get_space_info()

        # Label mapping
        LABEL_MAPPING = {
            "0_not_relevant": 0,
            "1_not_happening": 1,
            "2_not_human": 2,
            "3_not_bad": 3,
            "4_solutions_harmful_unnecessary": 4,
            "5_science_unreliable": 5,
            "6_proponents_biased": 6,
            "7_fossil_fuels_needed": 7
        }

        logger.info("Loading dataset")
        # Load dataset
        dataset = load_dataset(request.dataset_name)
        
        # Convert string labels to integers
        dataset = dataset.map(lambda x: {"label": LABEL_MAPPING[x["label"]]})
        
        # Get test dataset
        test_dataset = dataset["test"]
        
        logger.info("Starting emissions tracking")
        # Start tracking emissions
        tracker.start()
        
        try:
            # Load model and tokenizer
            logger.info("Loading model and tokenizer")
            model_name = "Tonic/climate-guard-toxic-agent"
            tokenizer = AutoTokenizer.from_pretrained(model_name)
            model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=len(LABEL_MAPPING))
            
            # Prepare dataset
            logger.info("Preparing dataset")
            test_data = TextDataset(
                texts=test_dataset["text"],
                labels=test_dataset["label"],
                tokenizer=tokenizer
            )
            
            test_loader = DataLoader(test_data, batch_size=16)
            
            # Model inference
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            logger.info(f"Using device: {device}")
            model = model.to(device)
            model.eval()
            
            predictions = []
            ground_truth = []
            
            logger.info("Running inference")
            with torch.no_grad():
                for batch in test_loader:
                    input_ids = batch['input_ids'].to(device)
                    attention_mask = batch['attention_mask'].to(device)
                    labels = batch['labels'].to(device)
                    
                    outputs = model(input_ids=input_ids, attention_mask=attention_mask)
                    _, predicted = torch.max(outputs.logits, 1)
                    
                    predictions.extend(predicted.cpu().numpy())
                    ground_truth.extend(labels.cpu().numpy())
            
            # Calculate accuracy
            accuracy = accuracy_score(ground_truth, predictions)
            logger.info(f"Accuracy: {accuracy}")
            
            # Stop tracking emissions
            emissions_data = tracker.stop()
            
            # Prepare results
            results = {
                "username": username,
                "space_url": space_url,
                "submission_timestamp": datetime.now().isoformat(),
                "model_description": DESCRIPTION,
                "accuracy": float(accuracy),
                "energy_consumed_wh": float(emissions_data.energy_consumed * 1000),
                "emissions_gco2eq": float(emissions_data.emissions * 1000),
                "emissions_data": clean_emissions_data(emissions_data),
                "api_route": ROUTE,
                "dataset_config": {
                    "dataset_name": request.dataset_name,
                    "test_size": request.test_size,
                    "test_seed": request.test_seed
                }
            }
            
            logger.info("Evaluation completed successfully")
            return results
            
        except Exception as e:
            logger.error(f"Error during evaluation: {str(e)}")
            tracker.stop()
            raise e
            
    except Exception as e:
        logger.error(f"Error in evaluate_text: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))