File size: 5,638 Bytes
e6d07cd
4d6e8c2
4477f42
4d6e8c2
 
e6d07cd
 
 
ece5856
6af9c73
1f08781
 
21262c6
 
4d6e8c2
 
1a885c6
4d6e8c2
21262c6
 
 
 
1f08781
21262c6
 
 
4d6e8c2
 
1f08781
1c33274
de4e4d7
1f08781
70f5f26
e6d07cd
 
1f08781
7abed63
c3f000b
 
89f8be4
c3f000b
 
89f8be4
c3f000b
89f8be4
 
c3f000b
 
89f8be4
c3f000b
 
 
 
7abed63
c3f000b
 
 
e6d07cd
c3f000b
4357468
c3f000b
 
 
4357468
 
 
c3f000b
1f08781
 
 
4357468
 
1f08781
c3f000b
e6d07cd
7eb6153
4357468
 
7eb6153
 
 
4477f42
4d6e8c2
4477f42
e6d07cd
 
 
 
 
 
 
 
 
85c5204
e6d07cd
 
 
 
6f0e9af
1f08781
 
6f0e9af
1f08781
 
6f0e9af
 
 
 
 
f3f30d7
1f08781
6f0e9af
 
c3f000b
 
 
 
 
 
 
 
 
 
 
 
6f0e9af
1f08781
c3f000b
1f08781
c3f000b
 
1f08781
 
 
 
 
 
c3f000b
1f08781
6f0e9af
 
 
 
 
c3f000b
6f0e9af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e6d07cd
 
6f0e9af
 
 
 
85c5204
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
from fastapi import APIRouter
from datetime import datetime
import time
from datasets import load_dataset
from sklearn.metrics import accuracy_score
import os
from concurrent.futures import ThreadPoolExecutor
from typing import List, Dict, Tuple
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from torch.utils.data import DataLoader
from transformers import DataCollatorWithPadding
from huggingface_hub import login
from dotenv import load_dotenv

from .utils.evaluation import TextEvaluationRequest
from .utils.emissions import tracker, clean_emissions_data, get_space_info

# Load environment variables
load_dotenv()

# Authenticate with Hugging Face
HF_TOKEN = os.getenv('HF_TOKEN')
if HF_TOKEN:
    login(token=HF_TOKEN)

router = APIRouter()

DESCRIPTION = "Climate Guard Toxic Agent is a ModernBERT for Climate Disinformation Detection"
ROUTE = "/text"
MODEL_NAME = "Tonic/climate-guard-toxic-agent"
TOKENIZER_NAME = "answerdotai/ModernBERT-base"

class TextClassifier:
    def __init__(self):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        try:
            # Initialize tokenizer
            self.tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
            
            # Initialize model
            self.model = BertForSequenceClassification.from_pretrained(
                MODEL_NAME,
                num_labels=8,
                ignore_mismatched_sizes=True
            ).to(self.device)
            
            # Convert to half precision and eval mode
            self.model = self.model.half()
            self.model.eval()
            
            print("Model initialized successfully")
                
        except Exception as e:
            print(f"Error initializing model: {str(e)}")
            raise

    def process_batch(self, batch):
        try:
            # Move batch to device
            input_ids = batch['input_ids'].to(self.device)
            attention_mask = batch['attention_mask'].to(self.device)
            
            # Get predictions
            with torch.no_grad():
                outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
                predictions = torch.argmax(outputs.logits, dim=-1)
                
            return predictions.cpu().numpy().tolist()
            
        except Exception as e:
            print(f"Error in batch processing: {str(e)}")
            return [0] * len(batch['input_ids'])

    def __del__(self):
        if hasattr(self, 'model'):
            del self.model
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

@router.post(ROUTE, tags=["Text Task"], description=DESCRIPTION)
async def evaluate_text(request: TextEvaluationRequest):
    """Evaluate text classification for climate disinformation detection."""
    
    username, space_url = get_space_info()

    LABEL_MAPPING = {
        "0_not_relevant": 0,
        "1_not_happening": 1,
        "2_not_human": 2,
        "3_not_bad": 3,
        "4_solutions_harmful_unnecessary": 4,
        "5_science_unreliable": 5,
        "6_proponents_biased": 6,
        "7_fossil_fuels_needed": 7
    }

    try:
        # Load dataset
        dataset = load_dataset(request.dataset_name)
        
        # Convert labels
        dataset = dataset.map(lambda x: {"label": LABEL_MAPPING[x["label"]]})
        test_dataset = dataset["test"]
        
        # Start tracking emissions
        tracker.start()
        tracker.start_task("inference")

        # Initialize model
        classifier = TextClassifier()
        
        # Prepare tokenization function
        def preprocess_function(examples):
            return classifier.tokenizer(
                examples["quote"],
                truncation=True,
                padding=True,
                max_length=512
            )
        
        # Tokenize dataset
        tokenized_test = test_dataset.map(preprocess_function, batched=True)
        tokenized_test.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
        
        # Create DataLoader
        data_collator = DataCollatorWithPadding(tokenizer=classifier.tokenizer)
        test_loader = DataLoader(
            tokenized_test,
            batch_size=16,
            collate_fn=data_collator
        )
        
        # Get predictions
        all_predictions = []
        for batch in test_loader:
            batch_preds = classifier.process_batch(batch)
            all_predictions.extend(batch_preds)
        
        # Stop tracking emissions
        emissions_data = tracker.stop_task()
        
        # Calculate accuracy
        accuracy = accuracy_score(test_dataset["label"], all_predictions)
        
        # Prepare results
        results = {
            "username": username,
            "space_url": space_url,
            "submission_timestamp": datetime.now().isoformat(),
            "model_description": DESCRIPTION,
            "accuracy": float(accuracy),
            "energy_consumed_wh": emissions_data.energy_consumed * 1000,
            "emissions_gco2eq": emissions_data.emissions * 1000,
            "emissions_data": clean_emissions_data(emissions_data),
            "api_route": ROUTE,
            "dataset_config": {
                "dataset_name": request.dataset_name,
                "test_size": request.test_size,
                "test_seed": request.test_seed
            }
        }

        return results

    except Exception as e:
        print(f"Error in evaluate_text: {str(e)}")
        raise Exception(f"Failed to process request: {str(e)}")