Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import os | |
# --- Configuration --- | |
# IMPORTANT: Replace with the path to your locally downloaded model or a Hugging Face model ID. | |
# Examples: | |
# LOCAL_MODEL_PATH = "/path/to/your/downloaded/qwen-1.5b-instruct" | |
# HUGGINGFACE_MODEL_ID = "Qwen/Qwen1.5-1.8B-Chat" # For a smaller Qwen model for local testing | |
HUGGINGFACE_MODEL_ID = "HuggingFaceH4/Qwen2.5-1.5B-Instruct-gkd" | |
# You might need to adjust TORCH_DTYPE based on your GPU and model support | |
# torch.float16 (FP16) is common for inference, torch.bfloat16 for newer GPUs | |
TORCH_DTYPE = torch.float16 # or torch.bfloat16 or torch.float32 | |
# Generation parameters (can be adjusted for different response styles) | |
MAX_NEW_TOKENS = 512 | |
DO_SAMPLE = True | |
TEMPERATURE = 0.7 | |
TOP_K = 50 | |
TOP_P = 0.95 | |
# --- Global variables for models and tokenizers --- | |
tokenizer = None | |
model = None | |
# --- Load Models and Tokenizers Function --- | |
def load_model_and_tokenizer(): | |
""" | |
Loads the language model and tokenizer from Hugging Face Hub or a local path. | |
This function will be called once when the Gradio app starts up. | |
""" | |
global tokenizer, model | |
if tokenizer is not None and model is not None: | |
print("Model and tokenizer already loaded.") | |
return | |
print(f"Loading tokenizer from: {HUGGINGFACE_MODEL_ID}") | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(HUGGINGFACE_MODEL_ID) | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
print(f"Set tokenizer.pad_token to tokenizer.eos_token ({tokenizer.pad_token_id})") | |
print(f"Loading model from: {HUGGINGFACE_MODEL_ID}...") | |
model = AutoModelForCausalLM.from_pretrained( | |
HUGGINGFACE_MODEL_ID, | |
torch_dtype=TORCH_DTYPE, | |
device_map="auto" # Automatically maps model to GPU if available, else CPU | |
) | |
model.eval() # Set model to evaluation mode | |
print("Model loaded successfully.") | |
except Exception as e: | |
print(f"Error loading model or tokenizer: {e}") | |
print("Please ensure the model ID is correct and you have an internet connection for initial download, or the local path is valid.") | |
tokenizer = None | |
model = None | |
raise RuntimeError("Failed to load model. Check your model ID/path and internet connection.") | |
# --- Generate Response Function --- | |
def generate_response( | |
message: str, # Current user message | |
history: list # Gradio Chatbot history format (list of dictionaries with 'role' and 'content') | |
) -> list: # Returns updated history for the Chatbot | |
""" | |
Generates a text response from the loaded model based on user input and chat history. | |
""" | |
global tokenizer, model | |
# Initialize models if not already loaded | |
if tokenizer is None or model is None: | |
load_model_and_tokenizer() | |
if tokenizer is None or model is None: # Check again in case loading failed | |
# history.append([message, "Error: Chatbot model not loaded. Please check logs."]) | |
# For 'messages' type history, append a dictionary | |
history.append({"role": "user", "content": message}) | |
history.append({"role": "assistant", "content": "Error: Chatbot model not loaded. Please check logs."}) | |
return history | |
# Format messages for the model's chat template (e.g., for Instruct models) | |
# The 'history' now directly contains dictionaries if type='messages' is used. | |
messages = history # Use history directly as it's already in the correct format | |
messages.append({"role": "user", "content": message}) # Add current user message | |
# Apply the chat template and tokenize | |
try: | |
input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
except Exception as e: | |
print(f"Error applying chat template: {e}") | |
# Fallback if chat template fails (e.g., for non-chat models) | |
# Reconstruct input_text for models without explicit chat templates | |
input_text = "" | |
for item in history: | |
if item["role"] == "user": | |
input_text += f"User: {item['content']}\n" | |
elif item["role"] == "assistant": | |
input_text += f"Assistant: {item['content']}\n" | |
input_text += f"User: {message}\nAssistant:" | |
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(model.device) | |
# Generate response | |
with torch.no_grad(): # Disable gradient calculations for inference | |
output_ids = model.generate( | |
input_ids, | |
max_new_tokens=MAX_NEW_TOKENS, | |
do_sample=DO_SAMPLE, | |
temperature=TEMPERATURE, | |
top_k=TOP_K, | |
top_p=TOP_P, | |
pad_token_id=tokenizer.eos_token_id # Important for generation to stop cleanly | |
) | |
# Decode the generated text, excluding the input prompt part | |
generated_token_ids = output_ids[0][input_ids.shape[-1]:] | |
generated_text = tokenizer.decode(generated_token_ids, skip_special_tokens=True).strip() | |
# --- Update Chat History --- | |
# Append the latest generated response to the history with its role | |
history.append({"role": "assistant", "content": generated_text}) | |
return history | |
# --- Gradio Interface --- | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
""" | |
# Local Chatbot Powered by Hugging Face Transformers | |
Type your message below and chat with the model loaded locally on your machine! | |
""" | |
) | |
# Set type='messages' for the chatbot to use OpenAI-style dictionaries | |
chatbot = gr.Chatbot(label="Conversation", type='messages') | |
with gr.Row(): | |
text_input = gr.Textbox( | |
label="Your message", | |
placeholder="Type your message here...", | |
scale=4 | |
) | |
submit_button = gr.Button("Send", scale=1) | |
# Link the text input and button to the generation function | |
# Note: 'inputs' will be current message and the full history (as 'messages' type) | |
# 'outputs' will be the updated full history | |
submit_button.click( | |
fn=generate_response, | |
inputs=[text_input, chatbot], # text_input is the new message, chatbot is the history | |
outputs=[chatbot], | |
queue=True # Queue requests for better concurrency | |
) | |
text_input.submit( # Also trigger on Enter key | |
fn=generate_response, | |
inputs=[text_input, chatbot], | |
outputs=[chatbot], | |
queue=True | |
) | |
# Clear button | |
def clear_chat(): | |
# When type='messages', the clear function should return an empty list for history | |
# and an empty string for the text input. | |
return [], "" | |
clear_button = gr.Button("Clear Chat") | |
clear_button.click(clear_chat, inputs=None, outputs=[chatbot, text_input]) | |
# Load the model when the app starts. This will ensure it's ready when the first request comes in. | |
load_model_and_tokenizer() | |
# Launch the Gradio app | |
#demo.queue().launch() # For local development, use launch() | |
demo.queue().launch(server_name="0.0.0.0") | |