File size: 3,858 Bytes
4d6e8c2
 
 
 
1e0fe77
 
 
4d6e8c2
 
 
 
 
 
1e0fe77
1c33274
70f5f26
1e0fe77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f40d4de
1e0fe77
 
 
 
4d6e8c2
 
1e0fe77
4d6e8c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70f5f26
1e0fe77
4d6e8c2
 
1e0fe77
 
 
 
 
 
 
 
 
 
 
 
4d6e8c2
 
 
 
 
 
 
 
70f5f26
4d6e8c2
 
 
 
1c33274
4d6e8c2
 
 
 
 
 
 
1e0fe77
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from fastapi import APIRouter
from datetime import datetime
from datasets import load_dataset
from sklearn.metrics import accuracy_score
import torch
from transformers import AutoTokenizer, RobertaForSequenceClassification
from torch.utils.data import Dataset, DataLoader

from .utils.evaluation import TextEvaluationRequest
from .utils.emissions import tracker, clean_emissions_data, get_space_info

router = APIRouter()

DESCRIPTION = "RoBERTa Climate Disinformation Classifier"
ROUTE = "/text"

class FrugalDataClass(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len=128):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = str(self.texts[idx])
        label = self.labels[idx]

        encodings = self.tokenizer(
            text,
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_tensors="pt"
        )

        return {
            'input_ids': encodings['input_ids'].flatten(),
            'attention_mask': encodings['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.long)
        }

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tokenizer = AutoTokenizer.from_pretrained("roberta-base")
model = RobertaForSequenceClassification.from_pretrained(
    "roberta-base",
    num_labels=8
)
model.load_state_dict(torch.load('tasks/best_roberta_model.pth', map_location=device))
model.to(device)
model.eval()

@router.post(ROUTE, description=DESCRIPTION)
async def evaluate_text(request: TextEvaluationRequest):
    """
    Evaluate text classification for climate disinformation detection using RoBERTa.
    """
    username, space_url = get_space_info()

    LABEL_MAPPING = {
        "0_not_relevant": 0,
        "1_not_happening": 1,
        "2_not_human": 2,
        "3_not_bad": 3,
        "4_solutions_harmful_unnecessary": 4,
        "5_science_unreliable": 5,
        "6_proponents_biased": 6,
        "7_fossil_fuels_needed": 7
    }

    dataset = load_dataset(request.dataset_name)

    dataset = dataset.map(lambda x: {"label": LABEL_MAPPING[x["label"]]})

    train_test = dataset["train"].train_test_split(test_size=request.test_size, seed=request.test_seed)
    test_dataset = train_test["test"]
    
    tracker.start()
    tracker.start_task("inference")

    test_texts = test_dataset["quote"]
    true_labels = test_dataset["label"]
    
    test_dataset = FrugalDataClass(test_texts, true_labels, tokenizer)
    test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)
    
    predictions = []
    with torch.no_grad():
        for batch in test_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            outputs = model(input_ids, attention_mask=attention_mask)
            preds = torch.argmax(outputs.logits, dim=1).cpu().numpy()
            predictions.extend(preds)
    
    emissions_data = tracker.stop_task()
    
    accuracy = accuracy_score(true_labels, predictions)
    
    results = {
        "username": username,
        "space_url": space_url,
        "submission_timestamp": datetime.now().isoformat(),
        "model_description": DESCRIPTION,
        "accuracy": float(accuracy),
        "energy_consumed_wh": emissions_data.energy_consumed * 1000,
        "emissions_gco2eq": emissions_data.emissions * 1000,
        "emissions_data": clean_emissions_data(emissions_data),
        "api_route": ROUTE,
        "dataset_config": {
            "dataset_name": request.dataset_name,
            "test_size": request.test_size,
            "test_seed": request.test_seed
        }
    }
    
    return results