hmrizal commited on
Commit
3dda2b6
·
verified ·
1 Parent(s): 516ac46

add reset_model_cache to prevent memory leak, force cpu_only and disable 8-bit quant for Phi-4

Browse files
Files changed (1) hide show
  1. app.py +69 -33
app.py CHANGED
@@ -107,20 +107,24 @@ performance_tracker = PerformanceTracker()
107
 
108
  def initialize_model_once(model_key):
109
  with MODEL_CACHE["init_lock"]:
110
- current_model = MODEL_CACHE["model_name"]
111
- if MODEL_CACHE["model"] is None or current_model != model_key:
112
- # Clear previous model
113
- if MODEL_CACHE["model"] is not None:
114
- del MODEL_CACHE["model"]
115
- if MODEL_CACHE["tokenizer"] is not None:
116
- del MODEL_CACHE["tokenizer"]
 
 
 
 
117
  torch.cuda.empty_cache() if torch.cuda.is_available() else None
 
118
 
119
- model_info = MODEL_CONFIG[model_key]
120
- model_name = model_info["name"]
121
- MODEL_CACHE["model_name"] = model_key
122
 
123
- try:
124
  print(f"Loading model: {model_name}")
125
 
126
  # Check if this is a GGUF model
@@ -169,22 +173,30 @@ def initialize_model_once(model_key):
169
  low_cpu_mem_usage=True
170
  )
171
  MODEL_CACHE["is_gguf"] = False
172
-
173
- # Special handling for models that cause memory issues
174
- elif model_key in ["Phi-4 Mini Instruct", "DeepSeek Lite Chat"]:
175
- # Reduce memory footprint
176
- os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:64"
177
-
178
- # For CPU-only environments, load with 8-bit quantization
179
  MODEL_CACHE["tokenizer"] = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
180
  MODEL_CACHE["model"] = AutoModelForCausalLM.from_pretrained(
181
  model_name,
182
- load_in_8bit=True, # Use 8-bit instead of 4-bit
183
- device_map="auto" if torch.cuda.is_available() else None,
 
 
 
 
 
 
 
 
 
 
 
 
184
  low_cpu_mem_usage=True,
185
  trust_remote_code=True
186
  )
187
- MODEL_CACHE["is_gguf"] = False
188
 
189
  # Handle standard HF models
190
  else:
@@ -219,19 +231,26 @@ def initialize_model_once(model_key):
219
  MODEL_CACHE["is_gguf"] = False
220
 
221
  print(f"Model {model_name} loaded successfully")
222
- except Exception as e:
223
- import traceback
224
- print(f"Error loading model {model_name}: {str(e)}")
225
- print(traceback.format_exc())
226
- raise RuntimeError(f"Failed to load model {model_name}: {str(e)}")
227
-
228
- # Final verification that model loaded correctly
229
- if MODEL_CACHE["model"] is None:
230
- print(f"WARNING: Model {model_name} appears to be None after loading")
231
- # Try to free memory before returning
232
- torch.cuda.empty_cache() if torch.cuda.is_available() else None
 
 
233
  gc.collect()
234
-
 
 
 
 
 
235
  return MODEL_CACHE["tokenizer"], MODEL_CACHE["model"], MODEL_CACHE.get("is_gguf", False)
236
 
237
  def get_fallback_model(current_model):
@@ -312,6 +331,22 @@ def create_llm_pipeline(model_key):
312
  print(traceback.format_exc())
313
  raise RuntimeError(f"Failed to create pipeline: {str(e)}")
314
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315
  # Modified handle_model_loading_error function
316
  def handle_model_loading_error(model_key, session_id):
317
  """Handle model loading errors by providing alternative model suggestions or fallbacks"""
@@ -724,6 +759,7 @@ def create_gradio_interface():
724
 
725
  # Reset handler - enables model selection again
726
  def reset_session():
 
727
  return None, False, [], gr.update(interactive=True)
728
 
729
  reset_button.click(
 
107
 
108
  def initialize_model_once(model_key):
109
  with MODEL_CACHE["init_lock"]:
110
+ try:
111
+ current_model = MODEL_CACHE["model_name"]
112
+ if MODEL_CACHE["model"] is None or current_model != model_key:
113
+ # Clear previous model
114
+ if MODEL_CACHE["model"] is not None:
115
+ del MODEL_CACHE["model"]
116
+ if MODEL_CACHE["tokenizer"] is not None:
117
+ del MODEL_CACHE["tokenizer"]
118
+
119
+ # Force garbage collection
120
+ gc.collect()
121
  torch.cuda.empty_cache() if torch.cuda.is_available() else None
122
+ time.sleep(1) # Give system time to release memory
123
 
124
+ model_info = MODEL_CONFIG[model_key]
125
+ model_name = model_info["name"]
126
+ MODEL_CACHE["model_name"] = model_key
127
 
 
128
  print(f"Loading model: {model_name}")
129
 
130
  # Check if this is a GGUF model
 
173
  low_cpu_mem_usage=True
174
  )
175
  MODEL_CACHE["is_gguf"] = False
176
+
177
+ # For Phi-4 specifically
178
+ elif "Phi-4" in model_key:
 
 
 
 
179
  MODEL_CACHE["tokenizer"] = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
180
  MODEL_CACHE["model"] = AutoModelForCausalLM.from_pretrained(
181
  model_name,
182
+ device_map="cpu", # Force CPU explicitly
183
+ torch_dtype=torch.float32, # Use float32 for CPU
184
+ low_cpu_mem_usage=True,
185
+ trust_remote_code=True
186
+ )
187
+ MODEL_CACHE["is_gguf"] = False
188
+
189
+ # Special handling for DeepSeek Lite Chat
190
+ elif model_key == "DeepSeek Lite Chat":
191
+ MODEL_CACHE["tokenizer"] = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
192
+ MODEL_CACHE["model"] = AutoModelForCausalLM.from_pretrained(
193
+ model_name,
194
+ device_map="cpu", # Force CPU
195
+ torch_dtype=torch.float32, # Use float32 for CPU
196
  low_cpu_mem_usage=True,
197
  trust_remote_code=True
198
  )
199
+ MODEL_CACHE["is_gguf"] = False
200
 
201
  # Handle standard HF models
202
  else:
 
231
  MODEL_CACHE["is_gguf"] = False
232
 
233
  print(f"Model {model_name} loaded successfully")
234
+
235
+ # Final verification that model loaded correctly
236
+ if MODEL_CACHE["model"] is None:
237
+ print(f"WARNING: Model {model_name} appears to be None after loading")
238
+ # Try to free memory before returning
239
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
240
+ gc.collect()
241
+
242
+ except Exception as e:
243
+ # Reset model cache on error
244
+ MODEL_CACHE["model"] = None
245
+ MODEL_CACHE["tokenizer"] = None
246
+ # Force garbage collection
247
  gc.collect()
248
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
249
+ import traceback
250
+ print(f"Error loading model {model_key}: {str(e)}")
251
+ print(traceback.format_exc())
252
+ raise RuntimeError(f"Failed to load model {model_key}: {str(e)}")
253
+
254
  return MODEL_CACHE["tokenizer"], MODEL_CACHE["model"], MODEL_CACHE.get("is_gguf", False)
255
 
256
  def get_fallback_model(current_model):
 
331
  print(traceback.format_exc())
332
  raise RuntimeError(f"Failed to create pipeline: {str(e)}")
333
 
334
+ # add a reset function to clear models between sessions
335
+ def reset_model_cache():
336
+ """Force clear all model cache"""
337
+ with MODEL_CACHE["init_lock"]:
338
+ if MODEL_CACHE["model"] is not None:
339
+ del MODEL_CACHE["model"]
340
+ if MODEL_CACHE["tokenizer"] is not None:
341
+ del MODEL_CACHE["tokenizer"]
342
+ MODEL_CACHE["model"] = None
343
+ MODEL_CACHE["tokenizer"] = None
344
+ MODEL_CACHE["model_name"] = None
345
+ MODEL_CACHE["is_gguf"] = False
346
+ gc.collect()
347
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
348
+ time.sleep(1)
349
+
350
  # Modified handle_model_loading_error function
351
  def handle_model_loading_error(model_key, session_id):
352
  """Handle model loading errors by providing alternative model suggestions or fallbacks"""
 
759
 
760
  # Reset handler - enables model selection again
761
  def reset_session():
762
+ reset_model_cache() # call reset model cache
763
  return None, False, [], gr.update(interactive=True)
764
 
765
  reset_button.click(