Spaces:
Sleeping
Sleeping
fix dataset loading
Browse files- tasks/text.py +84 -53
tasks/text.py
CHANGED
@@ -123,59 +123,90 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
123 |
"7_fossil_fuels_needed": 7
|
124 |
}
|
125 |
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
|
|
|
|
|
|
|
|
133 |
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
print(f"
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
177 |
}
|
178 |
-
}
|
179 |
|
180 |
-
|
181 |
-
|
|
|
|
|
|
|
|
|
|
123 |
"7_fossil_fuels_needed": 7
|
124 |
}
|
125 |
|
126 |
+
try:
|
127 |
+
# Load and prepare the dataset
|
128 |
+
dataset = load_dataset("QuotaClimat/frugal-ai-challenge")
|
129 |
+
|
130 |
+
# Convert string labels to integers
|
131 |
+
dataset = dataset.map(lambda x: {"label": LABEL_MAPPING[x["label"]]})
|
132 |
+
test_dataset = dataset["test"]
|
133 |
+
|
134 |
+
# Start tracking emissions
|
135 |
+
tracker.start()
|
136 |
+
tracker.start_task("inference")
|
137 |
|
138 |
+
true_labels = test_dataset["label"]
|
139 |
+
|
140 |
+
# Initialize the model once
|
141 |
+
classifier = TextClassifier()
|
142 |
+
|
143 |
+
# Prepare batches
|
144 |
+
batch_size = 32
|
145 |
+
quotes = test_dataset["quote"]
|
146 |
+
num_batches = len(quotes) // batch_size + (1 if len(quotes) % batch_size != 0 else 0)
|
147 |
+
batches = [
|
148 |
+
quotes[i * batch_size:(i + 1) * batch_size]
|
149 |
+
for i in range(num_batches)
|
150 |
+
]
|
151 |
+
|
152 |
+
# Initialize batch_results
|
153 |
+
batch_results = [[] for _ in range(num_batches)]
|
154 |
+
|
155 |
+
# Process batches in parallel
|
156 |
+
max_workers = min(os.cpu_count(), 4)
|
157 |
+
print(f"Processing with {max_workers} workers")
|
158 |
+
|
159 |
+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
160 |
+
future_to_batch = {
|
161 |
+
executor.submit(classifier.process_batch, batch, idx): idx
|
162 |
+
for idx, batch in enumerate(batches)
|
163 |
+
}
|
164 |
+
|
165 |
+
for future in future_to_batch:
|
166 |
+
batch_idx = future_to_batch[future]
|
167 |
+
try:
|
168 |
+
predictions, idx = future.result()
|
169 |
+
if predictions:
|
170 |
+
batch_results[idx] = predictions
|
171 |
+
print(f"Stored results for batch {idx} ({len(predictions)} predictions)")
|
172 |
+
except Exception as e:
|
173 |
+
print(f"Failed to get results for batch {batch_idx}: {e}")
|
174 |
+
batch_results[batch_idx] = [0] * len(batches[batch_idx])
|
175 |
+
|
176 |
+
# Flatten predictions
|
177 |
+
predictions = []
|
178 |
+
for batch_preds in batch_results:
|
179 |
+
if batch_preds is not None:
|
180 |
+
predictions.extend(batch_preds)
|
181 |
+
|
182 |
+
# Stop tracking emissions
|
183 |
+
emissions_data = tracker.stop_task()
|
184 |
+
|
185 |
+
# Calculate accuracy
|
186 |
+
accuracy = accuracy_score(true_labels, predictions)
|
187 |
+
print("accuracy:", accuracy)
|
188 |
+
|
189 |
+
# Prepare results
|
190 |
+
results = {
|
191 |
+
"username": username,
|
192 |
+
"space_url": space_url,
|
193 |
+
"submission_timestamp": datetime.now().isoformat(),
|
194 |
+
"model_description": DESCRIPTION,
|
195 |
+
"accuracy": float(accuracy),
|
196 |
+
"energy_consumed_wh": emissions_data.energy_consumed * 1000,
|
197 |
+
"emissions_gco2eq": emissions_data.emissions * 1000,
|
198 |
+
"emissions_data": clean_emissions_data(emissions_data),
|
199 |
+
"api_route": ROUTE,
|
200 |
+
"dataset_config": {
|
201 |
+
"dataset_name": request.dataset_name,
|
202 |
+
"test_size": request.test_size,
|
203 |
+
"test_seed": request.test_seed
|
204 |
+
}
|
205 |
}
|
|
|
206 |
|
207 |
+
print("results:", results)
|
208 |
+
return results
|
209 |
+
|
210 |
+
except Exception as e:
|
211 |
+
print(f"Error in evaluate_text: {str(e)}")
|
212 |
+
raise Exception(f"Failed to process request: {str(e)}")
|