Luigi's picture
use another version of qwen 1.5 moe
96e60d6
raw
history blame
14.6 kB
import streamlit as st
import os, gc, shutil, re, time, threading, queue
from itertools import islice
from llama_cpp import Llama
from llama_cpp.llama_speculative import LlamaPromptLookupDecoding
from huggingface_hub import hf_hub_download
from duckduckgo_search import DDGS
# ------------------------------
# Initialize Session State
# ------------------------------
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
if "pending_response" not in st.session_state:
st.session_state.pending_response = False
if "model_name" not in st.session_state:
st.session_state.model_name = None
if "llm" not in st.session_state:
st.session_state.llm = None
# ------------------------------
# Custom CSS for Improved Look & Feel
# ------------------------------
st.markdown("""
<style>
.chat-container { margin: 1em 0; }
.chat-assistant { background-color: #eef7ff; padding: 1em; border-radius: 10px; margin-bottom: 1em; }
.chat-user { background-color: #e6ffe6; padding: 1em; border-radius: 10px; margin-bottom: 1em; }
.message-time { font-size: 0.8em; color: #555; text-align: right; }
.loading-spinner { font-size: 1.1em; color: #ff6600; }
</style>
""", unsafe_allow_html=True)
# ------------------------------
# Required Storage and Model Definitions
# ------------------------------
REQUIRED_SPACE_BYTES = 5 * 1024 ** 3 # 5 GB
MODELS = {
"Qwen1.5-MoE-A2.7B-20-experts-SFT-trained (Q3_K_S)": {
"repo_id": "mradermacher/Qwen1.5-MoE-A2.7B-20-experts-SFT-trained-GGUF",
"filename": "Qwen1.5-MoE-A2.7B-20-experts-SFT-trained.Q3_K_S.gguf",
"description": "Qwen1.5-MoE-A2.7B-20-experts-SFT-trained (Q3_K_S)"
},
"Qwen2.5-MOE-6x1.5B-DeepSeek-Reasoning-e32-8.71B": {
"repo_id": "DavidAU/Qwen2.5-MOE-6x1.5B-DeepSeek-Reasoning-e32-8.71B-gguf",
"filename": "QW-6X1.5B-DeepSeek-Qwen-LAM-e32-Q4_K_S.gguf",
"description": "Qwen2.5-MOE-6x1.5B-DeepSeek-Reasoning-e32-8.71B"
},
"Qwen2.5-3B-Instruct (Q4_K_M)": {
"repo_id": "Qwen/Qwen2.5-3B-Instruct-GGUF",
"filename": "qwen2.5-3b-instruct-q4_k_m.gguf",
"description": "Qwen2.5-3B-Instruct (Q4_K_M)"
},
"Qwen2.5-7B-Instruct (Q2_K)": {
"repo_id": "Qwen/Qwen2.5-7B-Instruct-GGUF",
"filename": "qwen2.5-7b-instruct-q2_k.gguf",
"description": "Qwen2.5-7B Instruct (Q2_K)"
},
"Gemma-3-4B-IT (Q4_K_M)": {
"repo_id": "unsloth/gemma-3-4b-it-GGUF",
"filename": "gemma-3-4b-it-Q4_K_M.gguf",
"description": "Gemma 3 4B IT (Q4_K_M)"
},
"Phi-4-mini-Instruct (Q4_K_M)": {
"repo_id": "unsloth/Phi-4-mini-instruct-GGUF",
"filename": "Phi-4-mini-instruct-Q4_K_M.gguf",
"description": "Phi-4 Mini Instruct (Q4_K_M)"
},
"Meta-Llama-3.1-8B-Instruct (Q2_K)": {
"repo_id": "MaziyarPanahi/Meta-Llama-3.1-8B-Instruct-GGUF",
"filename": "Meta-Llama-3.1-8B-Instruct.Q2_K.gguf",
"description": "Meta-Llama-3.1-8B-Instruct (Q2_K)"
},
"DeepSeek-R1-Distill-Llama-8B (Q2_K)": {
"repo_id": "unsloth/DeepSeek-R1-Distill-Llama-8B-GGUF",
"filename": "DeepSeek-R1-Distill-Llama-8B-Q2_K.gguf",
"description": "DeepSeek-R1-Distill-Llama-8B (Q2_K)"
},
"Mistral-7B-Instruct-v0.3 (IQ3_XS)": {
"repo_id": "MaziyarPanahi/Mistral-7B-Instruct-v0.3-GGUF",
"filename": "Mistral-7B-Instruct-v0.3.IQ3_XS.gguf",
"description": "Mistral-7B-Instruct-v0.3 (IQ3_XS)"
},
"Qwen2.5-Coder-7B-Instruct (Q2_K)": {
"repo_id": "Qwen/Qwen2.5-Coder-7B-Instruct-GGUF",
"filename": "qwen2.5-coder-7b-instruct-q2_k.gguf",
"description": "Qwen2.5-Coder-7B-Instruct (Q2_K)"
},
}
# ------------------------------
# Helper Functions
# ------------------------------
def retrieve_context(query, max_results=6, max_chars_per_result=600):
"""Retrieve web search context using DuckDuckGo."""
try:
with DDGS() as ddgs:
results = list(islice(ddgs.text(query, region="wt-wt", safesearch="off", timelimit="y"), max_results))
context = ""
for i, result in enumerate(results, start=1):
title = result.get("title", "No Title")
snippet = result.get("body", "")[:max_chars_per_result]
context += f"Result {i}:\nTitle: {title}\nSnippet: {snippet}\n\n"
return context.strip()
except Exception as e:
st.error(f"Error during web retrieval: {e}")
return ""
def try_load_model(model_path):
"""Attempt to initialize the model from a specified path."""
try:
return Llama(
model_path=model_path,
n_ctx=4096,
n_threads=2,
n_threads_batch=1,
n_batch=256,
n_gpu_layers=0,
use_mlock=True,
use_mmap=True,
verbose=False,
logits_all=True,
draft_model=LlamaPromptLookupDecoding(num_pred_tokens=2),
)
except Exception as e:
return str(e)
def download_model(selected_model):
"""Download the model using Hugging Face Hub."""
with st.spinner(f"Downloading {selected_model['filename']}..."):
hf_hub_download(
repo_id=selected_model["repo_id"],
filename=selected_model["filename"],
local_dir="./models",
local_dir_use_symlinks=False,
)
def validate_or_download_model(selected_model):
"""Ensure the model is available and loaded properly; download if necessary."""
model_path = os.path.join("models", selected_model["filename"])
os.makedirs("models", exist_ok=True)
if not os.path.exists(model_path):
if shutil.disk_usage(".").free < REQUIRED_SPACE_BYTES:
st.info("Insufficient storage space. Consider cleaning up old models.")
download_model(selected_model)
result = try_load_model(model_path)
if isinstance(result, str):
st.warning(f"Initial model load failed: {result}\nAttempting re-download...")
try:
os.remove(model_path)
except Exception:
pass
download_model(selected_model)
result = try_load_model(model_path)
if isinstance(result, str):
st.error(f"Model failed to load after re-download: {result}")
st.stop()
return result
# ------------------------------
# Caching the Model Loading
# ------------------------------
@st.cache_resource
def load_cached_model(selected_model):
return validate_or_download_model(selected_model)
def stream_response(llm, messages, max_tokens, temperature, top_k, top_p, repeat_penalty, response_queue):
"""Stream the model response token-by-token."""
final_text = ""
try:
stream = llm.create_chat_completion(
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repeat_penalty=repeat_penalty,
stream=True,
)
for chunk in stream:
if "choices" in chunk:
delta = chunk["choices"][0]["delta"].get("content", "")
final_text += delta
response_queue.put(delta)
if chunk["choices"][0].get("finish_reason", ""):
break
except Exception as e:
response_queue.put(f"\nError: {e}")
response_queue.put(None) # Signal the end of streaming
# ------------------------------
# Sidebar: Settings and Advanced Options
# ------------------------------
with st.sidebar:
st.header("⚙️ Settings")
# Basic Settings
selected_model_name = st.selectbox("Select Model", list(MODELS.keys()),
help="Choose from the available model configurations.")
system_prompt_base = st.text_area("System Prompt",
value="You are a helpful assistant.",
height=80,
help="Define the base context for the AI's responses.")
# Generation Parameters
st.subheader("Generation Parameters")
max_tokens = st.slider("Max Tokens", 64, 1024, 256, step=32,
help="The maximum number of tokens the assistant can generate.")
temperature = st.slider("Temperature", 0.1, 2.0, 0.7,
help="Controls randomness. Lower values are more deterministic.")
top_k = st.slider("Top-K", 1, 100, 40,
help="Limits the token candidates to the top-k tokens.")
top_p = st.slider("Top-P", 0.1, 1.0, 0.95,
help="Nucleus sampling parameter; restricts to a cumulative probability.")
repeat_penalty = st.slider("Repetition Penalty", 1.0, 2.0, 1.1,
help="Penalizes token repetition to improve output variety.")
# Advanced Settings using expandable sections
with st.expander("Web Search Settings"):
enable_search = st.checkbox("Enable Web Search", value=False,
help="Include recent web search context to augment the prompt.")
max_results = st.number_input("Max Results for Context", min_value=1, max_value=20, value=6, step=1,
help="How many search results to use.")
max_chars_per_result = st.number_input("Max Chars per Result", min_value=100, max_value=2000, value=600, step=50,
help="Max characters to extract from each search result.")
# ------------------------------
# Model Loading/Reloading if Needed
# ------------------------------
selected_model = MODELS[selected_model_name]
if st.session_state.model_name != selected_model_name:
with st.spinner("Loading selected model..."):
st.session_state.llm = load_cached_model(selected_model)
st.session_state.model_name = selected_model_name
llm = st.session_state.llm
# ------------------------------
# Main Title and Chat History Display
# ------------------------------
st.title(f"🧠 {selected_model['description']} (Streamlit + GGUF)")
st.caption(f"Powered by `llama.cpp` | Model: {selected_model['filename']}")
# Render chat history with improved styling
for chat in st.session_state.chat_history:
role = chat["role"]
content = chat["content"]
if role == "assistant":
st.markdown(f"<div class='chat-assistant'>{content}</div>", unsafe_allow_html=True)
else:
st.markdown(f"<div class='chat-user'>{content}</div>", unsafe_allow_html=True)
# ------------------------------
# Chat Input and Processing
# ------------------------------
user_input = st.chat_input("Your message...")
if user_input:
if st.session_state.pending_response:
st.warning("Please wait until the current response is finished.")
else:
# Append user message with timestamp (if desired)
timestamp = time.strftime("%H:%M")
st.session_state.chat_history.append({"role": "user", "content": f"{user_input}\n\n<span class='message-time'>{timestamp}</span>"})
with st.chat_message("user"):
st.markdown(f"<div class='chat-user'>{user_input}</div>", unsafe_allow_html=True)
st.session_state.pending_response = True
# Retrieve web search context if enabled
retrieved_context = ""
if enable_search:
retrieved_context = retrieve_context(user_input, max_results=max_results, max_chars_per_result=max_chars_per_result)
with st.sidebar:
st.markdown("### Retrieved Context")
st.text_area("", value=retrieved_context or "No context found.", height=150)
# Augment the user prompt with the system prompt and optional web context
if enable_search and retrieved_context:
augmented_user_input = (
f"{system_prompt_base.strip()}\n\n"
f"Use the following recent web search context to help answer the query:\n\n"
f"{retrieved_context}\n\n"
f"User Query: {user_input}"
)
else:
augmented_user_input = f"{system_prompt_base.strip()}\n\nUser Query: {user_input}"
# Limit conversation history to the last few turns (for context)
MAX_TURNS = 2
trimmed_history = st.session_state.chat_history[-(MAX_TURNS * 2):]
if trimmed_history and trimmed_history[-1]["role"] == "user":
messages = trimmed_history[:-1] + [{"role": "user", "content": augmented_user_input}]
else:
messages = trimmed_history + [{"role": "user", "content": augmented_user_input}]
# Set up a placeholder for displaying the streaming response and a queue for tokens
visible_placeholder = st.empty()
progress_bar = st.progress(0)
response_queue = queue.Queue()
# Start streaming response in a separate thread
stream_thread = threading.Thread(
target=stream_response,
args=(llm, messages, max_tokens, temperature, top_k, top_p, repeat_penalty, response_queue),
daemon=True
)
stream_thread.start()
# Poll the queue to update the UI with incremental tokens and update progress
final_response = ""
timeout = 300 # seconds
start_time = time.time()
progress = 0
while True:
try:
update = response_queue.get(timeout=0.1)
if update is None:
break
final_response += update
# Remove any special tags from the output (for cleaner UI)
visible_response = re.sub(r"<think>.*?</think>", "", final_response, flags=re.DOTALL)
visible_placeholder.markdown(f"<div class='chat-assistant'>{visible_response}</div>", unsafe_allow_html=True)
progress = min(progress + 1, 100)
progress_bar.progress(progress)
start_time = time.time()
except queue.Empty:
if time.time() - start_time > timeout:
st.error("Response generation timed out.")
break
# Append assistant response with timestamp
timestamp = time.strftime("%H:%M")
st.session_state.chat_history.append({"role": "assistant", "content": f"{final_response}\n\n<span class='message-time'>{timestamp}</span>"})
st.session_state.pending_response = False
progress_bar.empty() # Clear progress bar
gc.collect()