Luigi commited on
Commit
afa19a3
·
1 Parent(s): 1bd76fd

add missing part of app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -16
app.py CHANGED
@@ -3,6 +3,7 @@ from llama_cpp import Llama
3
  from huggingface_hub import hf_hub_download
4
  import os
5
  import gc
 
6
 
7
  # Available models
8
  MODELS = {
@@ -24,22 +25,22 @@ MODELS = {
24
  "Meta-Llama-3.1-8B-Instruct (Q2_K)": {
25
  "repo_id": "MaziyarPanahi/Meta-Llama-3.1-8B-Instruct-GGUF",
26
  "filename": "Meta-Llama-3.1-8B-Instruct.Q2_K.gguf",
27
- "description": "Meta Llama 3.1 8B Instruct (Q2_K)"
28
  },
29
  "DeepSeek-R1-Distill-Llama-8B (Q2_K)": {
30
  "repo_id": "unsloth/DeepSeek-R1-Distill-Llama-8B-GGUF",
31
  "filename": "DeepSeek-R1-Distill-Llama-8B-Q2_K.gguf",
32
- "description": "DeepSeek R1 Distill Llama 8B (Q2_K)"
33
  },
34
  "Mistral-7B-Instruct-v0.3 (IQ3_XS)": {
35
  "repo_id": "MaziyarPanahi/Mistral-7B-Instruct-v0.3-GGUF",
36
  "filename": "Mistral-7B-Instruct-v0.3.IQ3_XS.gguf",
37
- "description": "Mistral 7B Instruct v0.3 (IQ3_XS)"
38
  },
39
  "Qwen2.5-Coder-7B-Instruct (Q2_K)": {
40
  "repo_id": "Qwen/Qwen2.5-Coder-7B-Instruct-GGUF",
41
  "filename": "qwen2.5-coder-7b-instruct-q2_k.gguf",
42
- "description": "Qwen2.5 Coder 7B Instruct (Q2_K)"
43
  },
44
  }
45
 
@@ -54,10 +55,34 @@ with st.sidebar:
54
  top_p = st.slider("Top-P", 0.1, 1.0, 0.95)
55
  repeat_penalty = st.slider("Repetition Penalty", 1.0, 2.0, 1.1)
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  # Model info
58
  selected_model = MODELS[selected_model_name]
59
  model_path = os.path.join("models", selected_model["filename"])
60
 
 
 
 
 
 
 
61
  # Ensure model directory exists
62
  os.makedirs("models", exist_ok=True)
63
 
@@ -70,7 +95,6 @@ def cleanup_old_models():
70
  except Exception as e:
71
  st.warning(f"Couldn't delete old model {f}: {e}")
72
 
73
- # Function to download the selected model
74
  def download_model():
75
  with st.spinner(f"Downloading {selected_model['filename']}..."):
76
  hf_hub_download(
@@ -80,26 +104,86 @@ def download_model():
80
  local_dir_use_symlinks=False,
81
  )
82
 
83
- # Function to validate or download the model
 
 
 
 
 
84
  def validate_or_download_model():
85
  if not os.path.exists(model_path):
86
  cleanup_old_models()
87
  download_model()
88
- try:
89
- # Attempt to load the model with minimal resources to validate
90
- _ = Llama(model_path=model_path, n_ctx=16, n_threads=1)
91
- except Exception as e:
92
- st.warning(f"Model file was invalid or corrupt: {e}\nRedownloading...")
93
  try:
94
  os.remove(model_path)
95
  except:
96
  pass
97
  cleanup_old_models()
98
  download_model()
99
-
100
- # Validate or download the selected model
101
- validate_or_download_model()
 
 
 
102
 
103
  # Load model if changed
104
- if "model_name" not in st.session_state or st.session_state.model_name != selected_model_name:
105
- if "llm" in st.session_state and st.session_state.llm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  from huggingface_hub import hf_hub_download
4
  import os
5
  import gc
6
+ import shutil
7
 
8
  # Available models
9
  MODELS = {
 
25
  "Meta-Llama-3.1-8B-Instruct (Q2_K)": {
26
  "repo_id": "MaziyarPanahi/Meta-Llama-3.1-8B-Instruct-GGUF",
27
  "filename": "Meta-Llama-3.1-8B-Instruct.Q2_K.gguf",
28
+ "description": "Meta-Llama-3.1-8B-Instruct (Q2_K)"
29
  },
30
  "DeepSeek-R1-Distill-Llama-8B (Q2_K)": {
31
  "repo_id": "unsloth/DeepSeek-R1-Distill-Llama-8B-GGUF",
32
  "filename": "DeepSeek-R1-Distill-Llama-8B-Q2_K.gguf",
33
+ "description": "DeepSeek-R1-Distill-Llama-8B (Q2_K)"
34
  },
35
  "Mistral-7B-Instruct-v0.3 (IQ3_XS)": {
36
  "repo_id": "MaziyarPanahi/Mistral-7B-Instruct-v0.3-GGUF",
37
  "filename": "Mistral-7B-Instruct-v0.3.IQ3_XS.gguf",
38
+ "description": "Mistral-7B-Instruct-v0.3 (IQ3_XS)"
39
  },
40
  "Qwen2.5-Coder-7B-Instruct (Q2_K)": {
41
  "repo_id": "Qwen/Qwen2.5-Coder-7B-Instruct-GGUF",
42
  "filename": "qwen2.5-coder-7b-instruct-q2_k.gguf",
43
+ "description": "Qwen2.5-Coder-7B-Instruct (Q2_K)"
44
  },
45
  }
46
 
 
55
  top_p = st.slider("Top-P", 0.1, 1.0, 0.95)
56
  repeat_penalty = st.slider("Repetition Penalty", 1.0, 2.0, 1.1)
57
 
58
+ if st.button("🧹 Clear All Cached Models"):
59
+ try:
60
+ for f in os.listdir("models"):
61
+ if f.endswith(".gguf"):
62
+ os.remove(os.path.join("models", f))
63
+ st.success("Model cache cleared.")
64
+ except Exception as e:
65
+ st.error(f"Failed to clear models: {e}")
66
+
67
+ if st.button("📦 Show Disk Usage"):
68
+ try:
69
+ usage = shutil.disk_usage(".")
70
+ used = usage.used / (1024**3)
71
+ free = usage.free / (1024**3)
72
+ st.info(f"Disk Used: {used:.2f} GB | Free: {free:.2f} GB")
73
+ except Exception as e:
74
+ st.error(f"Disk usage error: {e}")
75
+
76
  # Model info
77
  selected_model = MODELS[selected_model_name]
78
  model_path = os.path.join("models", selected_model["filename"])
79
 
80
+ # Init state
81
+ if "model_name" not in st.session_state:
82
+ st.session_state.model_name = None
83
+ if "llm" not in st.session_state:
84
+ st.session_state.llm = None
85
+
86
  # Ensure model directory exists
87
  os.makedirs("models", exist_ok=True)
88
 
 
95
  except Exception as e:
96
  st.warning(f"Couldn't delete old model {f}: {e}")
97
 
 
98
  def download_model():
99
  with st.spinner(f"Downloading {selected_model['filename']}..."):
100
  hf_hub_download(
 
104
  local_dir_use_symlinks=False,
105
  )
106
 
107
+ def try_load_model(path):
108
+ try:
109
+ return Llama(model_path=path, n_ctx=1024, n_threads=2, n_threads_batch=2, n_batch=4, n_gpu_layers=0, use_mlock=False, use_mmap=True, verbose=False)
110
+ except Exception as e:
111
+ return str(e)
112
+
113
  def validate_or_download_model():
114
  if not os.path.exists(model_path):
115
  cleanup_old_models()
116
  download_model()
117
+
118
+ # First load attempt
119
+ result = try_load_model(model_path)
120
+ if isinstance(result, str):
121
+ st.warning(f"Initial load failed: {result}\nAttempting re-download...")
122
  try:
123
  os.remove(model_path)
124
  except:
125
  pass
126
  cleanup_old_models()
127
  download_model()
128
+ result = try_load_model(model_path)
129
+ if isinstance(result, str):
130
+ st.error(f"Model still failed after re-download: {result}")
131
+ st.stop()
132
+ return result
133
+ return result
134
 
135
  # Load model if changed
136
+ if st.session_state.model_name != selected_model_name:
137
+ if st.session_state.llm is not None:
138
+ del st.session_state.llm
139
+ gc.collect()
140
+ st.session_state.llm = validate_or_download_model()
141
+ st.session_state.model_name = selected_model_name
142
+
143
+ llm = st.session_state.llm
144
+
145
+ # Chat history state
146
+ if "chat_history" not in st.session_state:
147
+ st.session_state.chat_history = []
148
+
149
+ st.title(f"🧠 {selected_model['description']} (Streamlit + GGUF)")
150
+ st.caption(f"Powered by `llama.cpp` | Model: {selected_model['filename']}")
151
+
152
+ user_input = st.chat_input("Ask something...")
153
+
154
+ if user_input:
155
+ # Prevent appending user message if assistant hasn't replied yet
156
+ if len(st.session_state.chat_history) % 2 == 1:
157
+ st.warning("Please wait for the assistant to respond before sending another message.")
158
+ else:
159
+ st.session_state.chat_history.append({"role": "user", "content": user_input})
160
+
161
+ with st.chat_message("user"):
162
+ st.markdown(user_input)
163
+
164
+ # Trim conversation history to max 8 turns (user+assistant)
165
+ MAX_TURNS = 8
166
+ trimmed_history = st.session_state.chat_history[-MAX_TURNS * 2:]
167
+ messages = [{"role": "system", "content": system_prompt}] + trimmed_history
168
+
169
+ with st.chat_message("assistant"):
170
+ full_response = ""
171
+ response_area = st.empty()
172
+
173
+ stream = llm.create_chat_completion(
174
+ messages=messages,
175
+ max_tokens=max_tokens,
176
+ temperature=temperature,
177
+ top_k=top_k,
178
+ top_p=top_p,
179
+ repeat_penalty=repeat_penalty,
180
+ stream=True,
181
+ )
182
+
183
+ for chunk in stream:
184
+ if "choices" in chunk:
185
+ delta = chunk["choices"][0]["delta"].get("content", "")
186
+ full_response += delta
187
+ response_area.markdown(full_response)
188
+
189
+ st.session_state.chat_history.append({"role": "assistant", "content": full_response})