Luigi's picture
increase rep penality as on many slms, reptition is frequent.
8a7c97d
raw
history blame
15.7 kB
import os
import time
import gc
import threading
from itertools import islice
from datetime import datetime
import re # for parsing <think> blocks
import gradio as gr
import torch
from transformers import pipeline, TextIteratorStreamer
from transformers import AutoTokenizer
from duckduckgo_search import DDGS
import spaces # Import spaces early to enable ZeroGPU support
# Optional: Disable GPU visibility if you wish to force CPU usage
# os.environ["CUDA_VISIBLE_DEVICES"] = ""
# ------------------------------
# Global Cancellation Event
# ------------------------------
cancel_event = threading.Event()
# ------------------------------
# Torch-Compatible Model Definitions with Adjusted Descriptions
# ------------------------------
MODELS = {
"Taiwan-ELM-1_1B-Instruct": {"repo_id": "liswei/Taiwan-ELM-1_1B-Instruct", "description": "Taiwan-ELM-1_1B-Instruct"},
"Taiwan-ELM-270M-Instruct": {"repo_id": "liswei/Taiwan-ELM-270M-Instruct", "description": "Taiwan-ELM-270M-Instruct"},
# "Granite-4.0-Tiny-Preview": {"repo_id": "ibm-granite/granite-4.0-tiny-preview", "description": "Granite-4.0-Tiny-Preview"},
"Qwen3-0.6B": {"repo_id":"Qwen/Qwen3-0.6B","description":"Dense causal language model with 0.6 B total parameters (0.44 B non-embedding), 28 transformer layers, 16 query heads & 8 KV heads, native 32 768-token context window, dual-mode generation, full multilingual & agentic capabilities."},
"Qwen3-1.7B": {"repo_id":"Qwen/Qwen3-1.7B","description":"Dense causal language model with 1.7 B total parameters (1.4 B non-embedding), 28 layers, 16 query heads & 8 KV heads, 32 768-token context, stronger reasoning vs. 0.6 B variant, dual-mode inference, instruction following across 100+ languages."},
"Qwen3-4B": {"repo_id":"Qwen/Qwen3-4B","description":"Dense causal language model with 4.0 B total parameters (3.6 B non-embedding), 36 layers, 32 query heads & 8 KV heads, native 32 768-token context (extendable to 131 072 via YaRN), balanced mid-range capacity & long-context reasoning."},
"Qwen3-8B": {"repo_id":"Qwen/Qwen3-8B","description":"Dense causal language model with 8.2 B total parameters (6.95 B non-embedding), 36 layers, 32 query heads & 8 KV heads, 32 768-token context (131 072 via YaRN), excels at multilingual instruction following & zero-shot tasks."},
"Qwen3-14B": {"repo_id":"Qwen/Qwen3-14B","description":"Dense causal language model with 14.8 B total parameters (13.2 B non-embedding), 40 layers, 40 query heads & 8 KV heads, 32 768-token context (131 072 via YaRN), enhanced human preference alignment & advanced agent integration."},
# "Qwen3-32B": {"repo_id":"Qwen/Qwen3-32B","description":"Dense causal language model with 32.8 B total parameters (31.2 B non-embedding), 64 layers, 64 query heads & 8 KV heads, 32 768-token context (131 072 via YaRN), flagship variant delivering state-of-the-art reasoning & instruction following."},
# "Qwen3-30B-A3B": {"repo_id":"Qwen/Qwen3-30B-A3B","description":"Mixture-of-Experts model with 30.5 B total parameters (29.9 B non-embedding, 3.3 B activated per token), 48 layers, 128 experts (8 activated per token), 32 query heads & 4 KV heads, 32 768-token context (131 072 via YaRN), MoE routing for scalable specialized reasoning."},
# "Qwen3-235B-A22B":{"repo_id":"Qwen/Qwen3-235B-A22B","description":"Mixture-of-Experts model with 235 B total parameters (234 B non-embedding, 22 B activated per token), 94 layers, 128 experts (8 activated per token), 64 query heads & 4 KV heads, 32 768-token context (131 072 via YaRN), ultra-scale reasoning & agentic workflows."},
"Gemma-3-4B-IT": {"repo_id": "unsloth/gemma-3-4b-it", "description": "Gemma-3-4B-IT"},
"SmolLM2_135M_Grpo_Gsm8k":{"repo_id":"prithivMLmods/SmolLM2_135M_Grpo_Gsm8k", "desscription":"SmolLM2_135M_Grpo_Gsm8k"},
"SmolLM2-135M-Instruct-TaiwanChat": {"repo_id": "Luigi/SmolLM2-135M-Instruct-TaiwanChat", "description": "SmolLM2‑135M Instruct fine-tuned on TaiwanChat"},
"SmolLM2-135M-Instruct": {"repo_id": "HuggingFaceTB/SmolLM2-135M-Instruct", "description": "Original SmolLM2‑135M Instruct"},
"SmolLM2-360M-Instruct-TaiwanChat": {"repo_id": "Luigi/SmolLM2-360M-Instruct-TaiwanChat", "description": "SmolLM2‑360M Instruct fine-tuned on TaiwanChat"},
"SmolLM2-360M-Instruct": {"repo_id": "HuggingFaceTB/SmolLM2-360M-Instruct", "description": "Original SmolLM2‑360M Instruct"},
"Llama-3.2-Taiwan-3B-Instruct": {"repo_id": "lianghsun/Llama-3.2-Taiwan-3B-Instruct", "description": "Llama-3.2-Taiwan-3B-Instruct"},
"MiniCPM3-4B": {"repo_id": "openbmb/MiniCPM3-4B", "description": "MiniCPM3-4B"},
"Qwen2.5-3B-Instruct": {"repo_id": "Qwen/Qwen2.5-3B-Instruct", "description": "Qwen2.5-3B-Instruct"},
"Qwen2.5-7B-Instruct": {"repo_id": "Qwen/Qwen2.5-7B-Instruct", "description": "Qwen2.5-7B-Instruct"},
"Phi-4-mini-Reasoning": {"repo_id": "microsoft/Phi-4-mini-reasoning", "description": "Phi-4-mini-Reasoning"},
# "Phi-4-Reasoning": {"repo_id": "microsoft/Phi-4-reasoning", "description": "Phi-4-Reasoning"},
"Phi-4-mini-Instruct": {"repo_id": "microsoft/Phi-4-mini-instruct", "description": "Phi-4-mini-Instruct"},
"Meta-Llama-3.1-8B-Instruct": {"repo_id": "MaziyarPanahi/Meta-Llama-3.1-8B-Instruct", "description": "Meta-Llama-3.1-8B-Instruct"},
"DeepSeek-R1-Distill-Llama-8B": {"repo_id": "unsloth/DeepSeek-R1-Distill-Llama-8B", "description": "DeepSeek-R1-Distill-Llama-8B"},
"Mistral-7B-Instruct-v0.3": {"repo_id": "MaziyarPanahi/Mistral-7B-Instruct-v0.3", "description": "Mistral-7B-Instruct-v0.3"},
"Qwen2.5-Coder-7B-Instruct": {"repo_id": "Qwen/Qwen2.5-Coder-7B-Instruct", "description": "Qwen2.5-Coder-7B-Instruct"},
"Qwen2.5-Omni-3B": {"repo_id": "Qwen/Qwen2.5-Omni-3B", "description": "Qwen2.5-Omni-3B"},
"MiMo-7B-RL": {"repo_id": "XiaomiMiMo/MiMo-7B-RL", "description": "MiMo-7B-RL"},
}
# Global cache for pipelines to avoid re-loading.
PIPELINES = {}
def load_pipeline(model_name):
"""
Load and cache a transformers pipeline for text generation.
Tries bfloat16, falls back to float16 or float32 if unsupported.
"""
global PIPELINES
if model_name in PIPELINES:
return PIPELINES[model_name]
repo = MODELS[model_name]["repo_id"]
tokenizer = AutoTokenizer.from_pretrained(repo)
for dtype in (torch.bfloat16, torch.float16, torch.float32):
try:
pipe = pipeline(
task="text-generation",
model=repo,
tokenizer=tokenizer,
trust_remote_code=True,
torch_dtype=dtype,
device_map="auto"
)
PIPELINES[model_name] = pipe
return pipe
except Exception:
continue
# Final fallback
pipe = pipeline(
task="text-generation",
model=repo,
tokenizer=tokenizer,
trust_remote_code=True,
device_map="auto"
)
PIPELINES[model_name] = pipe
return pipe
def retrieve_context(query, max_results=6, max_chars=600):
"""
Retrieve search snippets from DuckDuckGo (runs in background).
Returns a list of result strings.
"""
try:
with DDGS() as ddgs:
return [f"{i+1}. {r.get('title','No Title')} - {r.get('body','')[:max_chars]}"
for i, r in enumerate(islice(ddgs.text(query, region="wt-wt", safesearch="off", timelimit="y"), max_results))]
except Exception:
return []
def format_conversation(history, system_prompt, tokenizer):
if hasattr(tokenizer, "chat_template") and tokenizer.chat_template:
messages = [{"role": "system", "content": system_prompt.strip()}] + history
return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, enable_thinking=True)
else:
# Fallback for base LMs without chat template
prompt = system_prompt.strip() + "\n"
for msg in history:
if msg['role'] == 'user':
prompt += "User: " + msg['content'].strip() + "\n"
elif msg['role'] == 'assistant':
prompt += "Assistant: " + msg['content'].strip() + "\n"
if not prompt.strip().endswith("Assistant:"):
prompt += "Assistant: "
return prompt
@spaces.GPU(duration=60)
def chat_response(user_msg, chat_history, system_prompt,
enable_search, max_results, max_chars,
model_name, max_tokens, temperature,
top_k, top_p, repeat_penalty, search_timeout):
"""
Generates streaming chat responses, optionally with background web search.
"""
cancel_event.clear()
history = list(chat_history or [])
history.append({'role': 'user', 'content': user_msg})
# Launch web search if enabled
debug = ''
search_results = []
if enable_search:
debug = 'Search task started.'
thread_search = threading.Thread(
target=lambda: search_results.extend(
retrieve_context(user_msg, int(max_results), int(max_chars))
)
)
thread_search.daemon = True
thread_search.start()
else:
debug = 'Web search disabled.'
try:
# merge any fetched search results into the system prompt
if search_results:
enriched = system_prompt.strip() + "\n\nRelevant context:\n" + "\n".join(search_results)
else:
enriched = system_prompt
# wait up to 1s for snippets, then replace debug with them
if enable_search:
thread_search.join(timeout=float(search_timeout))
if search_results:
debug = "### Search results merged into prompt\n\n" + "\n".join(
f"- {r}" for r in search_results
)
else:
debug = "*No web search results found.*"
# merge fetched snippets into the system prompt
if search_results:
enriched = system_prompt.strip() + "\n\nRelevant context:\n" + "\n".join(search_results)
else:
enriched = system_prompt
pipe = load_pipeline(model_name)
prompt = format_conversation(history, enriched, pipe.tokenizer)
prompt_debug = f"\n\n--- Prompt Preview ---\n```\n{prompt}\n```"
streamer = TextIteratorStreamer(pipe.tokenizer,
skip_prompt=True,
skip_special_tokens=True)
gen_thread = threading.Thread(
target=pipe,
args=(prompt,),
kwargs={
'max_new_tokens': max_tokens,
'temperature': temperature,
'top_k': top_k,
'top_p': top_p,
'repetition_penalty': repeat_penalty,
'streamer': streamer,
'return_full_text': False,
}
)
gen_thread.start()
# Buffers for thought vs answer
thought_buf = ''
answer_buf = ''
in_thought = False
# Stream tokens
for chunk in streamer:
if cancel_event.is_set():
break
text = chunk
# Detect start of thinking
if not in_thought and '<think>' in text:
in_thought = True
# Insert thought placeholder
history.append({
'role': 'assistant',
'content': '',
'metadata': {'title': '💭 Thought'}
})
# Capture after opening tag
after = text.split('<think>', 1)[1]
thought_buf += after
# If closing tag in same chunk
if '</think>' in thought_buf:
before, after2 = thought_buf.split('</think>', 1)
history[-1]['content'] = before.strip()
in_thought = False
# Start answer buffer
answer_buf = after2
history.append({'role': 'assistant', 'content': answer_buf})
else:
history[-1]['content'] = thought_buf
yield history, debug
continue
# Continue thought streaming
if in_thought:
thought_buf += text
if '</think>' in thought_buf:
before, after2 = thought_buf.split('</think>', 1)
history[-1]['content'] = before.strip()
in_thought = False
# Start answer buffer
answer_buf = after2
history.append({'role': 'assistant', 'content': answer_buf})
else:
history[-1]['content'] = thought_buf
yield history, debug
continue
# Stream answer
if not answer_buf:
history.append({'role': 'assistant', 'content': ''})
answer_buf += text
history[-1]['content'] = answer_buf
yield history, debug
gen_thread.join()
yield history, debug + prompt_debug
except Exception as e:
history.append({'role': 'assistant', 'content': f"Error: {e}"})
yield history, debug
finally:
gc.collect()
def cancel_generation():
cancel_event.set()
return 'Generation cancelled.'
def update_default_prompt(enable_search):
today = datetime.now().strftime('%Y-%m-%d')
return f"You are a helpful assistant. Today is {today}."
# ------------------------------
# Gradio UI
# ------------------------------
with gr.Blocks(title="LLM Inference with ZeroGPU") as demo:
gr.Markdown("## 🧠 ZeroGPU LLM Inference with Web Search")
gr.Markdown("Interact with the model. Select parameters and chat below.")
with gr.Row():
with gr.Column(scale=3):
model_dd = gr.Dropdown(label="Select Model", choices=list(MODELS.keys()), value=list(MODELS.keys())[0])
search_chk = gr.Checkbox(label="Enable Web Search", value=True)
sys_prompt = gr.Textbox(label="System Prompt", lines=3, value=update_default_prompt(search_chk.value))
gr.Markdown("### Generation Parameters")
max_tok = gr.Slider(64, 16384, value=2048, step=32, label="Max Tokens")
temp = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature")
k = gr.Slider(1, 100, value=40, step=1, label="Top-K")
p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-P")
rp = gr.Slider(1.0, 2.0, value=1.5, step=0.1, label="Repetition Penalty")
gr.Markdown("### Web Search Settings")
mr = gr.Number(value=6, precision=0, label="Max Results")
mc = gr.Number(value=600, precision=0, label="Max Chars/Result")
st = gr.Slider(minimum=0.0, maximum=30.0, step=0.5, value=5.0, label="Search Timeout (s)")
clr = gr.Button("Clear Chat")
cnl = gr.Button("Cancel Generation")
with gr.Column(scale=7):
chat = gr.Chatbot(type="messages")
txt = gr.Textbox(placeholder="Type your message and press Enter...")
dbg = gr.Markdown()
search_chk.change(fn=update_default_prompt, inputs=search_chk, outputs=sys_prompt)
clr.click(fn=lambda: ([], "", ""), outputs=[chat, txt, dbg])
cnl.click(fn=cancel_generation, outputs=dbg)
txt.submit(fn=chat_response,
inputs=[txt, chat, sys_prompt, search_chk, mr, mc,
model_dd, max_tok, temp, k, p, rp, st],
outputs=[chat, dbg])
demo.launch()