m1k3wn commited on
Commit
3d1738d
·
verified ·
1 Parent(s): 42717bf

Update app.py

Browse files

refactor to try and stop request hangs

Files changed (1) hide show
  1. app.py +45 -15
app.py CHANGED
@@ -70,21 +70,22 @@ class ModelManager:
70
  model_path = MODELS[model_name]
71
  logger.debug(f"Loading tokenizer and model from {model_path}")
72
 
73
- # Simplified tokenizer loading
74
  tokenizer = T5Tokenizer.from_pretrained(
75
  model_path,
76
  token=HF_TOKEN,
77
- use_fast=True # Added this
78
  )
79
 
80
- # Simplified model loading
81
  model = T5ForConditionalGeneration.from_pretrained(
82
  model_path,
83
  token=HF_TOKEN,
84
- torch_dtype=torch.float32
 
 
85
  )
86
 
87
  model.eval()
 
88
  cls._instances[model_name] = (model, tokenizer)
89
 
90
  except Exception as e:
@@ -154,6 +155,12 @@ async def predict(request: PredictionRequest, background_tasks: BackgroundTasks)
154
  )
155
 
156
  model, tokenizer = await ModelManager.get_model_and_tokenizer(request.model)
 
 
 
 
 
 
157
  generation_params = DEFAULT_GENERATION_CONFIGS[request.model].copy()
158
 
159
  try:
@@ -181,17 +188,23 @@ async def predict(request: PredictionRequest, background_tasks: BackgroundTasks)
181
  )
182
 
183
  async def generate():
184
- return model.generate(
185
- **inputs,
186
- **{k: v for k, v in generation_params.items() if k in [
187
- 'max_length', 'min_length', 'do_sample', 'temperature',
188
- 'top_p', 'top_k', 'num_beams', 'no_repeat_ngram_size',
189
- 'repetition_penalty', 'early_stopping'
190
- ]}
191
- )
 
 
 
 
 
 
192
 
193
  with torch.inference_mode():
194
- outputs = await asyncio.wait_for(generate(), timeout=70.0)
195
 
196
  result = tokenizer.decode(outputs[0], skip_special_tokens=True)
197
  background_tasks.add_task(cleanup_memory)
@@ -201,14 +214,31 @@ async def predict(request: PredictionRequest, background_tasks: BackgroundTasks)
201
  selected_model=request.model
202
  )
203
 
 
 
 
 
 
 
204
  except Exception as e:
205
  error_msg = f"Error during prediction: {str(e)}\n{traceback.format_exc()}"
206
  logger.error(error_msg)
 
 
 
207
  raise HTTPException(status_code=500, detail=error_msg)
208
 
209
  def cleanup_memory():
210
- gc.collect()
211
- torch.cuda.empty_cache() if torch.cuda.is_available() else None
 
 
 
 
 
 
 
 
212
 
213
  if __name__ == "__main__":
214
  import uvicorn
 
70
  model_path = MODELS[model_name]
71
  logger.debug(f"Loading tokenizer and model from {model_path}")
72
 
 
73
  tokenizer = T5Tokenizer.from_pretrained(
74
  model_path,
75
  token=HF_TOKEN,
76
+ use_fast=True
77
  )
78
 
 
79
  model = T5ForConditionalGeneration.from_pretrained(
80
  model_path,
81
  token=HF_TOKEN,
82
+ torch_dtype=torch.float32,
83
+ low_cpu_mem_usage=True,
84
+ device_map='auto'
85
  )
86
 
87
  model.eval()
88
+ torch.set_num_threads(6) # Number of CPUs used
89
  cls._instances[model_name] = (model, tokenizer)
90
 
91
  except Exception as e:
 
155
  )
156
 
157
  model, tokenizer = await ModelManager.get_model_and_tokenizer(request.model)
158
+
159
+ # Add immediate cleanup of memory before generation
160
+ gc.collect()
161
+ if torch.cuda.is_available():
162
+ torch.cuda.empty_cache()
163
+
164
  generation_params = DEFAULT_GENERATION_CONFIGS[request.model].copy()
165
 
166
  try:
 
188
  )
189
 
190
  async def generate():
191
+ try:
192
+ return model.generate(
193
+ **inputs,
194
+ **{k: v for k, v in generation_params.items() if k in [
195
+ 'max_length', 'min_length', 'do_sample', 'temperature',
196
+ 'top_p', 'top_k', 'num_beams', 'no_repeat_ngram_size',
197
+ 'repetition_penalty', 'early_stopping'
198
+ ]}
199
+ )
200
+ finally:
201
+ # Ensure cleanup happens even if generation fails
202
+ gc.collect()
203
+ if torch.cuda.is_available():
204
+ torch.cuda.empty_cache()
205
 
206
  with torch.inference_mode():
207
+ outputs = await asyncio.wait_for(generate(), timeout=45.0) # Reduced timeout
208
 
209
  result = tokenizer.decode(outputs[0], skip_special_tokens=True)
210
  background_tasks.add_task(cleanup_memory)
 
214
  selected_model=request.model
215
  )
216
 
217
+ except asyncio.TimeoutError:
218
+ logger.error("Generation timed out")
219
+ gc.collect()
220
+ if torch.cuda.is_available():
221
+ torch.cuda.empty_cache()
222
+ raise HTTPException(status_code=504, detail="Generation timed out")
223
  except Exception as e:
224
  error_msg = f"Error during prediction: {str(e)}\n{traceback.format_exc()}"
225
  logger.error(error_msg)
226
+ gc.collect()
227
+ if torch.cuda.is_available():
228
+ torch.cuda.empty_cache()
229
  raise HTTPException(status_code=500, detail=error_msg)
230
 
231
  def cleanup_memory():
232
+ try:
233
+ gc.collect()
234
+ if torch.cuda.is_available():
235
+ torch.cuda.empty_cache()
236
+
237
+ # Force Python garbage collection
238
+ gc.collect(generation=2)
239
+
240
+ except Exception as e:
241
+ logger.error(f"Error in cleanup: {str(e)}")
242
 
243
  if __name__ == "__main__":
244
  import uvicorn