Luigi's picture
usue chat pipeline instead of model and tokenizer individually
ac8e9cc
raw
history blame
11.1 kB
import os
import time
import gc
import threading
from itertools import islice
from datetime import datetime
import gradio as gr
import torch
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
from duckduckgo_search import DDGS
import spaces # Import spaces early to enable ZeroGPU support
# Optional: Disable GPU visibility if you wish to force CPU usage
# os.environ["CUDA_VISIBLE_DEVICES"] = ""
# ------------------------------
# Global Cancellation Event
# ------------------------------
cancel_event = threading.Event()
# ------------------------------
# Torch-Compatible Model Definitions with Adjusted Descriptions
# ------------------------------
MODELS = {
"Taiwan-tinyllama-v1.0-chat": {
"repo_id": "DavidLanz/Taiwan-tinyllama-v1.0-chat",
"description": "Taiwan-tinyllama-v1.0-chat"
},
"Llama-3.2-Taiwan-3B-Instruct": {
"repo_id": "https://huggingface.co/lianghsun/Llama-3.2-Taiwan-3B-Instruct",
"description": "Llama-3.2-Taiwan-3B-Instruct"
},
"MiniCPM3-4B": {
"repo_id": "openbmb/MiniCPM3-4B",
"description": "MiniCPM3-4B"
},
"Qwen2.5-3B-Instruct": {
"repo_id": "Qwen/Qwen2.5-3B-Instruct",
"description": "Qwen2.5-3B-Instruct"
},
"Qwen2.5-7B-Instruct": {
"repo_id": "Qwen/Qwen2.5-7B-Instruct",
"description": "Qwen2.5-7B-Instruct"
},
"Gemma-3-4B-IT": {
"repo_id": "unsloth/gemma-3-4b-it",
"description": "Gemma-3-4B-IT"
},
"Phi-4-mini-Instruct": {
"repo_id": "unsloth/Phi-4-mini-instruct",
"description": "Phi-4-mini-Instruct"
},
"Meta-Llama-3.1-8B-Instruct": {
"repo_id": "MaziyarPanahi/Meta-Llama-3.1-8B-Instruct",
"description": "Meta-Llama-3.1-8B-Instruct"
},
"DeepSeek-R1-Distill-Llama-8B": {
"repo_id": "unsloth/DeepSeek-R1-Distill-Llama-8B",
"description": "DeepSeek-R1-Distill-Llama-8B"
},
"Mistral-7B-Instruct-v0.3": {
"repo_id": "MaziyarPanahi/Mistral-7B-Instruct-v0.3",
"description": "Mistral-7B-Instruct-v0.3"
},
"Qwen2.5-Coder-7B-Instruct": {
"repo_id": "Qwen/Qwen2.5-Coder-7B-Instruct",
"description": "Qwen2.5-Coder-7B-Instruct"
},
}
# Global cache for pipelines to avoid re-loading.
PIPELINES = {}
def load_pipeline(model_name):
"""
Load and cache a transformers pipeline for chat/text-generation.
Uses the model's repo_id from MODELS and caches the pipeline for future use.
"""
global PIPELINES
if model_name in PIPELINES:
return PIPELINES[model_name]
selected_model = MODELS[model_name]
# Create a chat-style text-generation pipeline.
pipe = pipeline(
task="text-generation",
model=selected_model["repo_id"],
tokenizer=selected_model["repo_id"],
trust_remote_code=True,
torch_dtype=torch.bfloat16,
device_map="auto"
)
PIPELINES[model_name] = pipe
return pipe
def retrieve_context(query, max_results=6, max_chars_per_result=600):
"""
Retrieve recent web search context for the given query using DuckDuckGo.
Returns a formatted string with search results.
"""
try:
with DDGS() as ddgs:
results = list(islice(ddgs.text(query, region="wt-wt", safesearch="off", timelimit="y"), max_results))
context = ""
for i, result in enumerate(results, start=1):
title = result.get("title", "No Title")
snippet = result.get("body", "")[:max_chars_per_result]
context += f"Result {i}:\nTitle: {title}\nSnippet: {snippet}\n\n"
return context.strip()
except Exception:
return ""
# ------------------------------
# Chat Response Generation with ZeroGPU using Pipeline
# ------------------------------
@spaces.GPU(duration=60)
def chat_response(user_message, chat_history, system_prompt, enable_search,
max_results, max_chars, model_name, max_tokens, temperature, top_k, top_p, repeat_penalty):
"""
Generate a chat response by utilizing a transformers pipeline.
- Appends the user's message to the conversation history.
- Optionally retrieves web search context and inserts it as an additional system message.
- Uses a cached pipeline (loaded via load_pipeline) to generate a response.
- Returns the updated conversation history and a debug message.
"""
cancel_event.clear()
# Build conversation list from chat history.
conversation = list(chat_history) if chat_history else []
conversation.append({"role": "user", "content": user_message})
# Retrieve web search context if enabled.
debug_message = ""
retrieved_context = ""
if enable_search:
debug_message = "Initiating web search..."
yield conversation, debug_message
search_result = [""]
def do_search():
search_result[0] = retrieve_context(user_message, max_results, max_chars)
search_thread = threading.Thread(target=do_search)
search_thread.start()
search_thread.join(timeout=2)
retrieved_context = search_result[0]
if retrieved_context:
debug_message = f"Web search results:\n\n{retrieved_context}"
# Insert the search context as a system-level message immediately after the original system prompt.
conversation.insert(1, {"role": "system", "content": f"Web search context:\n{retrieved_context}"})
else:
debug_message = "Web search returned no results or timed out."
else:
debug_message = "Web search disabled."
# Append a placeholder for the assistant's response.
conversation.append({"role": "assistant", "content": ""})
try:
# Load the pipeline (cached) for the selected model.
pipe = load_pipeline(model_name)
# Use the pipeline directly with conversation history.
# Note: Many chat pipelines use internal chat templating to properly format the conversation.
response = pipe(
conversation,
max_new_tokens=max_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repeat_penalty,
)
# Extract the assistant's reply.
try:
assistant_text = response[0]["generated_text"][-1]["content"]
except (KeyError, IndexError, TypeError):
assistant_text = response[0]["generated_text"]
# Update the conversation history.
conversation[-1]["content"] = assistant_text
# Yield the complete conversation history and the debug message.
yield conversation, debug_message
except Exception as e:
conversation[-1]["content"] = f"Error: {e}"
yield conversation, debug_message
finally:
gc.collect()
# ------------------------------
# Cancel Function
# ------------------------------
def cancel_generation():
cancel_event.set()
return "Cancellation requested."
# ------------------------------
# Gradio UI Definition
# ------------------------------
with gr.Blocks(title="LLM Inference with ZeroGPU") as demo:
gr.Markdown("## 🧠 ZeroGPU LLM Inference with Web Search")
gr.Markdown("Interact with the model. Select your model, set your system prompt, and adjust parameters on the left.")
with gr.Row():
with gr.Column(scale=3):
default_model = list(MODELS.keys())[0] if MODELS else "No models available"
model_dropdown = gr.Dropdown(
label="Select Model",
choices=list(MODELS.keys()) if MODELS else [],
value=default_model,
info="Choose from available models."
)
today = datetime.now().strftime('%Y-%m-%d')
default_prompt = f"You are a helpful assistant. Today is {today}. Please leverage the latest web data when responding to queries."
system_prompt_text = gr.Textbox(label="System Prompt",
value=default_prompt,
lines=3,
info="Define the base context for the AI's responses.")
gr.Markdown("### Generation Parameters")
max_tokens_slider = gr.Slider(label="Max Tokens", minimum=64, maximum=1024, value=1024, step=32,
info="Maximum tokens for the response.")
temperature_slider = gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, value=0.7, step=0.1,
info="Controls the randomness of the output.")
top_k_slider = gr.Slider(label="Top-K", minimum=1, maximum=100, value=40, step=1,
info="Limits token candidates to the top-k tokens.")
top_p_slider = gr.Slider(label="Top-P (Nucleus Sampling)", minimum=0.1, maximum=1.0, value=0.95, step=0.05,
info="Limits token candidates to a cumulative probability threshold.")
repeat_penalty_slider = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, value=1.1, step=0.1,
info="Penalizes token repetition to improve diversity.")
gr.Markdown("### Web Search Settings")
enable_search_checkbox = gr.Checkbox(label="Enable Web Search", value=False,
info="Include recent search context to improve answers.")
max_results_number = gr.Number(label="Max Search Results", value=6, precision=0,
info="Maximum number of search results to retrieve.")
max_chars_number = gr.Number(label="Max Chars per Result", value=600, precision=0,
info="Maximum characters to retrieve per search result.")
clear_button = gr.Button("Clear Chat")
cancel_button = gr.Button("Cancel Generation")
with gr.Column(scale=7):
chatbot = gr.Chatbot(label="Chat", type="messages")
msg_input = gr.Textbox(label="Your Message", placeholder="Enter your message and press Enter")
search_debug = gr.Markdown(label="Web Search Debug")
def clear_chat():
return [], "", ""
clear_button.click(fn=clear_chat, outputs=[chatbot, msg_input, search_debug])
cancel_button.click(fn=cancel_generation, outputs=search_debug)
# Submission: the chat_response function is now used with the Transformers pipeline.
msg_input.submit(
fn=chat_response,
inputs=[msg_input, chatbot, system_prompt_text, enable_search_checkbox,
max_results_number, max_chars_number, model_dropdown,
max_tokens_slider, temperature_slider, top_k_slider, top_p_slider, repeat_penalty_slider],
outputs=[chatbot, search_debug],
)
demo.launch()