File size: 5,893 Bytes
4d6e8c2
 
 
 
90194b0
73febd1
4d6e8c2
a1a5fb1
 
 
4d6e8c2
 
73febd1
1c33274
70f5f26
48ed843
f193c7f
 
5ba80f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48ed843
 
 
 
 
 
 
 
 
 
73febd1
48ed843
 
73febd1
5ba80f7
 
a1a5fb1
4d6e8c2
48ed843
4d6e8c2
a1a5fb1
 
 
48ed843
 
 
 
 
5ba80f7
48ed843
 
 
5ba80f7
a1a5fb1
 
 
 
 
73febd1
a1a5fb1
aa18df0
941eb28
 
 
 
 
 
 
 
 
 
73febd1
 
 
 
 
941eb28
73febd1
0a7a34d
73febd1
 
48ed843
 
 
4d6e8c2
48ed843
a1a5fb1
 
 
4d6e8c2
48ed843
5ba80f7
4d6e8c2
 
a1a5fb1
 
4d6e8c2
70f5f26
4d6e8c2
a1a5fb1
 
 
1c33274
4d6e8c2
48ed843
 
 
5ba80f7
4d6e8c2
5ba80f7
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from fastapi import APIRouter
from datetime import datetime
from datasets import load_dataset
from sklearn.metrics import accuracy_score
from sentence_transformers import SentenceTransformer
import numpy as np

from .utils.emissions import clean_emissions_data, get_space_info, tracker
from .utils.evaluation import TextEvaluationRequest

router = APIRouter()

DESCRIPTION = "Efficient embedding-based classification with similarity threshold"
ROUTE = "/text"

# Load custom embedding model
model_name = "pedro-thenewsroom/climate-misinfo-embed"
# model_name = "pedro-thenewsroom/climate-misinfo-embed-8bit"
embedding_model = SentenceTransformer(model_name)

# Define label mapping
LABEL_MAPPING = {
    "0_not_relevant": 0,
    "1_not_happening": 1,
    "2_not_human": 2,
    "3_not_bad": 3,
    "4_solutions_harmful_unnecessary": 4,
    "5_science_unreliable": 5,
    "6_proponents_biased": 6,
    "7_fossil_fuels_needed": 7,
}

# Class descriptions for embedding comparison
class_descriptions = {
    "1_not_happening": "Despite the alarmists' claims of global warming, temperatures have remained steady or even dropped in many areas, proving that climate change is nothing more than natural variability.",
    "2_not_human": "The Earth's climate has always changed due to natural cycles and external factors, and the role of human activity or CO2 emissions in driving these changes is negligible or unsupported by evidence.",
    "3_not_bad": "Contrary to the alarmist rhetoric, rising CO2 levels and modest warming are fostering a greener planet, boosting crop yields, and enhancing global prosperity while posing no significant threat to human or environmental health.",
    "4_solutions_harmful_unnecessary": "Global climate policies are costly, ineffective, and fail to address the unchecked emissions from developing nations, rendering efforts by industrialized countries futile and economically damaging.",
    "5_science_unreliable": "The so-called consensus on climate change relies on flawed models, manipulated data, and a refusal to address legitimate scientific uncertainties, all to serve a predetermined political narrative.",
    "6_proponents_biased": "Climate change is nothing more than a fabricated agenda pushed by corrupt elites, politicians, and scientists to control the masses, gain wealth, and suppress freedom.",
    "7_fossil_fuels_needed": "Fossil fuels have powered centuries of progress, lifted billions out of poverty, and remain the backbone of global energy, while alternatives, though promising, cannot yet match their scale, reliability, or affordability.",
}

# Precompute class embeddings (normalized for cosine similarity)
class_labels = list(class_descriptions.keys())
class_sentences = list(class_descriptions.values())
class_embeddings = embedding_model.encode(class_sentences, batch_size=8, convert_to_numpy=True, normalize_embeddings=True)

@router.post(ROUTE, tags=["Text Task"], description=DESCRIPTION)
async def evaluate_text(request: TextEvaluationRequest):
    """
    Evaluate text classification for climate disinformation detection using cosine similarity.
    """
    # Get space info
    username, space_url = get_space_info()

    # Load and prepare the dataset
    dataset = load_dataset(request.dataset_name)

    # Convert dataset labels to integers based on LABEL_MAPPING
    dataset = dataset.map(lambda x: {"label": LABEL_MAPPING[x["label"]]})

    # Split dataset into train and test
    train_test = dataset["train"].train_test_split(test_size=request.test_size, seed=request.test_seed)
    test_dataset = train_test["test"]

    # Start tracking emissions
    tracker.start()
    tracker.start_task("inference")

    # --------------------------------------------------------------------------------------------
    # Optimized cosine similarity-based classification with threshold
    # --------------------------------------------------------------------------------------------
    
    # Convert "quote" key into embeddings
    def embed_quote(example):
        example["quote_embedding"] = embedding_model.encode(example["quote"]).tolist()
        return example
    
    test_dataset = test_dataset.map(embed_quote, batched=True)
    
    # Convert test embeddings to numpy array
    test_embeddings = np.array(test_dataset["quote_embedding"])

    # Compute cosine similarity in a single operation
    similarity_matrix = np.dot(test_embeddings, class_embeddings.T)  # Efficient matrix multiplication
    best_indices = similarity_matrix.argmax(axis=1)  # Get index of highest similarity for each test sample
    best_similarities = similarity_matrix.max(axis=1)  # Get max similarity values

    # Apply threshold (0.8) for classification
    predictions = [
        LABEL_MAPPING[class_labels[idx]] if sim > 0.8 else LABEL_MAPPING["0_not_relevant"]
        for idx, sim in zip(best_indices, best_similarities)
    ]

    # Get ground truth labels
    true_labels = test_dataset["label"]

    # --------------------------------------------------------------------------------------------
    # Stop tracking emissions
    emissions_data = tracker.stop_task()

    # Calculate accuracy
    accuracy = accuracy_score(true_labels, predictions)

    # Prepare results dictionary
    results = {
        "username": username,
        "space_url": space_url,
        "submission_timestamp": datetime.now().isoformat(),
        "model_description": DESCRIPTION,
        "accuracy": float(accuracy),
        "energy_consumed_wh": emissions_data.energy_consumed * 1000,
        "emissions_gco2eq": emissions_data.emissions * 1000,
        "emissions_data": clean_emissions_data(emissions_data),
        "api_route": ROUTE,
        "dataset_config": {
            "dataset_name": request.dataset_name,
            "test_size": request.test_size,
            "test_seed": request.test_seed,
        },
    }

    return results