Tonic commited on
Commit
6f0e9af
·
verified ·
1 Parent(s): 7abed63

fix dataset loading

Browse files
Files changed (1) hide show
  1. tasks/text.py +84 -53
tasks/text.py CHANGED
@@ -123,59 +123,90 @@ async def evaluate_text(request: TextEvaluationRequest):
123
  "7_fossil_fuels_needed": 7
124
  }
125
 
126
- # Load and prepare the dataset
127
- dataset = load_dataset(request.dataset_name)
128
- dataset = dataset.map(lambda x: {"label": LABEL_MAPPING[x["label"]]})
129
- test_dataset = dataset["test"]
130
-
131
- # Start tracking emissions
132
- start_tracking()
 
 
 
 
133
 
134
- true_labels = test_dataset["label"]
135
-
136
- # Initialize the model once
137
- classifier = TextClassifier()
138
-
139
- # Prepare batches
140
- batch_size = 32 # Increased batch size for efficiency
141
- quotes = test_dataset["quote"]
142
- num_batches = len(quotes) // batch_size + (1 if len(quotes) % batch_size != 0 else 0)
143
- batches = [
144
- quotes[i * batch_size:(i + 1) * batch_size]
145
- for i in range(num_batches)
146
- ]
147
-
148
- # Process batches sequentially to avoid memory issues
149
- predictions = []
150
- for idx, batch in enumerate(batches):
151
- batch_preds, _ = classifier.process_batch(batch, idx)
152
- predictions.extend(batch_preds)
153
- print(f"Processed batch {idx + 1}/{num_batches}")
154
-
155
- # Stop tracking emissions
156
- emissions_data = stop_tracking()
157
-
158
- # Calculate accuracy
159
- accuracy = accuracy_score(true_labels, predictions)
160
- print("accuracy:", accuracy)
161
-
162
- # Prepare results
163
- results = {
164
- "username": username,
165
- "space_url": space_url,
166
- "submission_timestamp": datetime.now().isoformat(),
167
- "model_description": DESCRIPTION,
168
- "accuracy": float(accuracy),
169
- "energy_consumed_wh": emissions_data.energy_consumed * 1000,
170
- "emissions_gco2eq": emissions_data.emissions * 1000,
171
- "emissions_data": clean_emissions_data(emissions_data),
172
- "api_route": ROUTE,
173
- "dataset_config": {
174
- "dataset_name": request.dataset_name,
175
- "test_size": request.test_size,
176
- "test_seed": request.test_seed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  }
178
- }
179
 
180
- print("results:", results)
181
- return results
 
 
 
 
 
123
  "7_fossil_fuels_needed": 7
124
  }
125
 
126
+ try:
127
+ # Load and prepare the dataset
128
+ dataset = load_dataset("QuotaClimat/frugal-ai-challenge")
129
+
130
+ # Convert string labels to integers
131
+ dataset = dataset.map(lambda x: {"label": LABEL_MAPPING[x["label"]]})
132
+ test_dataset = dataset["test"]
133
+
134
+ # Start tracking emissions
135
+ tracker.start()
136
+ tracker.start_task("inference")
137
 
138
+ true_labels = test_dataset["label"]
139
+
140
+ # Initialize the model once
141
+ classifier = TextClassifier()
142
+
143
+ # Prepare batches
144
+ batch_size = 32
145
+ quotes = test_dataset["quote"]
146
+ num_batches = len(quotes) // batch_size + (1 if len(quotes) % batch_size != 0 else 0)
147
+ batches = [
148
+ quotes[i * batch_size:(i + 1) * batch_size]
149
+ for i in range(num_batches)
150
+ ]
151
+
152
+ # Initialize batch_results
153
+ batch_results = [[] for _ in range(num_batches)]
154
+
155
+ # Process batches in parallel
156
+ max_workers = min(os.cpu_count(), 4)
157
+ print(f"Processing with {max_workers} workers")
158
+
159
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
160
+ future_to_batch = {
161
+ executor.submit(classifier.process_batch, batch, idx): idx
162
+ for idx, batch in enumerate(batches)
163
+ }
164
+
165
+ for future in future_to_batch:
166
+ batch_idx = future_to_batch[future]
167
+ try:
168
+ predictions, idx = future.result()
169
+ if predictions:
170
+ batch_results[idx] = predictions
171
+ print(f"Stored results for batch {idx} ({len(predictions)} predictions)")
172
+ except Exception as e:
173
+ print(f"Failed to get results for batch {batch_idx}: {e}")
174
+ batch_results[batch_idx] = [0] * len(batches[batch_idx])
175
+
176
+ # Flatten predictions
177
+ predictions = []
178
+ for batch_preds in batch_results:
179
+ if batch_preds is not None:
180
+ predictions.extend(batch_preds)
181
+
182
+ # Stop tracking emissions
183
+ emissions_data = tracker.stop_task()
184
+
185
+ # Calculate accuracy
186
+ accuracy = accuracy_score(true_labels, predictions)
187
+ print("accuracy:", accuracy)
188
+
189
+ # Prepare results
190
+ results = {
191
+ "username": username,
192
+ "space_url": space_url,
193
+ "submission_timestamp": datetime.now().isoformat(),
194
+ "model_description": DESCRIPTION,
195
+ "accuracy": float(accuracy),
196
+ "energy_consumed_wh": emissions_data.energy_consumed * 1000,
197
+ "emissions_gco2eq": emissions_data.emissions * 1000,
198
+ "emissions_data": clean_emissions_data(emissions_data),
199
+ "api_route": ROUTE,
200
+ "dataset_config": {
201
+ "dataset_name": request.dataset_name,
202
+ "test_size": request.test_size,
203
+ "test_seed": request.test_seed
204
+ }
205
  }
 
206
 
207
+ print("results:", results)
208
+ return results
209
+
210
+ except Exception as e:
211
+ print(f"Error in evaluate_text: {str(e)}")
212
+ raise Exception(f"Failed to process request: {str(e)}")