ajsbsd's picture
Upload app.py
5995f30 verified
raw
history blame
7.07 kB
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import os
# --- Configuration ---
# IMPORTANT: Replace with the path to your locally downloaded model or a Hugging Face model ID.
# Examples:
# LOCAL_MODEL_PATH = "/path/to/your/downloaded/qwen-1.5b-instruct"
# HUGGINGFACE_MODEL_ID = "Qwen/Qwen1.5-1.8B-Chat" # For a smaller Qwen model for local testing
HUGGINGFACE_MODEL_ID = "."
# You might need to adjust TORCH_DTYPE based on your GPU and model support
# torch.float16 (FP16) is common for inference, torch.bfloat16 for newer GPUs
TORCH_DTYPE = torch.float16 # or torch.bfloat16 or torch.float32
# Generation parameters (can be adjusted for different response styles)
MAX_NEW_TOKENS = 512
DO_SAMPLE = True
TEMPERATURE = 0.7
TOP_K = 50
TOP_P = 0.95
# --- Global variables for models and tokenizers ---
tokenizer = None
model = None
# --- Load Models and Tokenizers Function ---
def load_model_and_tokenizer():
"""
Loads the language model and tokenizer from Hugging Face Hub or a local path.
This function will be called once when the Gradio app starts up.
"""
global tokenizer, model
if tokenizer is not None and model is not None:
print("Model and tokenizer already loaded.")
return
print(f"Loading tokenizer from: {HUGGINGFACE_MODEL_ID}")
try:
tokenizer = AutoTokenizer.from_pretrained(HUGGINGFACE_MODEL_ID)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print(f"Set tokenizer.pad_token to tokenizer.eos_token ({tokenizer.pad_token_id})")
print(f"Loading model from: {HUGGINGFACE_MODEL_ID}...")
model = AutoModelForCausalLM.from_pretrained(
HUGGINGFACE_MODEL_ID,
torch_dtype=TORCH_DTYPE,
device_map="auto" # Automatically maps model to GPU if available, else CPU
)
model.eval() # Set model to evaluation mode
print("Model loaded successfully.")
except Exception as e:
print(f"Error loading model or tokenizer: {e}")
print("Please ensure the model ID is correct and you have an internet connection for initial download, or the local path is valid.")
tokenizer = None
model = None
raise RuntimeError("Failed to load model. Check your model ID/path and internet connection.")
# --- Generate Response Function ---
def generate_response(
message: str, # Current user message
history: list # Gradio Chatbot history format (list of dictionaries with 'role' and 'content')
) -> list: # Returns updated history for the Chatbot
"""
Generates a text response from the loaded model based on user input and chat history.
"""
global tokenizer, model
# Initialize models if not already loaded
if tokenizer is None or model is None:
load_model_and_tokenizer()
if tokenizer is None or model is None: # Check again in case loading failed
# history.append([message, "Error: Chatbot model not loaded. Please check logs."])
# For 'messages' type history, append a dictionary
history.append({"role": "user", "content": message})
history.append({"role": "assistant", "content": "Error: Chatbot model not loaded. Please check logs."})
return history
# Format messages for the model's chat template (e.g., for Instruct models)
# The 'history' now directly contains dictionaries if type='messages' is used.
messages = history # Use history directly as it's already in the correct format
messages.append({"role": "user", "content": message}) # Add current user message
# Apply the chat template and tokenize
try:
input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
except Exception as e:
print(f"Error applying chat template: {e}")
# Fallback if chat template fails (e.g., for non-chat models)
# Reconstruct input_text for models without explicit chat templates
input_text = ""
for item in history:
if item["role"] == "user":
input_text += f"User: {item['content']}\n"
elif item["role"] == "assistant":
input_text += f"Assistant: {item['content']}\n"
input_text += f"User: {message}\nAssistant:"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(model.device)
# Generate response
with torch.no_grad(): # Disable gradient calculations for inference
output_ids = model.generate(
input_ids,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=DO_SAMPLE,
temperature=TEMPERATURE,
top_k=TOP_K,
top_p=TOP_P,
pad_token_id=tokenizer.eos_token_id # Important for generation to stop cleanly
)
# Decode the generated text, excluding the input prompt part
generated_token_ids = output_ids[0][input_ids.shape[-1]:]
generated_text = tokenizer.decode(generated_token_ids, skip_special_tokens=True).strip()
# --- Update Chat History ---
# Append the latest generated response to the history with its role
history.append({"role": "assistant", "content": generated_text})
return history
# --- Gradio Interface ---
with gr.Blocks() as demo:
gr.Markdown(
"""
# Local Chatbot Powered by Hugging Face Transformers
Type your message below and chat with the model loaded locally on your machine!
"""
)
# Set type='messages' for the chatbot to use OpenAI-style dictionaries
chatbot = gr.Chatbot(label="Conversation", type='messages')
with gr.Row():
text_input = gr.Textbox(
label="Your message",
placeholder="Type your message here...",
scale=4
)
submit_button = gr.Button("Send", scale=1)
# Link the text input and button to the generation function
# Note: 'inputs' will be current message and the full history (as 'messages' type)
# 'outputs' will be the updated full history
submit_button.click(
fn=generate_response,
inputs=[text_input, chatbot], # text_input is the new message, chatbot is the history
outputs=[chatbot],
queue=True # Queue requests for better concurrency
)
text_input.submit( # Also trigger on Enter key
fn=generate_response,
inputs=[text_input, chatbot],
outputs=[chatbot],
queue=True
)
# Clear button
def clear_chat():
# When type='messages', the clear function should return an empty list for history
# and an empty string for the text input.
return [], ""
clear_button = gr.Button("Clear Chat")
clear_button.click(clear_chat, inputs=None, outputs=[chatbot, text_input])
# Load the model when the app starts. This will ensure it's ready when the first request comes in.
load_model_and_tokenizer()
# Launch the Gradio app
#demo.queue().launch() # For local development, use launch()
demo.queue().launch(server_name="0.0.0.0")