import os import time import gc import threading from itertools import islice from datetime import datetime import gradio as gr import torch from transformers import pipeline, TextIteratorStreamer from duckduckgo_search import DDGS import spaces # Import spaces early to enable ZeroGPU support # Optional: Disable GPU visibility if you wish to force CPU usage # os.environ["CUDA_VISIBLE_DEVICES"] = "" # ------------------------------ # Global Cancellation Event # ------------------------------ cancel_event = threading.Event() # ------------------------------ # Torch-Compatible Model Definitions with Adjusted Descriptions # ------------------------------ MODELS = { "Gemma-3-4B-IT": {"repo_id": "unsloth/gemma-3-4b-it", "description": "Gemma-3-4B-IT"}, "SmolLM2-135M-Instruct-TaiwanChat": {"repo_id": "Luigi/SmolLM2-135M-Instruct-TaiwanChat", "description": "SmolLM2‑135M Instruct fine-tuned on TaiwanChat"}, "SmolLM2-135M-Instruct": {"repo_id": "HuggingFaceTB/SmolLM2-135M-Instruct", "description": "Original SmolLM2‑135M Instruct"}, "Llama-3.2-Taiwan-3B-Instruct": {"repo_id": "lianghsun/Llama-3.2-Taiwan-3B-Instruct", "description": "Llama-3.2-Taiwan-3B-Instruct"}, "MiniCPM3-4B": {"repo_id": "openbmb/MiniCPM3-4B", "description": "MiniCPM3-4B"}, "Qwen2.5-3B-Instruct": {"repo_id": "Qwen/Qwen2.5-3B-Instruct", "description": "Qwen2.5-3B-Instruct"}, "Qwen2.5-7B-Instruct": {"repo_id": "Qwen/Qwen2.5-7B-Instruct", "description": "Qwen2.5-7B-Instruct"}, "Phi-4-mini-Instruct": {"repo_id": "unsloth/Phi-4-mini-instruct", "description": "Phi-4-mini-Instruct"}, "Meta-Llama-3.1-8B-Instruct": {"repo_id": "MaziyarPanahi/Meta-Llama-3.1-8B-Instruct", "description": "Meta-Llama-3.1-8B-Instruct"}, "DeepSeek-R1-Distill-Llama-8B": {"repo_id": "unsloth/DeepSeek-R1-Distill-Llama-8B", "description": "DeepSeek-R1-Distill-Llama-8B"}, "Mistral-7B-Instruct-v0.3": {"repo_id": "MaziyarPanahi/Mistral-7B-Instruct-v0.3", "description": "Mistral-7B-Instruct-v0.3"}, "Qwen2.5-Coder-7B-Instruct": {"repo_id": "Qwen/Qwen2.5-Coder-7B-Instruct", "description": "Qwen2.5-Coder-7B-Instruct"}, } # Global cache for pipelines to avoid re-loading. PIPELINES = {} def load_pipeline(model_name): """ Load and cache a transformers pipeline for text generation. Tries bfloat16, falls back to float16 or float32 if unsupported. """ global PIPELINES if model_name in PIPELINES: return PIPELINES[model_name] repo = MODELS[model_name]["repo_id"] for dtype in (torch.bfloat16, torch.float16, torch.float32): try: pipe = pipeline( task="text-generation", model=repo, tokenizer=repo, trust_remote_code=True, torch_dtype=dtype, device_map="auto" ) PIPELINES[model_name] = pipe return pipe except Exception: continue # Final fallback pipe = pipeline( task="text-generation", model=repo, tokenizer=repo, trust_remote_code=True, device_map="auto" ) PIPELINES[model_name] = pipe return pipe def retrieve_context(query, max_results=6, max_chars=600): """ Retrieve search snippets from DuckDuckGo (runs in background). Returns a list of result strings. """ try: with DDGS() as ddgs: return [f"{i+1}. {r.get('title','No Title')} - {r.get('body','')[:max_chars]}" for i, r in enumerate(islice(ddgs.text(query, region="wt-wt", safesearch="off", timelimit="y"), max_results))] except Exception: return [] def format_conversation(history, system_prompt): """ Flatten chat history and system prompt into a single string. """ prompt = system_prompt.strip() + "\n" for msg in history: if msg['role'] == 'user': prompt += "User: " + msg['content'].strip() + "\n" elif msg['role'] == 'assistant': prompt += "Assistant: " + msg['content'].strip() + "\n" else: prompt += msg['content'].strip() + "\n" if not prompt.strip().endswith("Assistant:"): prompt += "Assistant: " return prompt @spaces.GPU(duration=60) def chat_response(user_msg, chat_history, system_prompt, enable_search, max_results, max_chars, model_name, max_tokens, temperature, top_k, top_p, repeat_penalty): """ Generates streaming chat responses, optionally with background web search. """ cancel_event.clear() history = list(chat_history or []) history.append({'role': 'user', 'content': user_msg}) # Launch web search if enabled debug = '' search_results = [] if enable_search: debug = 'Search task started.' thread_search = threading.Thread( target=lambda: search_results.extend( retrieve_context(user_msg, int(max_results), int(max_chars)) ) ) thread_search.daemon = True thread_search.start() else: debug = 'Web search disabled.' # Prepare assistant placeholder history.append({'role': 'assistant', 'content': ''}) try: # merge any fetched search results into the system prompt if search_results: enriched = system_prompt.strip() + "\n\nRelevant context:\n" + "\n".join(search_results) else: enriched = system_prompt # wait up to 1s for snippets, then replace debug with them if enable_search: thread_search.join(timeout=1.0) if search_results: debug = "### Search results merged into prompt\n\n" + "\n".join( f"- {r}" for r in search_results ) else: debug = "*No web search results found.*" # merge fetched snippets into the system prompt if search_results: enriched = system_prompt.strip() + "\n\nRelevant context:\n" + "\n".join(search_results) else: enriched = system_prompt prompt = format_conversation(history, enriched) pipe = load_pipeline(model_name) streamer = TextIteratorStreamer(pipe.tokenizer, skip_prompt=True, skip_special_tokens=True) gen_thread = threading.Thread( target=pipe, args=(prompt,), kwargs={ 'max_new_tokens': max_tokens, 'temperature': temperature, 'top_k': top_k, 'top_p': top_p, 'repetition_penalty': repeat_penalty, 'streamer': streamer, 'return_full_text': False } ) gen_thread.start() assistant_text = '' for chunk in streamer: if cancel_event.is_set(): break assistant_text += chunk history[-1]['content'] = assistant_text # Show debug only once yield history, debug gen_thread.join() except Exception as e: history[-1]['content'] = f"Error: {e}" yield history, debug finally: gc.collect() def cancel_generation(): cancel_event.set() return 'Generation cancelled.' def update_default_prompt(enable_search): today = datetime.now().strftime('%Y-%m-%d') return f"You are a helpful assistant. Today is {today}." # ------------------------------ # Gradio UI # ------------------------------ with gr.Blocks(title="LLM Inference with ZeroGPU") as demo: gr.Markdown("## 🧠 ZeroGPU LLM Inference with Web Search") gr.Markdown("Interact with the model. Select parameters and chat below.") with gr.Row(): with gr.Column(scale=3): model_dd = gr.Dropdown(label="Select Model", choices=list(MODELS.keys()), value=list(MODELS.keys())[0]) search_chk = gr.Checkbox(label="Enable Web Search", value=True) sys_prompt = gr.Textbox(label="System Prompt", lines=3, value=update_default_prompt(search_chk.value)) gr.Markdown("### Generation Parameters") max_tok = gr.Slider(64, 1024, value=512, step=32, label="Max Tokens") temp = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature") k = gr.Slider(1, 100, value=40, step=1, label="Top-K") p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-P") rp = gr.Slider(1.0, 2.0, value=1.1, step=0.1, label="Repetition Penalty") gr.Markdown("### Web Search Settings") mr = gr.Number(value=6, precision=0, label="Max Results") mc = gr.Number(value=600, precision=0, label="Max Chars/Result") clr = gr.Button("Clear Chat") cnl = gr.Button("Cancel Generation") with gr.Column(scale=7): chat = gr.Chatbot(type="messages") txt = gr.Textbox(placeholder="Type your message and press Enter...") dbg = gr.Markdown() search_chk.change(fn=update_default_prompt, inputs=search_chk, outputs=sys_prompt) clr.click(fn=lambda: ([], "", ""), outputs=[chat, txt, dbg]) cnl.click(fn=cancel_generation, outputs=dbg) txt.submit(fn=chat_response, inputs=[txt, chat, sys_prompt, search_chk, mr, mc, model_dd, max_tok, temp, k, p, rp], outputs=[chat, dbg]) demo.launch()