Tonic commited on
Commit
08f1c39
·
unverified ·
1 Parent(s): 7eb6153

switch model loading technique

Browse files
Files changed (1) hide show
  1. tasks/text.py +9 -25
tasks/text.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from fastapi import APIRouter
2
  from datetime import datetime
3
  import time
@@ -134,21 +135,8 @@ async def evaluate_text(request: TextEvaluationRequest):
134
  }
135
 
136
  try:
137
- # Load and prepare the dataset with retry mechanism
138
- max_retries = 3
139
- for attempt in range(max_retries):
140
- try:
141
- dataset = load_dataset(
142
- "QuotaClimat/frugalaichallenge-text-train",
143
- token=HF_TOKEN,
144
- trust_remote_code=True
145
- )
146
- break
147
- except Exception as e:
148
- if attempt == max_retries - 1:
149
- raise Exception(f"Failed to load dataset after {max_retries} attempts: {str(e)}")
150
- print(f"Dataset loading attempt {attempt + 1} failed, retrying... Error: {str(e)}")
151
- time.sleep(2)
152
 
153
  # Convert string labels to integers
154
  dataset = dataset.map(lambda x: {"label": LABEL_MAPPING[x["label"]]})
@@ -165,8 +153,8 @@ async def evaluate_text(request: TextEvaluationRequest):
165
  # Initialize the model once
166
  classifier = TextClassifier()
167
 
168
- # Prepare batches with smaller batch size
169
- batch_size = 16 # Reduced batch size
170
  quotes = test_dataset["quote"]
171
  num_batches = len(quotes) // batch_size + (1 if len(quotes) % batch_size != 0 else 0)
172
  batches = [
@@ -177,8 +165,8 @@ async def evaluate_text(request: TextEvaluationRequest):
177
  # Initialize batch_results
178
  batch_results = [[] for _ in range(num_batches)]
179
 
180
- # Process batches in parallel with fewer workers
181
- max_workers = min(os.cpu_count(), 2) # Reduced number of workers
182
  print(f"Processing with {max_workers} workers")
183
 
184
  with ThreadPoolExecutor(max_workers=max_workers) as executor:
@@ -211,11 +199,6 @@ async def evaluate_text(request: TextEvaluationRequest):
211
  accuracy = accuracy_score(true_labels, predictions)
212
  print("accuracy:", accuracy)
213
 
214
- # Clean up
215
- del classifier
216
- if torch.cuda.is_available():
217
- torch.cuda.empty_cache()
218
-
219
  # Prepare results
220
  results = {
221
  "username": username,
@@ -239,4 +222,5 @@ async def evaluate_text(request: TextEvaluationRequest):
239
 
240
  except Exception as e:
241
  print(f"Error in evaluate_text: {str(e)}")
242
- raise Exception(f"Failed to process request: {str(e)}")
 
 
1
+
2
  from fastapi import APIRouter
3
  from datetime import datetime
4
  import time
 
135
  }
136
 
137
  try:
138
+ # Load and prepare the dataset
139
+ dataset = load_dataset("QuotaClimat/frugalaichallenge-text-train", token=HF_TOKEN)
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
  # Convert string labels to integers
142
  dataset = dataset.map(lambda x: {"label": LABEL_MAPPING[x["label"]]})
 
153
  # Initialize the model once
154
  classifier = TextClassifier()
155
 
156
+ # Prepare batches
157
+ batch_size = 24
158
  quotes = test_dataset["quote"]
159
  num_batches = len(quotes) // batch_size + (1 if len(quotes) % batch_size != 0 else 0)
160
  batches = [
 
165
  # Initialize batch_results
166
  batch_results = [[] for _ in range(num_batches)]
167
 
168
+ # Process batches in parallel
169
+ max_workers = min(os.cpu_count(), 4)
170
  print(f"Processing with {max_workers} workers")
171
 
172
  with ThreadPoolExecutor(max_workers=max_workers) as executor:
 
199
  accuracy = accuracy_score(true_labels, predictions)
200
  print("accuracy:", accuracy)
201
 
 
 
 
 
 
202
  # Prepare results
203
  results = {
204
  "username": username,
 
222
 
223
  except Exception as e:
224
  print(f"Error in evaluate_text: {str(e)}")
225
+ raise Exception(f"Failed to process request: {str(e)}")
226
+