Luigi commited on
Commit
248f5a7
·
1 Parent(s): a7fdfe6

Code simplification

Browse files
Files changed (1) hide show
  1. app.py +97 -115
app.py CHANGED
@@ -1,61 +1,49 @@
1
  import streamlit as st
 
 
2
  from llama_cpp import Llama
3
  from llama_cpp.llama_speculative import LlamaPromptLookupDecoding
4
  from huggingface_hub import hf_hub_download
5
- import os, gc, shutil, re
6
- from itertools import islice
7
- from duckduckgo_search import DDGS # Latest class-based interface :contentReference[oaicite:0]{index=0}
8
 
9
- # ----- Custom CSS for pretty formatting of internal reasoning -----
10
- CUSTOM_CSS = """
11
- <style>
12
- /* Styles for the internal reasoning bullet list */
13
- ul.think-list {
14
- margin: 0.5em 0 1em 1.5em;
15
- padding: 0;
16
- list-style-type: disc;
17
- }
18
- ul.think-list li {
19
- margin-bottom: 0.5em;
20
- }
21
 
22
- /* Container style for the "in progress" internal reasoning */
23
- .chat-assistant {
24
- background-color: #f9f9f9;
25
- padding: 1em;
26
- border-radius: 5px;
27
- margin-bottom: 1em;
28
- }
29
  </style>
30
- """
31
- st.markdown(CUSTOM_CSS, unsafe_allow_html=True)
32
 
33
- # ----- Set a threshold for required free storage (in bytes) -----
34
  REQUIRED_SPACE_BYTES = 5 * 1024 ** 3 # 5 GB
35
 
36
- # ----- Function to perform DuckDuckGo search and retrieve concise context -----
37
  def retrieve_context(query, max_results=2, max_chars_per_result=150):
38
- """
39
- Query DuckDuckGo for the given search query and return a concatenated context string.
40
- Uses the DDGS().text() generator (with region, safesearch, and timelimit parameters)
41
- and limits the results using islice. Each result's title and snippet are combined into context.
42
- """
43
  try:
44
  with DDGS() as ddgs:
45
- results_gen = ddgs.text(query, region="wt-wt", safesearch="off", timelimit="y")
46
- results = list(islice(results_gen, max_results))
47
  context = ""
48
- if results:
49
- for i, result in enumerate(results, start=1):
50
- title = result.get("title", "No Title")
51
- snippet = result.get("body", "")[:max_chars_per_result]
52
- context += f"Result {i}:\nTitle: {title}\nSnippet: {snippet}\n\n"
53
  return context.strip()
54
  except Exception as e:
55
  st.error(f"Error during retrieval: {e}")
56
  return ""
57
 
58
- # ----- Available models -----
59
  MODELS = {
60
  "Qwen2.5-0.5B-Instruct (Q4_K_M)": {
61
  "repo_id": "Qwen/Qwen2.5-0.5B-Instruct-GGUF",
@@ -109,40 +97,26 @@ with st.sidebar:
109
  st.header("⚙️ Settings")
110
  selected_model_name = st.selectbox("Select Model", list(MODELS.keys()))
111
  system_prompt_base = st.text_area("System Prompt", value="You are a helpful assistant.", height=80)
112
- max_tokens = st.slider("Max tokens", 64, 1024, 256, step=32) # Adjust for lower memory usage
113
  temperature = st.slider("Temperature", 0.1, 2.0, 0.7)
114
  top_k = st.slider("Top-K", 1, 100, 40)
115
  top_p = st.slider("Top-P", 0.1, 1.0, 0.95)
116
  repeat_penalty = st.slider("Repetition Penalty", 1.0, 2.0, 1.1)
117
-
118
- # Checkbox to enable the DuckDuckGo search feature (disabled by default)
119
  enable_search = st.checkbox("Enable Web Search", value=False)
120
-
121
- if st.button("📦 Show Disk Usage"):
122
- try:
123
- usage = shutil.disk_usage(".")
124
- used = usage.used / (1024 ** 3)
125
- free = usage.free / (1024 ** 3)
126
- st.info(f"Disk Used: {used:.2f} GB | Free: {free:.2f} GB")
127
- except Exception as e:
128
- st.error(f"Disk usage error: {e}")
129
 
130
- # ----- Define selected model and path -----
131
  selected_model = MODELS[selected_model_name]
132
  model_path = os.path.join("models", selected_model["filename"])
133
-
134
- # Ensure model directory exists
135
  os.makedirs("models", exist_ok=True)
136
 
137
- # ----- Helper functions for model management -----
138
  def try_load_model(path):
139
  try:
140
  return Llama(
141
  model_path=path,
142
- n_ctx=512, # Reduced context window to save memory
143
- n_threads=2, # Fewer threads for resource-constrained environments
144
  n_threads_batch=1,
145
- n_batch=64, # Lower batch size to conserve memory
146
  n_gpu_layers=0,
147
  use_mlock=False,
148
  use_mmap=True,
@@ -164,13 +138,12 @@ def download_model():
164
 
165
  def validate_or_download_model():
166
  if not os.path.exists(model_path):
167
- free_space = shutil.disk_usage(".").free
168
- if free_space < REQUIRED_SPACE_BYTES:
169
  st.info("Insufficient storage. Consider cleaning up old models.")
170
  download_model()
171
  result = try_load_model(model_path)
172
  if isinstance(result, str):
173
- st.warning(f"Initial load failed: {result}\nAttempting re-download...")
174
  try:
175
  os.remove(model_path)
176
  except Exception:
@@ -180,20 +153,8 @@ def validate_or_download_model():
180
  if isinstance(result, str):
181
  st.error(f"Model still failed after re-download: {result}")
182
  st.stop()
183
- return result
184
  return result
185
 
186
- # ----- Session state initialization -----
187
- if "model_name" not in st.session_state:
188
- st.session_state.model_name = None
189
- if "llm" not in st.session_state:
190
- st.session_state.llm = None
191
- if "chat_history" not in st.session_state:
192
- st.session_state.chat_history = []
193
- if "pending_response" not in st.session_state:
194
- st.session_state.pending_response = False
195
-
196
- # ----- Load model if changed -----
197
  if st.session_state.model_name != selected_model_name:
198
  if st.session_state.llm is not None:
199
  del st.session_state.llm
@@ -203,40 +164,32 @@ if st.session_state.model_name != selected_model_name:
203
 
204
  llm = st.session_state.llm
205
 
206
- # ----- Display title and caption -----
207
  st.title(f"🧠 {selected_model['description']} (Streamlit + GGUF)")
208
  st.caption(f"Powered by `llama.cpp` | Model: {selected_model['filename']}")
209
 
210
- # Render existing chat history
211
  for chat in st.session_state.chat_history:
212
  with st.chat_message(chat["role"]):
213
  st.markdown(chat["content"])
214
 
215
- # ----- Chat input and integrated RAG with memory optimizations -----
216
  user_input = st.chat_input("Ask something...")
217
-
218
  if user_input:
219
  if st.session_state.pending_response:
220
  st.warning("Please wait for the assistant to finish responding.")
221
  else:
222
- # Display the raw user input immediately in the chat view.
223
  with st.chat_message("user"):
224
  st.markdown(user_input)
225
-
226
- # Append the plain user message to chat history for display purposes.
227
- # (We will later override the last user message in the API call with the augmented version.)
228
  st.session_state.chat_history.append({"role": "user", "content": user_input})
229
  st.session_state.pending_response = True
230
 
231
- # Retrieve extra context from web search if enabled
232
- if enable_search:
233
- retrieved_context = retrieve_context(user_input, max_results=2, max_chars_per_result=150)
234
- else:
235
- retrieved_context = ""
236
  st.sidebar.markdown("### Retrieved Context" if enable_search else "Web Search Disabled")
237
  st.sidebar.text(retrieved_context or "No context found.")
238
 
239
- # Build an augmented user query by merging the system prompt (and search context when available)
240
  if enable_search and retrieved_context:
241
  augmented_user_input = (
242
  f"{system_prompt_base.strip()}\n\n"
@@ -247,39 +200,68 @@ if user_input:
247
  else:
248
  augmented_user_input = f"{system_prompt_base.strip()}\n\nUser Query: {user_input}"
249
 
250
- # Limit conversation history to the last MAX_TURNS turns (user/assistant pairs)
251
  MAX_TURNS = 2
252
  trimmed_history = st.session_state.chat_history[-(MAX_TURNS * 2):]
253
-
254
- # Replace the last user message (which is plain) with the augmented version for model input.
255
  if trimmed_history and trimmed_history[-1]["role"] == "user":
256
  messages = trimmed_history[:-1] + [{"role": "user", "content": augmented_user_input}]
257
  else:
258
  messages = trimmed_history + [{"role": "user", "content": augmented_user_input}]
259
 
260
- # Generate response with the LLM in a streaming fashion
261
- with st.chat_message("assistant"):
262
- visible_placeholder = st.empty()
263
- full_response = ""
264
- stream = llm.create_chat_completion(
265
- messages=messages,
266
- max_tokens=max_tokens,
267
- temperature=temperature,
268
- top_k=top_k,
269
- top_p=top_p,
270
- repeat_penalty=repeat_penalty,
271
- stream=True,
272
- )
273
- for chunk in stream:
274
- if "choices" in chunk:
275
- delta = chunk["choices"][0]["delta"].get("content", "")
276
- full_response += delta
277
- # Clean internal reasoning markers before display
278
- visible_response = re.sub(r"<think>.*?</think>", "", full_response, flags=re.DOTALL)
279
- visible_response = re.sub(r"<think>.*$", "", visible_response, flags=re.DOTALL)
280
- visible_placeholder.markdown(visible_response)
281
-
282
- # Append the assistant's response to conversation history.
283
- st.session_state.chat_history.append({"role": "assistant", "content": full_response})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284
  st.session_state.pending_response = False
285
- gc.collect() # Free memory
 
1
  import streamlit as st
2
+ import os, gc, shutil, re, time, threading, queue
3
+ from itertools import islice
4
  from llama_cpp import Llama
5
  from llama_cpp.llama_speculative import LlamaPromptLookupDecoding
6
  from huggingface_hub import hf_hub_download
7
+ from duckduckgo_search import DDGS
 
 
8
 
9
+ # ---- Initialize session state ----
10
+ if "chat_history" not in st.session_state:
11
+ st.session_state.chat_history = []
12
+ if "pending_response" not in st.session_state:
13
+ st.session_state.pending_response = False
14
+ if "model_name" not in st.session_state:
15
+ st.session_state.model_name = None
16
+ if "llm" not in st.session_state:
17
+ st.session_state.llm = None
 
 
 
18
 
19
+ # ---- Custom CSS ----
20
+ st.markdown("""
21
+ <style>
22
+ ul.think-list { margin: 0.5em 0 1em 1.5em; padding: 0; list-style-type: disc; }
23
+ ul.think-list li { margin-bottom: 0.5em; }
24
+ .chat-assistant { background-color: #f9f9f9; padding: 1em; border-radius: 5px; margin-bottom: 1em; }
 
25
  </style>
26
+ """, unsafe_allow_html=True)
 
27
 
28
+ # ---- Required storage space ----
29
  REQUIRED_SPACE_BYTES = 5 * 1024 ** 3 # 5 GB
30
 
31
+ # ---- Function to retrieve web search context ----
32
  def retrieve_context(query, max_results=2, max_chars_per_result=150):
 
 
 
 
 
33
  try:
34
  with DDGS() as ddgs:
35
+ results = list(islice(ddgs.text(query, region="wt-wt", safesearch="off", timelimit="y"), max_results))
 
36
  context = ""
37
+ for i, result in enumerate(results, start=1):
38
+ title = result.get("title", "No Title")
39
+ snippet = result.get("body", "")[:max_chars_per_result]
40
+ context += f"Result {i}:\nTitle: {title}\nSnippet: {snippet}\n\n"
 
41
  return context.strip()
42
  except Exception as e:
43
  st.error(f"Error during retrieval: {e}")
44
  return ""
45
 
46
+ # ---- Model definitions ----
47
  MODELS = {
48
  "Qwen2.5-0.5B-Instruct (Q4_K_M)": {
49
  "repo_id": "Qwen/Qwen2.5-0.5B-Instruct-GGUF",
 
97
  st.header("⚙️ Settings")
98
  selected_model_name = st.selectbox("Select Model", list(MODELS.keys()))
99
  system_prompt_base = st.text_area("System Prompt", value="You are a helpful assistant.", height=80)
100
+ max_tokens = st.slider("Max tokens", 64, 1024, 256, step=32)
101
  temperature = st.slider("Temperature", 0.1, 2.0, 0.7)
102
  top_k = st.slider("Top-K", 1, 100, 40)
103
  top_p = st.slider("Top-P", 0.1, 1.0, 0.95)
104
  repeat_penalty = st.slider("Repetition Penalty", 1.0, 2.0, 1.1)
 
 
105
  enable_search = st.checkbox("Enable Web Search", value=False)
 
 
 
 
 
 
 
 
 
106
 
107
+ # ---- Define selected model and manage its download/load ----
108
  selected_model = MODELS[selected_model_name]
109
  model_path = os.path.join("models", selected_model["filename"])
 
 
110
  os.makedirs("models", exist_ok=True)
111
 
 
112
  def try_load_model(path):
113
  try:
114
  return Llama(
115
  model_path=path,
116
+ n_ctx=512, # Reduced context window
117
+ n_threads=2,
118
  n_threads_batch=1,
119
+ n_batch=64,
120
  n_gpu_layers=0,
121
  use_mlock=False,
122
  use_mmap=True,
 
138
 
139
  def validate_or_download_model():
140
  if not os.path.exists(model_path):
141
+ if shutil.disk_usage(".").free < REQUIRED_SPACE_BYTES:
 
142
  st.info("Insufficient storage. Consider cleaning up old models.")
143
  download_model()
144
  result = try_load_model(model_path)
145
  if isinstance(result, str):
146
+ st.warning(f"Initial load failed: {result}\nRe-downloading...")
147
  try:
148
  os.remove(model_path)
149
  except Exception:
 
153
  if isinstance(result, str):
154
  st.error(f"Model still failed after re-download: {result}")
155
  st.stop()
 
156
  return result
157
 
 
 
 
 
 
 
 
 
 
 
 
158
  if st.session_state.model_name != selected_model_name:
159
  if st.session_state.llm is not None:
160
  del st.session_state.llm
 
164
 
165
  llm = st.session_state.llm
166
 
167
+ # ---- Display title and existing chat history ----
168
  st.title(f"🧠 {selected_model['description']} (Streamlit + GGUF)")
169
  st.caption(f"Powered by `llama.cpp` | Model: {selected_model['filename']}")
170
 
 
171
  for chat in st.session_state.chat_history:
172
  with st.chat_message(chat["role"]):
173
  st.markdown(chat["content"])
174
 
175
+ # ---- Chat input and processing ----
176
  user_input = st.chat_input("Ask something...")
 
177
  if user_input:
178
  if st.session_state.pending_response:
179
  st.warning("Please wait for the assistant to finish responding.")
180
  else:
181
+ # Display user input and update chat history
182
  with st.chat_message("user"):
183
  st.markdown(user_input)
 
 
 
184
  st.session_state.chat_history.append({"role": "user", "content": user_input})
185
  st.session_state.pending_response = True
186
 
187
+ # Optionally retrieve extra context
188
+ retrieved_context = retrieve_context(user_input, max_results=2, max_chars_per_result=150) if enable_search else ""
 
 
 
189
  st.sidebar.markdown("### Retrieved Context" if enable_search else "Web Search Disabled")
190
  st.sidebar.text(retrieved_context or "No context found.")
191
 
192
+ # Build augmented query
193
  if enable_search and retrieved_context:
194
  augmented_user_input = (
195
  f"{system_prompt_base.strip()}\n\n"
 
200
  else:
201
  augmented_user_input = f"{system_prompt_base.strip()}\n\nUser Query: {user_input}"
202
 
203
+ # Limit conversation history (last 2 pairs)
204
  MAX_TURNS = 2
205
  trimmed_history = st.session_state.chat_history[-(MAX_TURNS * 2):]
 
 
206
  if trimmed_history and trimmed_history[-1]["role"] == "user":
207
  messages = trimmed_history[:-1] + [{"role": "user", "content": augmented_user_input}]
208
  else:
209
  messages = trimmed_history + [{"role": "user", "content": augmented_user_input}]
210
 
211
+ # ---- Set up a placeholder for the response and queue for streaming tokens ----
212
+ visible_placeholder = st.empty()
213
+ response_queue = queue.Queue()
214
+
215
+ # Function to stream LLM response and push incremental updates into the queue
216
+ def stream_response(msgs, max_tokens, temp, topk, topp, repeat_penalty):
217
+ final_text = ""
218
+ try:
219
+ stream = llm.create_chat_completion(
220
+ messages=msgs,
221
+ max_tokens=max_tokens,
222
+ temperature=temp,
223
+ top_k=topk,
224
+ top_p=topp,
225
+ repeat_penalty=repeat_penalty,
226
+ stream=True,
227
+ )
228
+ for chunk in stream:
229
+ if "choices" in chunk:
230
+ delta = chunk["choices"][0]["delta"].get("content", "")
231
+ final_text += delta
232
+ response_queue.put(delta)
233
+ if chunk["choices"][0].get("finish_reason", ""):
234
+ break
235
+ except Exception as e:
236
+ response_queue.put(f"\nError: {e}")
237
+ response_queue.put(None) # Signal completion
238
+
239
+ # Start streaming in a separate thread
240
+ stream_thread = threading.Thread(
241
+ target=stream_response,
242
+ args=(messages, max_tokens, temperature, top_k, top_p, repeat_penalty),
243
+ daemon=True
244
+ )
245
+ stream_thread.start()
246
+
247
+ # Poll the queue in the main thread for up to 5 seconds
248
+ final_response = ""
249
+ timeout = 120 # seconds
250
+ start_time = time.time()
251
+ while True:
252
+ try:
253
+ update = response_queue.get(timeout=0.1)
254
+ if update is None:
255
+ break
256
+ final_response += update
257
+ visible_response = re.sub(r"<think>.*?</think>", "", final_response, flags=re.DOTALL)
258
+ visible_response = re.sub(r"<think>.*$", "", visible_response, flags=re.DOTALL)
259
+ visible_placeholder.markdown(visible_response)
260
+ except queue.Empty:
261
+ if time.time() - start_time > timeout:
262
+ st.error("Response generation timed out.")
263
+ break
264
+
265
+ st.session_state.chat_history.append({"role": "assistant", "content": final_response})
266
  st.session_state.pending_response = False
267
+ gc.collect()