hmrizal commited on
Commit
22b2e5f
·
verified ·
1 Parent(s): 95e8f89

error handling model phi-4, deepseek lite, flan t5, adding fallback model

Browse files
Files changed (1) hide show
  1. app.py +74 -10
app.py CHANGED
@@ -170,6 +170,22 @@ def initialize_model_once(model_key):
170
  )
171
  MODEL_CACHE["is_gguf"] = False
172
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  # Handle standard HF models
174
  else:
175
  # Only use quantization if CUDA is available
@@ -247,7 +263,7 @@ def create_llm_pipeline(model_key):
247
  max_new_tokens=256, # Increased for more comprehensive answers
248
  temperature=0.3,
249
  top_p=0.9,
250
- return_full_text=False,
251
  )
252
  else:
253
  print("Creating causal LM pipeline")
@@ -271,22 +287,47 @@ def create_llm_pipeline(model_key):
271
  print(traceback.format_exc())
272
  raise RuntimeError(f"Failed to create pipeline: {str(e)}")
273
 
 
 
 
 
 
 
 
 
 
 
274
  def handle_model_loading_error(model_key, session_id):
275
- """Handle model loading errors by providing alternative model suggestions"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
  suggested_models = [
277
  "DeepSeek Coder Instruct", # 1.3B model
278
- "Phi-4 Mini Instruct", # Light model
279
  "TinyLlama Chat", # 1.1B model
280
- "Flan T5 Small" # Lightweight T5
281
  ]
282
 
283
- # Remove the current model from suggestions if it's in the list
284
- if model_key in suggested_models:
285
- suggested_models.remove(model_key)
286
 
287
  suggestions = ", ".join(suggested_models[:3]) # Only show top 3 suggestions
288
- return None, f"Unable to load model {model_key}. Please try another model such as: {suggestions}"
289
-
290
  def create_conversational_chain(db, file_path, model_key):
291
  llm = create_llm_pipeline(model_key)
292
 
@@ -359,6 +400,15 @@ def create_conversational_chain(db, file_path, model_key):
359
 
360
  # Clean the result
361
  cleaned_result = raw_result.strip()
 
 
 
 
 
 
 
 
 
362
 
363
  # If result is empty after cleaning, use a fallback
364
  if not cleaned_result:
@@ -615,8 +665,9 @@ def create_gradio_interface():
615
  outputs=[model_info]
616
  )
617
 
618
- # Process file handler - disables model selection after file is processed
619
  def handle_process_file(file, model_key, sess_id):
 
620
  if file is None:
621
  return None, None, False, "Please upload a CSV file first."
622
 
@@ -628,6 +679,19 @@ def create_gradio_interface():
628
  import traceback
629
  print(f"Error processing file with {model_key}: {str(e)}")
630
  print(traceback.format_exc())
 
 
 
 
 
 
 
 
 
 
 
 
 
631
  error_msg = f"Error with model {model_key}: {str(e)}\n\nPlease try another model."
632
  return None, False, [(None, error_msg)]
633
 
 
170
  )
171
  MODEL_CACHE["is_gguf"] = False
172
 
173
+ # Special handling for models that cause memory issues
174
+ elif model_key in ["Phi-4 Mini Instruct", "DeepSeek Lite Chat"]:
175
+ # Reduce memory footprint
176
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:64"
177
+
178
+ # For CPU-only environments, load with 8-bit quantization
179
+ MODEL_CACHE["tokenizer"] = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
180
+ MODEL_CACHE["model"] = AutoModelForCausalLM.from_pretrained(
181
+ model_name,
182
+ load_in_8bit=True, # Use 8-bit instead of 4-bit
183
+ device_map="auto" if torch.cuda.is_available() else None,
184
+ low_cpu_mem_usage=True,
185
+ trust_remote_code=True
186
+ )
187
+ MODEL_CACHE["is_gguf"] = False
188
+
189
  # Handle standard HF models
190
  else:
191
  # Only use quantization if CUDA is available
 
263
  max_new_tokens=256, # Increased for more comprehensive answers
264
  temperature=0.3,
265
  top_p=0.9,
266
+ # Remove return_full_text parameter for T5 models
267
  )
268
  else:
269
  print("Creating causal LM pipeline")
 
287
  print(traceback.format_exc())
288
  raise RuntimeError(f"Failed to create pipeline: {str(e)}")
289
 
290
+ def get_fallback_model(current_model):
291
+ """Get appropriate fallback model for problematic models"""
292
+ fallback_map = {
293
+ "Phi-4 Mini Instruct": "TinyLlama Chat",
294
+ "DeepSeek Lite Chat": "DeepSeek Coder Instruct",
295
+ "Flan T5 Small": "Llama 2 Chat"
296
+ }
297
+ return fallback_map.get(current_model, "TinyLlama Chat")
298
+
299
+ # Modified handle_model_loading_error function
300
  def handle_model_loading_error(model_key, session_id):
301
+ """Handle model loading errors by providing alternative model suggestions or fallbacks"""
302
+ # Get the appropriate fallback model
303
+ fallback_model = get_fallback_model(model_key)
304
+
305
+ # Try to load the fallback model automatically
306
+ if fallback_model != model_key:
307
+ print(f"Automatically trying fallback model: {fallback_model} for {model_key}")
308
+
309
+ try:
310
+ # Try to initialize the fallback model
311
+ tokenizer, model, is_gguf = initialize_model_once(fallback_model)
312
+ return tokenizer, model, is_gguf, f"Model {model_key} couldn't be loaded. Automatically switched to {fallback_model}."
313
+ except Exception as e:
314
+ print(f"Fallback model {fallback_model} also failed: {str(e)}")
315
+ # If fallback fails, continue with regular suggestion logic
316
+
317
+ # Regular suggestion logic for when fallbacks don't work or aren't applicable
318
  suggested_models = [
319
  "DeepSeek Coder Instruct", # 1.3B model
 
320
  "TinyLlama Chat", # 1.1B model
321
+ "Qwen2.5 Coder Instruct" # Another option
322
  ]
323
 
324
+ # Remove problematic models and current model from suggestions
325
+ problem_models = ["Phi-4 Mini Instruct", "DeepSeek Lite Chat", "Flan T5 Small"]
326
+ suggested_models = [m for m in suggested_models if m not in problem_models and m != model_key]
327
 
328
  suggestions = ", ".join(suggested_models[:3]) # Only show top 3 suggestions
329
+ return None, None, None, f"Unable to load model {model_key}. Please try another model such as: {suggestions}"
330
+
331
  def create_conversational_chain(db, file_path, model_key):
332
  llm = create_llm_pipeline(model_key)
333
 
 
400
 
401
  # Clean the result
402
  cleaned_result = raw_result.strip()
403
+
404
+ # Add special handling for T5 models
405
+ if MODEL_CONFIG.get(model_key, {}).get("is_t5", False):
406
+ # T5 models sometimes return lists instead of strings
407
+ if isinstance(raw_result, list) and len(raw_result) > 0:
408
+ if isinstance(raw_result[0], dict) and "generated_text" in raw_result[0]:
409
+ raw_result = raw_result[0]["generated_text"]
410
+ else:
411
+ raw_result = str(raw_result[0])
412
 
413
  # If result is empty after cleaning, use a fallback
414
  if not cleaned_result:
 
665
  outputs=[model_info]
666
  )
667
 
668
+ # Modified handle_process_file function
669
  def handle_process_file(file, model_key, sess_id):
670
+ """Process uploaded file with fallback model handling"""
671
  if file is None:
672
  return None, None, False, "Please upload a CSV file first."
673
 
 
679
  import traceback
680
  print(f"Error processing file with {model_key}: {str(e)}")
681
  print(traceback.format_exc())
682
+
683
+ # Try with fallback model if original fails
684
+ fallback = get_fallback_model(model_key)
685
+ if fallback != model_key:
686
+ try:
687
+ print(f"Trying fallback model: {fallback}")
688
+ chatbot = ChatBot(sess_id, fallback)
689
+ result = chatbot.process_file(file)
690
+ message = f"Original model {model_key} failed. Using {fallback} instead.\n\n{result}"
691
+ return chatbot, True, [(None, message)]
692
+ except Exception as fallback_error:
693
+ print(f"Fallback model also failed: {str(fallback_error)}")
694
+
695
  error_msg = f"Error with model {model_key}: {str(e)}\n\nPlease try another model."
696
  return None, False, [(None, error_msg)]
697