File size: 3,352 Bytes
4d6e8c2
 
 
8b796b7
b552ad5
 
4d6e8c2
b552ad5
4d6e8c2
 
 
 
 
 
b552ad5
39aa6d2
70f5f26
8b796b7
860f09c
4d6e8c2
b552ad5
4d6e8c2
39aa6d2
4d6e8c2
 
39aa6d2
4d6e8c2
 
 
 
 
 
 
 
 
 
 
39aa6d2
4d6e8c2
860f09c
39aa6d2
860f09c
 
8b796b7
 
 
 
 
 
 
39aa6d2
4d6e8c2
 
5d2f9b2
b552ad5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8b796b7
b552ad5
 
39aa6d2
8b796b7
b552ad5
39aa6d2
 
 
8b796b7
39aa6d2
8b796b7
 
39aa6d2
 
4d6e8c2
 
 
39aa6d2
4d6e8c2
 
 
 
39aa6d2
4d6e8c2
 
8b796b7
b552ad5
 
4d6e8c2
8b796b7
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from fastapi import APIRouter
from datetime import datetime
from datasets import load_dataset
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import accuracy_score
from sklearn.pipeline import Pipeline

from .utils.evaluation import TextEvaluationRequest
from .utils.emissions import tracker, clean_emissions_data, get_space_info

router = APIRouter()

DESCRIPTION = "TF-IDF + Logistic Regression"
ROUTE = "/text"

@router.post(ROUTE, tags=["Text Task"], description=DESCRIPTION)
async def evaluate_text(request: TextEvaluationRequest):
    """
    Evaluate text classification for climate disinformation detection using TF-IDF and Logistic Regression.
    """
    # Get space info
    username, space_url = get_space_info()

    # Define the label mapping
    LABEL_MAPPING = {
        "0_not_relevant": 0,
        "1_not_happening": 1,
        "2_not_human": 2,
        "3_not_bad": 3,
        "4_solutions_harmful_unnecessary": 4,
        "5_science_unreliable": 5,
        "6_proponents_biased": 6,
        "7_fossil_fuels_needed": 7
    }

    # Load and prepare the dataset
    dataset = load_dataset(request.dataset_name)

    # Convert string labels to integers
    dataset = dataset.map(lambda x: {"label": LABEL_MAPPING[x["label"]]})

    # Split dataset into training and testing sets
    train_data = dataset["train"]
    test_data = dataset["test"]

    train_texts, train_labels = train_data["text"], train_data["label"]
    test_texts, test_labels = test_data["text"], test_data["label"]

    # Start tracking emissions
    tracker.start()
    tracker.start_task("inference")

    # Define the pipeline with TF-IDF and Logistic Regression
    pipeline = Pipeline([
        ('tfidf', TfidfVectorizer(max_features=10000, ngram_range=(1, 2), stop_words="english")),
        ('clf', LogisticRegression(max_iter=1000, random_state=42))
    ])

    # Set up GridSearchCV for hyperparameter tuning
    param_grid = {
        'tfidf__max_features': [5000, 10000, 15000],
        'tfidf__ngram_range': [(1, 1), (1, 2)],
        'clf__C': [0.1, 1, 10]  # Regularization strength
    }

    grid_search = GridSearchCV(pipeline, param_grid, cv=3, scoring='accuracy', verbose=2)
    grid_search.fit(train_texts, train_labels)

    # Get best estimator from GridSearch
    best_model = grid_search.best_estimator_

    # Model Inference
    predictions = best_model.predict(test_texts)

    # Stop tracking emissions
    emissions_data = tracker.stop_task()

    # Calculate accuracy
    accuracy = accuracy_score(test_labels, predictions)

    # Prepare results dictionary
    results = {
        "username": username,
        "space_url": space_url,
        "submission_timestamp": datetime.now().isoformat(),
        "model_description": DESCRIPTION,
        "accuracy": float(accuracy),
        "energy_consumed_wh": emissions_data.energy_consumed * 1000,
        "emissions_gco2eq": emissions_data.emissions * 1000,
        "emissions_data": clean_emissions_data(emissions_data),
        "api_route": ROUTE,
        "dataset_config": {
            "dataset_name": request.dataset_name,
            "test_size": len(test_data),
        },
        "best_params": grid_search.best_params_
    }

    return results