Spaces:
Sleeping
Sleeping
revert inference code
Browse files- tasks/text.py +182 -132
tasks/text.py
CHANGED
@@ -1,151 +1,201 @@
|
|
1 |
-
|
2 |
-
from fastapi import APIRouter, HTTPException
|
3 |
from datetime import datetime
|
4 |
from datasets import load_dataset
|
5 |
from sklearn.metrics import accuracy_score
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
import torch
|
7 |
-
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
8 |
-
from torch.utils.data import Dataset, DataLoader
|
9 |
-
import logging
|
10 |
|
11 |
from .utils.evaluation import TextEvaluationRequest
|
12 |
-
from .utils.emissions import
|
13 |
|
14 |
-
#
|
15 |
-
|
16 |
-
logger = logging.getLogger(__name__)
|
17 |
|
18 |
router = APIRouter()
|
19 |
|
20 |
-
DESCRIPTION = "
|
21 |
ROUTE = "/text"
|
22 |
|
23 |
-
class
|
24 |
-
def __init__(self
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
-
@router.post(ROUTE, tags=["Text Task"],
|
|
|
52 |
async def evaluate_text(request: TextEvaluationRequest):
|
53 |
"""
|
54 |
Evaluate text classification for climate disinformation detection.
|
|
|
|
|
|
|
|
|
55 |
"""
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
}
|
71 |
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
results = {
|
125 |
-
"username": username,
|
126 |
-
"space_url": space_url,
|
127 |
-
"submission_timestamp": datetime.now().isoformat(),
|
128 |
-
"model_description": DESCRIPTION,
|
129 |
-
"accuracy": float(accuracy),
|
130 |
-
"energy_consumed_wh": float(emissions_data.energy_consumed * 1000),
|
131 |
-
"emissions_gco2eq": float(emissions_data.emissions * 1000),
|
132 |
-
"emissions_data": clean_emissions_data(emissions_data.__dict__),
|
133 |
-
"api_route": ROUTE,
|
134 |
-
"dataset_config": {
|
135 |
-
"dataset_name": request.dataset_name,
|
136 |
-
"test_size": request.test_size,
|
137 |
-
"test_seed": request.test_seed
|
138 |
-
}
|
139 |
-
}
|
140 |
-
|
141 |
-
logger.info("Evaluation completed successfully")
|
142 |
-
return results
|
143 |
-
|
144 |
-
except Exception as e:
|
145 |
-
logger.error(f"Error during evaluation: {str(e)}")
|
146 |
-
stop_tracking()
|
147 |
-
raise HTTPException(status_code=500, detail=str(e))
|
148 |
-
|
149 |
-
except Exception as e:
|
150 |
-
logger.error(f"Error in evaluate_text: {str(e)}")
|
151 |
-
raise HTTPException(status_code=500, detail=str(e))
|
|
|
1 |
+
from fastapi import APIRouter
|
|
|
2 |
from datetime import datetime
|
3 |
from datasets import load_dataset
|
4 |
from sklearn.metrics import accuracy_score
|
5 |
+
import random
|
6 |
+
from transformers import pipeline, AutoConfig
|
7 |
+
import os
|
8 |
+
from concurrent.futures import ThreadPoolExecutor
|
9 |
+
from typing import List, Dict, Tuple
|
10 |
+
import numpy as np
|
11 |
import torch
|
|
|
|
|
|
|
12 |
|
13 |
from .utils.evaluation import TextEvaluationRequest
|
14 |
+
from .utils.emissions import tracker, clean_emissions_data, get_space_info
|
15 |
|
16 |
+
# Disable torch compile
|
17 |
+
os.environ["TORCH_COMPILE_DISABLE"] = "1"
|
|
|
18 |
|
19 |
router = APIRouter()
|
20 |
|
21 |
+
DESCRIPTION = "Random Baseline"
|
22 |
ROUTE = "/text"
|
23 |
|
24 |
+
class TextClassifier:
|
25 |
+
def __init__(self):
|
26 |
+
# Add retry mechanism for model initialization
|
27 |
+
max_retries = 3
|
28 |
+
for attempt in range(max_retries):
|
29 |
+
try:
|
30 |
+
self.config = AutoConfig.from_pretrained("Tonic/climate-guard-toxic-agent")
|
31 |
+
self.label2id = self.config.label2id
|
32 |
+
self.classifier = pipeline(
|
33 |
+
"text-classification",
|
34 |
+
"Tonic/climate-guard-toxic-agent",
|
35 |
+
device="cpu",
|
36 |
+
batch_size=16
|
37 |
+
)
|
38 |
+
print("Model initialized successfully")
|
39 |
+
break
|
40 |
+
except Exception as e:
|
41 |
+
if attempt == max_retries - 1:
|
42 |
+
raise Exception(f"Failed to initialize model after {max_retries} attempts: {str(e)}")
|
43 |
+
print(f"Attempt {attempt + 1} failed, retrying...")
|
44 |
+
time.sleep(1)
|
45 |
+
|
46 |
+
def process_batch(self, batch: List[str], batch_idx: int) -> Tuple[List[int], int]:
|
47 |
+
"""Process a batch of texts and return their predictions"""
|
48 |
+
max_retries = 3
|
49 |
+
for attempt in range(max_retries):
|
50 |
+
try:
|
51 |
+
print(f"Processing batch {batch_idx} with {len(batch)} items (attempt {attempt + 1})")
|
52 |
+
# Process texts one by one in case of errors
|
53 |
+
predictions = []
|
54 |
+
for text in batch:
|
55 |
+
try:
|
56 |
+
pred = self.classifier(text)
|
57 |
+
pred_label = self.label2id[pred[0]["label"]]
|
58 |
+
predictions.append(pred_label)
|
59 |
+
except Exception as e:
|
60 |
+
print(f"Error processing text in batch {batch_idx}: {str(e)}")
|
61 |
+
|
62 |
+
if not predictions:
|
63 |
+
raise Exception("No predictions generated for batch")
|
64 |
+
|
65 |
+
print(f"Completed batch {batch_idx} with {len(predictions)} predictions")
|
66 |
+
return predictions, batch_idx
|
67 |
+
|
68 |
+
except Exception as e:
|
69 |
+
if attempt == max_retries - 1:
|
70 |
+
print(f"Final error in batch {batch_idx}: {str(e)}")
|
71 |
+
return [0] * len(batch), batch_idx # Return default predictions instead of empty list
|
72 |
+
print(f"Error in batch {batch_idx} (attempt {attempt + 1}): {str(e)}")
|
73 |
+
time.sleep(1)
|
74 |
+
|
75 |
|
76 |
+
@router.post(ROUTE, tags=["Text Task"],
|
77 |
+
description=DESCRIPTION)
|
78 |
async def evaluate_text(request: TextEvaluationRequest):
|
79 |
"""
|
80 |
Evaluate text classification for climate disinformation detection.
|
81 |
+
|
82 |
+
Current Model: Random Baseline
|
83 |
+
- Makes random predictions from the label space (0-7)
|
84 |
+
- Used as a baseline for comparison
|
85 |
"""
|
86 |
+
# Get space info
|
87 |
+
username, space_url = get_space_info()
|
88 |
+
|
89 |
+
# Define the label mapping
|
90 |
+
LABEL_MAPPING = {
|
91 |
+
"0_not_relevant": 0,
|
92 |
+
"1_not_happening": 1,
|
93 |
+
"2_not_human": 2,
|
94 |
+
"3_not_bad": 3,
|
95 |
+
"4_solutions_harmful_unnecessary": 4,
|
96 |
+
"5_science_unreliable": 5,
|
97 |
+
"6_proponents_biased": 6,
|
98 |
+
"7_fossil_fuels_needed": 7
|
99 |
+
}
|
100 |
+
|
101 |
+
# Load and prepare the dataset
|
102 |
+
dataset = load_dataset(request.dataset_name)
|
103 |
+
|
104 |
+
# Convert string labels to integers
|
105 |
+
dataset = dataset.map(lambda x: {"label": LABEL_MAPPING[x["label"]]})
|
106 |
+
|
107 |
+
# Split dataset
|
108 |
+
train_test = dataset["train"]
|
109 |
+
test_dataset = dataset["test"]
|
110 |
+
|
111 |
+
# Start tracking emissions
|
112 |
+
tracker.start()
|
113 |
+
tracker.start_task("inference")
|
114 |
+
|
115 |
+
#--------------------------------------------------------------------------------------------
|
116 |
+
# YOUR MODEL INFERENCE CODE HERE
|
117 |
+
# Update the code below to replace the random baseline by your model inference within the inference pass where the energy consumption and emissions are tracked.
|
118 |
+
#--------------------------------------------------------------------------------------------
|
119 |
+
|
120 |
+
true_labels = test_dataset["label"]
|
121 |
+
|
122 |
+
# Initialize the model once
|
123 |
+
classifier = TextClassifier()
|
124 |
+
|
125 |
+
# Prepare batches
|
126 |
+
batch_size = 32
|
127 |
+
quotes = test_dataset["quote"]
|
128 |
+
num_batches = len(quotes) // batch_size + (1 if len(quotes) % batch_size != 0 else 0)
|
129 |
+
batches = [
|
130 |
+
quotes[i * batch_size:(i + 1) * batch_size]
|
131 |
+
for i in range(num_batches)
|
132 |
+
]
|
133 |
+
|
134 |
+
# Initialize batch_results before parallel processing
|
135 |
+
batch_results = [[] for _ in range(num_batches)]
|
136 |
+
|
137 |
+
# Process batches in parallel
|
138 |
+
max_workers = min(os.cpu_count(), 4) # Limit to 4 workers or CPU count
|
139 |
+
print(f"Processing with {max_workers} workers")
|
140 |
+
|
141 |
+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
142 |
+
# Submit all batches for processing
|
143 |
+
future_to_batch = {
|
144 |
+
executor.submit(
|
145 |
+
classifier.process_batch,
|
146 |
+
batch,
|
147 |
+
idx
|
148 |
+
): idx for idx, batch in enumerate(batches)
|
149 |
}
|
150 |
|
151 |
+
# Collect results in order
|
152 |
+
for future in future_to_batch:
|
153 |
+
batch_idx = future_to_batch[future]
|
154 |
+
try:
|
155 |
+
predictions, idx = future.result()
|
156 |
+
if predictions: # Only store non-empty predictions
|
157 |
+
batch_results[idx] = predictions
|
158 |
+
print(f"Stored results for batch {idx} ({len(predictions)} predictions)")
|
159 |
+
except Exception as e:
|
160 |
+
print(f"Failed to get results for batch {batch_idx}: {e}")
|
161 |
+
# Use default predictions instead of empty list
|
162 |
+
batch_results[batch_idx] = [0] * len(batches[batch_idx])
|
163 |
+
|
164 |
+
# Flatten predictions while maintaining order
|
165 |
+
predictions = []
|
166 |
+
for batch_preds in batch_results:
|
167 |
+
if batch_preds is not None:
|
168 |
+
predictions.extend(batch_preds)
|
169 |
+
|
170 |
+
#--------------------------------------------------------------------------------------------
|
171 |
+
# YOUR MODEL INFERENCE STOPS HERE
|
172 |
+
#--------------------------------------------------------------------------------------------
|
173 |
+
|
174 |
+
# Stop tracking emissions
|
175 |
+
emissions_data = tracker.stop_task()
|
176 |
+
|
177 |
+
# Calculate accuracy
|
178 |
+
accuracy = accuracy_score(true_labels, predictions)
|
179 |
+
print("accuracy : ", accuracy)
|
180 |
+
|
181 |
+
# Prepare results dictionary
|
182 |
+
results = {
|
183 |
+
"username": username,
|
184 |
+
"space_url": space_url,
|
185 |
+
"submission_timestamp": datetime.now().isoformat(),
|
186 |
+
"model_description": DESCRIPTION,
|
187 |
+
"accuracy": float(accuracy),
|
188 |
+
"energy_consumed_wh": emissions_data.energy_consumed * 1000,
|
189 |
+
"emissions_gco2eq": emissions_data.emissions * 1000,
|
190 |
+
"emissions_data": clean_emissions_data(emissions_data),
|
191 |
+
"api_route": ROUTE,
|
192 |
+
"dataset_config": {
|
193 |
+
"dataset_name": request.dataset_name,
|
194 |
+
"test_size": request.test_size,
|
195 |
+
"test_seed": request.test_seed
|
196 |
+
}
|
197 |
+
}
|
198 |
+
|
199 |
+
print("results : ", results)
|
200 |
+
|
201 |
+
return results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|