File size: 5,110 Bytes
4d6e8c2
 
 
f6107f3
860f09c
e2f75a8
4d6e8c2
 
 
 
 
860f09c
4d6e8c2
 
860f09c
 
70f5f26
860f09c
 
 
4d6e8c2
860f09c
4d6e8c2
 
 
860f09c
4d6e8c2
 
 
 
 
 
 
 
 
 
 
860f09c
4d6e8c2
 
 
860f09c
4d6e8c2
860f09c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d2f9b2
860f09c
 
 
 
 
 
 
 
 
 
f6107f3
 
 
 
 
 
e2f75a8
 
f6107f3
5d2f9b2
860f09c
4d6e8c2
 
f6107f3
4d6e8c2
5d2f9b2
860f09c
f6107f3
5d2f9b2
860f09c
4d6e8c2
 
 
860f09c
4d6e8c2
 
 
 
860f09c
4d6e8c2
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
from fastapi import APIRouter
from datetime import datetime
from datasets import load_dataset
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.naive_bayes import MultinomialNB
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score

from .utils.evaluation import TextEvaluationRequest
from .utils.emissions import tracker, clean_emissions_data, get_space_info

# Define the router for text tasks
router = APIRouter()

DESCRIPTION_NAIVE_BAYES = "Naive Bayes Text Classifier"
DESCRIPTION_SVM = "SVM Text Classifier with TF-IDF"

# Naive Bayes Endpoint
@router.post("/text", tags=["Text Task"], description=DESCRIPTION_NAIVE_BAYES)
async def evaluate_text(request: TextEvaluationRequest):
    """
    Evaluate text classification using Naive Bayes.
    """
    username, space_url = get_space_info()

    # Label Mapping
    LABEL_MAPPING = {
        "0_not_relevant": 0,
        "1_not_happening": 1,
        "2_not_human": 2,
        "3_not_bad": 3,
        "4_solutions_harmful_unnecessary": 4,
        "5_science_unreliable": 5,
        "6_proponents_biased": 6,
        "7_fossil_fuels_needed": 7
    }

    # Load and prepare dataset
    dataset = load_dataset(request.dataset_name)
    dataset = dataset.map(lambda x: {"label": LABEL_MAPPING[x["label"]]})

    # Train-Test Split
    train_test = dataset["train"].train_test_split(test_size=request.test_size, seed=request.test_seed)
    train_texts = [x["text"] for x in train_test["train"]]
    train_labels = [x["label"] for x in train_test["train"]]
    test_texts = [x["text"] for x in train_test["test"]]
    test_labels = [x["label"] for x in train_test["test"]]

    # TF-IDF Vectorization
    vectorizer = TfidfVectorizer(max_features=5000)
    train_vectors = vectorizer.fit_transform(train_texts)
    test_vectors = vectorizer.transform(test_texts)

    # Train Naive Bayes Classifier
    model = MultinomialNB()
    model.fit(train_vectors, train_labels)

    # Track emissions
    tracker.start()
    tracker.start_task("inference")
    predictions = model.predict(test_vectors)
    emissions_data = tracker.stop_task()

    # Calculate Accuracy
    accuracy = accuracy_score(test_labels, predictions)

    return {
        "username": username,
        "space_url": space_url,
        "submission_timestamp": datetime.now().isoformat(),
        "model_description": DESCRIPTION_NAIVE_BAYES,
        "accuracy": float(accuracy),
        "energy_consumed_wh": emissions_data.energy_consumed * 1000,
        "emissions_gco2eq": emissions_data.emissions * 1000,
        "emissions_data": clean_emissions_data(emissions_data),
        "api_route": "/text",
        "dataset_config": {
            "dataset_name": request.dataset_name,
            "test_size": request.test_size,
            "test_seed": request.test_seed
        }
    }

# SVM Endpoint
@router.post("/text_svm", tags=["Text Task"], description=DESCRIPTION_SVM)
async def evaluate_text_svm(request: TextEvaluationRequest):
    """
    Evaluate text classification using SVM.
    """
    username, space_url = get_space_info()

    # Label Mapping
    LABEL_MAPPING = {
        "0_not_relevant": 0,
        "1_not_happening": 1,
        "2_not_human": 2,
        "3_not_bad": 3,
        "4_solutions_harmful_unnecessary": 4,
        "5_science_unreliable": 5,
        "6_proponents_biased": 6,
        "7_fossil_fuels_needed": 7
    }

    # Load and prepare dataset
    dataset = load_dataset(request.dataset_name)
    dataset = dataset.map(lambda x: {"label": LABEL_MAPPING[x["label"]]})

    # Train-Test Split
    train_test = dataset["train"].train_test_split(test_size=request.test_size, seed=request.test_seed)
    train_texts = [x["text"] for x in train_test["train"]]
    train_labels = [x["label"] for x in train_test["train"]]
    test_texts = [x["text"] for x in train_test["test"]]
    test_labels = [x["label"] for x in train_test["test"]]

    # TF-IDF Vectorization
    vectorizer = TfidfVectorizer(max_features=5000)
    train_vectors = vectorizer.fit_transform(train_texts)
    test_vectors = vectorizer.transform(test_texts)

    # Train SVM Classifier
    model = SVC(kernel="linear", probability=True)
    model.fit(train_vectors, train_labels)

    # Track emissions
    tracker.start()
    tracker.start_task("inference")
    predictions = model.predict(test_vectors)
    emissions_data = tracker.stop_task()

    # Calculate Accuracy
    accuracy = accuracy_score(test_labels, predictions)

    return {
        "username": username,
        "space_url": space_url,
        "submission_timestamp": datetime.now().isoformat(),
        "model_description": DESCRIPTION_SVM,
        "accuracy": float(accuracy),
        "energy_consumed_wh": emissions_data.energy_consumed * 1000,
        "emissions_gco2eq": emissions_data.emissions * 1000,
        "emissions_data": clean_emissions_data(emissions_data),
        "api_route": "/text_svm",
        "dataset_config": {
            "dataset_name": request.dataset_name,
            "test_size": request.test_size,
            "test_seed": request.test_seed
        }
    }