Luigi's picture
open web search settgins to user
c9fd924
raw
history blame
11.6 kB
import streamlit as st
import os, gc, shutil, re, time, threading, queue
from itertools import islice
from llama_cpp import Llama
from llama_cpp.llama_speculative import LlamaPromptLookupDecoding
from huggingface_hub import hf_hub_download
from duckduckgo_search import DDGS
# ---- Initialize session state ----
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
if "pending_response" not in st.session_state:
st.session_state.pending_response = False
if "model_name" not in st.session_state:
st.session_state.model_name = None
if "llm" not in st.session_state:
st.session_state.llm = None
# ---- Custom CSS ----
st.markdown("""
<style>
ul.think-list { margin: 0.5em 0 1em 1.5em; padding: 0; list-style-type: disc; }
ul.think-list li { margin-bottom: 0.5em; }
.chat-assistant { background-color: #f9f9f9; padding: 1em; border-radius: 5px; margin-bottom: 1em; }
</style>
""", unsafe_allow_html=True)
# ---- Required storage space ----
REQUIRED_SPACE_BYTES = 5 * 1024 ** 3 # 5 GB
# ---- Function to retrieve web search context ----
def retrieve_context(query, max_results=6, max_chars_per_result=600):
try:
with DDGS() as ddgs:
results = list(islice(ddgs.text(query, region="wt-wt", safesearch="off", timelimit="y"), max_results))
context = ""
for i, result in enumerate(results, start=1):
title = result.get("title", "No Title")
snippet = result.get("body", "")[:max_chars_per_result]
context += f"Result {i}:\nTitle: {title}\nSnippet: {snippet}\n\n"
return context.strip()
except Exception as e:
st.error(f"Error during retrieval: {e}")
return ""
# ---- Model definitions ----
MODELS = {
"Qwen2.5-0.5B-Instruct (Q4_K_M)": {
"repo_id": "Qwen/Qwen2.5-0.5B-Instruct-GGUF",
"filename": "qwen2.5-0.5b-instruct-q4_k_m.gguf",
"description": "Qwen2.5-0.5B-Instruct (Q4_K_M)"
},
"Gemma-3.1B-it (Q4_K_M)": {
"repo_id": "unsloth/gemma-3-1b-it-GGUF",
"filename": "gemma-3-1b-it-Q4_K_M.gguf",
"description": "Gemma-3.1B-it (Q4_K_M)"
},
"Qwen2.5-1.5B-Instruct (Q4_K_M)": {
"repo_id": "Qwen/Qwen2.5-1.5B-Instruct-GGUF",
"filename": "qwen2.5-1.5b-instruct-q4_k_m.gguf",
"description": "Qwen2.5-1.5B-Instruct (Q4_K_M)"
},
"Qwen2.5-3B-Instruct (Q4_K_M)": {
"repo_id": "Qwen/Qwen2.5-3B-Instruct-GGUF",
"filename": "qwen2.5-3b-instruct-q4_k_m.gguf",
"description": "Qwen2.5-3B-Instruct (Q4_K_M)"
},
"Qwen2.5-7B-Instruct (Q2_K)": {
"repo_id": "Qwen/Qwen2.5-7B-Instruct-GGUF",
"filename": "qwen2.5-7b-instruct-q2_k.gguf",
"description": "Qwen2.5-7B Instruct (Q2_K)"
},
"Gemma-3-4B-IT (Q4_K_M)": {
"repo_id": "unsloth/gemma-3-4b-it-GGUF",
"filename": "gemma-3-4b-it-Q4_K_M.gguf",
"description": "Gemma 3 4B IT (Q4_K_M)"
},
"Phi-4-mini-Instruct (Q4_K_M)": {
"repo_id": "unsloth/Phi-4-mini-instruct-GGUF",
"filename": "Phi-4-mini-instruct-Q4_K_M.gguf",
"description": "Phi-4 Mini Instruct (Q4_K_M)"
},
"Meta-Llama-3.1-8B-Instruct (Q2_K)": {
"repo_id": "MaziyarPanahi/Meta-Llama-3.1-8B-Instruct-GGUF",
"filename": "Meta-Llama-3.1-8B-Instruct.Q2_K.gguf",
"description": "Meta-Llama-3.1-8B-Instruct (Q2_K)"
},
"DeepSeek-R1-Distill-Llama-8B (Q2_K)": {
"repo_id": "unsloth/DeepSeek-R1-Distill-Llama-8B-GGUF",
"filename": "DeepSeek-R1-Distill-Llama-8B-Q2_K.gguf",
"description": "DeepSeek-R1-Distill-Llama-8B (Q2_K)"
},
"Mistral-7B-Instruct-v0.3 (IQ3_XS)": {
"repo_id": "MaziyarPanahi/Mistral-7B-Instruct-v0.3-GGUF",
"filename": "Mistral-7B-Instruct-v0.3.IQ3_XS.gguf",
"description": "Mistral-7B-Instruct-v0.3 (IQ3_XS)"
},
"Qwen2.5-Coder-7B-Instruct (Q2_K)": {
"repo_id": "Qwen/Qwen2.5-Coder-7B-Instruct-GGUF",
"filename": "qwen2.5-coder-7b-instruct-q2_k.gguf",
"description": "Qwen2.5-Coder-7B-Instruct (Q2_K)"
},
}
# ----- Sidebar settings -----
# ----- Sidebar settings -----
with st.sidebar:
st.header("⚙️ Settings")
selected_model_name = st.selectbox("Select Model", list(MODELS.keys()))
system_prompt_base = st.text_area("System Prompt", value="You are a helpful assistant.", height=80)
max_tokens = st.slider("Max tokens", 64, 1024, 256, step=32)
temperature = st.slider("Temperature", 0.1, 2.0, 0.7)
top_k = st.slider("Top-K", 1, 100, 40)
top_p = st.slider("Top-P", 0.1, 1.0, 0.95)
repeat_penalty = st.slider("Repetition Penalty", 1.0, 2.0, 1.1)
enable_search = st.checkbox("Enable Web Search", value=False)
# NEW SETTINGS: Expose search configuration
max_results = st.number_input("Max Results for Context", min_value=1, max_value=20, value=6, step=1)
max_chars_per_result = st.number_input("Max Chars Per Result", min_value=100, max_value=2000, value=600, step=50)
# ---- Define selected model and manage its download/load ----
selected_model = MODELS[selected_model_name]
model_path = os.path.join("models", selected_model["filename"])
os.makedirs("models", exist_ok=True)
def try_load_model(path):
try:
return Llama(
model_path=path,
n_ctx=4096, # Reduced context window
n_threads=2,
n_threads_batch=1,
n_batch=256,
n_gpu_layers=0,
use_mlock=True,
use_mmap=True,
verbose=False,
logits_all=True,
draft_model=LlamaPromptLookupDecoding(num_pred_tokens=2),
)
except Exception as e:
return str(e)
def download_model():
with st.spinner(f"Downloading {selected_model['filename']}..."):
hf_hub_download(
repo_id=selected_model["repo_id"],
filename=selected_model["filename"],
local_dir="./models",
local_dir_use_symlinks=False,
)
def validate_or_download_model():
if not os.path.exists(model_path):
if shutil.disk_usage(".").free < REQUIRED_SPACE_BYTES:
st.info("Insufficient storage. Consider cleaning up old models.")
download_model()
result = try_load_model(model_path)
if isinstance(result, str):
st.warning(f"Initial load failed: {result}\nRe-downloading...")
try:
os.remove(model_path)
except Exception:
pass
download_model()
result = try_load_model(model_path)
if isinstance(result, str):
st.error(f"Model still failed after re-download: {result}")
st.stop()
return result
if st.session_state.model_name != selected_model_name:
if st.session_state.llm is not None:
del st.session_state.llm
gc.collect()
st.session_state.llm = validate_or_download_model()
st.session_state.model_name = selected_model_name
llm = st.session_state.llm
# ---- Display title and existing chat history ----
st.title(f"🧠 {selected_model['description']} (Streamlit + GGUF)")
st.caption(f"Powered by `llama.cpp` | Model: {selected_model['filename']}")
for chat in st.session_state.chat_history:
with st.chat_message(chat["role"]):
st.markdown(chat["content"])
# ---- Chat input and processing ----
user_input = st.chat_input("Ask something...")
if user_input:
if st.session_state.pending_response:
st.warning("Please wait for the assistant to finish responding.")
else:
# Display user input and update chat history
with st.chat_message("user"):
st.markdown(user_input)
st.session_state.chat_history.append({"role": "user", "content": user_input})
st.session_state.pending_response = True
# Use the new settings when retrieving web search context
retrieved_context = (
retrieve_context(user_input, max_results=max_results, max_chars_per_result=max_chars_per_result)
if enable_search else ""
)
st.sidebar.markdown("### Retrieved Context" if enable_search else "Web Search Disabled")
st.sidebar.text(retrieved_context or "No context found.")
# Build augmented query as before...
if enable_search and retrieved_context:
augmented_user_input = (
f"{system_prompt_base.strip()}\n\n"
f"Use the following recent web search context to help answer the query:\n\n"
f"{retrieved_context}\n\n"
f"User Query: {user_input}"
)
else:
augmented_user_input = f"{system_prompt_base.strip()}\n\nUser Query: {user_input}"
# Limit conversation history (last 2 pairs)
MAX_TURNS = 2
trimmed_history = st.session_state.chat_history[-(MAX_TURNS * 2):]
if trimmed_history and trimmed_history[-1]["role"] == "user":
messages = trimmed_history[:-1] + [{"role": "user", "content": augmented_user_input}]
else:
messages = trimmed_history + [{"role": "user", "content": augmented_user_input}]
# ---- Set up a placeholder for the response and queue for streaming tokens ----
visible_placeholder = st.empty()
response_queue = queue.Queue()
# Function to stream LLM response and push incremental updates into the queue
def stream_response(msgs, max_tokens, temp, topk, topp, repeat_penalty):
final_text = ""
try:
stream = llm.create_chat_completion(
messages=msgs,
max_tokens=max_tokens,
temperature=temp,
top_k=topk,
top_p=topp,
repeat_penalty=repeat_penalty,
stream=True,
)
for chunk in stream:
if "choices" in chunk:
delta = chunk["choices"][0]["delta"].get("content", "")
final_text += delta
response_queue.put(delta)
if chunk["choices"][0].get("finish_reason", ""):
break
except Exception as e:
response_queue.put(f"\nError: {e}")
response_queue.put(None) # Signal completion
# Start streaming in a separate thread
stream_thread = threading.Thread(
target=stream_response,
args=(messages, max_tokens, temperature, top_k, top_p, repeat_penalty),
daemon=True
)
stream_thread.start()
# Poll the queue in the main thread for up to 5 seconds
final_response = ""
timeout = 300 # seconds
start_time = time.time()
while True:
try:
update = response_queue.get(timeout=0.1)
if update is None:
break
final_response += update
visible_response = re.sub(r"<think>.*?</think>", "", final_response, flags=re.DOTALL)
visible_response = re.sub(r"<think>.*$", "", visible_response, flags=re.DOTALL)
visible_placeholder.markdown(visible_response)
except queue.Empty:
if time.time() - start_time > timeout:
st.error("Response generation timed out.")
break
st.session_state.chat_history.append({"role": "assistant", "content": final_response})
st.session_state.pending_response = False
gc.collect()