Luigi commited on
Commit
0813164
·
1 Parent(s): 37ee1f3

improve model management

Browse files
Files changed (1) hide show
  1. app.py +44 -27
app.py CHANGED
@@ -12,10 +12,10 @@ MODELS = {
12
  "filename": "qwen2.5-7b-instruct-q2_k.gguf",
13
  "description": "Qwen2.5-7B Instruct (Q2_K)"
14
  },
15
- "Gemma-3-4B-IT (Q5_K_M)": {
16
  "repo_id": "unsloth/gemma-3-4b-it-GGUF",
17
- "filename": "gemma-3-4b-it-Q5_K_M.gguf",
18
- "description": "Gemma 3 4B IT (Q5_K_M)"
19
  },
20
  "Phi-4-mini-Instruct (Q4_K_M)": {
21
  "repo_id": "unsloth/Phi-4-mini-instruct-GGUF",
@@ -38,19 +38,37 @@ with st.sidebar:
38
  selected_model = MODELS[selected_model_name]
39
  model_path = os.path.join("models", selected_model["filename"])
40
 
 
 
 
 
 
 
41
  # Make sure models dir exists
42
  os.makedirs("models", exist_ok=True)
43
 
44
- # Clear old models if new one isn't present
45
- if not os.path.exists(model_path):
46
- for file in os.listdir("models"):
47
- if file.endswith(".gguf"):
 
 
 
 
 
 
 
 
 
 
 
 
48
  try:
49
- os.remove(os.path.join("models", file))
50
  except Exception as e:
51
- st.warning(f"Failed to delete {file}: {e}")
52
 
53
- # Download the selected model
54
  with st.spinner(f"Downloading {selected_model['filename']}..."):
55
  hf_hub_download(
56
  repo_id=selected_model["repo_id"],
@@ -59,29 +77,28 @@ if not os.path.exists(model_path):
59
  local_dir_use_symlinks=False,
60
  )
61
 
62
- # Init state
63
- if "model_name" not in st.session_state:
64
- st.session_state.model_name = None
65
- if "llm" not in st.session_state:
66
- st.session_state.llm = None
67
 
68
  # Load model if changed
69
  if st.session_state.model_name != selected_model_name:
70
  if st.session_state.llm is not None:
71
  del st.session_state.llm
72
  gc.collect()
73
-
74
- st.session_state.llm = Llama(
75
- model_path=model_path,
76
- n_ctx=1024,
77
- n_threads=2,
78
- n_threads_batch=2,
79
- n_batch=4,
80
- n_gpu_layers=0,
81
- use_mlock=False,
82
- use_mmap=True,
83
- verbose=False,
84
- )
 
 
 
85
  st.session_state.model_name = selected_model_name
86
 
87
  llm = st.session_state.llm
 
12
  "filename": "qwen2.5-7b-instruct-q2_k.gguf",
13
  "description": "Qwen2.5-7B Instruct (Q2_K)"
14
  },
15
+ "Gemma-3-4B-IT (Q4_K_M)": {
16
  "repo_id": "unsloth/gemma-3-4b-it-GGUF",
17
+ "filename": "gemma-3-4b-it-Q4_K_M.gguf",
18
+ "description": "Gemma 3 4B IT (Q4_K_M)"
19
  },
20
  "Phi-4-mini-Instruct (Q4_K_M)": {
21
  "repo_id": "unsloth/Phi-4-mini-instruct-GGUF",
 
38
  selected_model = MODELS[selected_model_name]
39
  model_path = os.path.join("models", selected_model["filename"])
40
 
41
+ # Init state
42
+ if "model_name" not in st.session_state:
43
+ st.session_state.model_name = None
44
+ if "llm" not in st.session_state:
45
+ st.session_state.llm = None
46
+
47
  # Make sure models dir exists
48
  os.makedirs("models", exist_ok=True)
49
 
50
+ # If the selected model file does not exist or is invalid, clean up and re-download
51
+ def validate_or_download_model():
52
+ if not os.path.exists(model_path):
53
+ cleanup_old_models()
54
+ download_model()
55
+ return
56
+ try:
57
+ _ = Llama(model_path=model_path, n_ctx=16, n_threads=1) # dummy check
58
+ except Exception as e:
59
+ st.warning(f"Model file was invalid or corrupt: {e}\nRedownloading...")
60
+ cleanup_old_models()
61
+ download_model()
62
+
63
+ def cleanup_old_models():
64
+ for f in os.listdir("models"):
65
+ if f.endswith(".gguf") and f != selected_model["filename"]:
66
  try:
67
+ os.remove(os.path.join("models", f))
68
  except Exception as e:
69
+ st.warning(f"Couldn't delete old model {f}: {e}")
70
 
71
+ def download_model():
72
  with st.spinner(f"Downloading {selected_model['filename']}..."):
73
  hf_hub_download(
74
  repo_id=selected_model["repo_id"],
 
77
  local_dir_use_symlinks=False,
78
  )
79
 
80
+ validate_or_download_model()
 
 
 
 
81
 
82
  # Load model if changed
83
  if st.session_state.model_name != selected_model_name:
84
  if st.session_state.llm is not None:
85
  del st.session_state.llm
86
  gc.collect()
87
+ try:
88
+ st.session_state.llm = Llama(
89
+ model_path=model_path,
90
+ n_ctx=1024,
91
+ n_threads=2,
92
+ n_threads_batch=2,
93
+ n_batch=4,
94
+ n_gpu_layers=0,
95
+ use_mlock=False,
96
+ use_mmap=True,
97
+ verbose=False,
98
+ )
99
+ except Exception as e:
100
+ st.error(f"Failed to load model: {e}")
101
+ st.stop()
102
  st.session_state.model_name = selected_model_name
103
 
104
  llm = st.session_state.llm