Update tasks/text.py
Browse files- tasks/text.py +34 -3
tasks/text.py
CHANGED
@@ -16,10 +16,16 @@ DESCRIPTION = "electra fine tune"
|
|
16 |
ROUTE = "/text"
|
17 |
|
18 |
@router.post(ROUTE, tags=["Text Task"],
|
19 |
-
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
21 |
- Used as a baseline for comparison
|
22 |
"""
|
|
|
23 |
# Download from Google Drive
|
24 |
import gdown
|
25 |
|
@@ -32,6 +38,32 @@ ROUTE = "/text"
|
|
32 |
# Get space info
|
33 |
username, space_url = get_space_info()
|
34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
#--------------------------------------------------------------------------------------------
|
36 |
|
37 |
# Make random predictions (placeholder for actual model inference)
|
@@ -44,7 +76,6 @@ ROUTE = "/text"
|
|
44 |
# Get true labels
|
45 |
true_labels = test_dataset["label"]
|
46 |
|
47 |
-
|
48 |
#--------------------------------------------------------------------------------------------
|
49 |
# YOUR MODEL INFERENCE STOPS HERE
|
50 |
#--------------------------------------------------------------------------------------------
|
|
|
16 |
ROUTE = "/text"
|
17 |
|
18 |
@router.post(ROUTE, tags=["Text Task"],
|
19 |
+
description=DESCRIPTION)
|
20 |
+
|
21 |
+
async def evaluate_text(request: TextEvaluationRequest):
|
22 |
+
"""
|
23 |
+
Evaluate text classification for climate disinformation detection.
|
24 |
+
|
25 |
+
Current Model: Electra
|
26 |
- Used as a baseline for comparison
|
27 |
"""
|
28 |
+
|
29 |
# Download from Google Drive
|
30 |
import gdown
|
31 |
|
|
|
38 |
# Get space info
|
39 |
username, space_url = get_space_info()
|
40 |
|
41 |
+
# Define the label mapping
|
42 |
+
LABEL_MAPPING = {
|
43 |
+
"0_not_relevant": 0,
|
44 |
+
"1_not_happening": 1,
|
45 |
+
"2_not_human": 2,
|
46 |
+
"3_not_bad": 3,
|
47 |
+
"4_solutions_harmful_unnecessary": 4,
|
48 |
+
"5_science_unreliable": 5,
|
49 |
+
"6_proponents_biased": 6,
|
50 |
+
"7_fossil_fuels_needed": 7
|
51 |
+
}
|
52 |
+
|
53 |
+
# Load and prepare the dataset
|
54 |
+
dataset = load_dataset(request.dataset_name)
|
55 |
+
|
56 |
+
# Convert string labels to integers
|
57 |
+
dataset = dataset.map(lambda x: {"label": LABEL_MAPPING[x["label"]]})
|
58 |
+
|
59 |
+
# Split dataset
|
60 |
+
train_test = dataset["train"]
|
61 |
+
test_dataset = dataset["test"]
|
62 |
+
|
63 |
+
# Start tracking emissions
|
64 |
+
tracker.start()
|
65 |
+
tracker.start_task("inference")
|
66 |
+
|
67 |
#--------------------------------------------------------------------------------------------
|
68 |
|
69 |
# Make random predictions (placeholder for actual model inference)
|
|
|
76 |
# Get true labels
|
77 |
true_labels = test_dataset["label"]
|
78 |
|
|
|
79 |
#--------------------------------------------------------------------------------------------
|
80 |
# YOUR MODEL INFERENCE STOPS HERE
|
81 |
#--------------------------------------------------------------------------------------------
|