Spaces:
Running
on
Zero
Running
on
Zero
Upload app.py
Browse files
app.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
4 |
+
import os
|
5 |
+
|
6 |
+
# --- Configuration ---
|
7 |
+
# IMPORTANT: Replace with the path to your locally downloaded model or a Hugging Face model ID.
|
8 |
+
# Examples:
|
9 |
+
# LOCAL_MODEL_PATH = "/path/to/your/downloaded/qwen-1.5b-instruct"
|
10 |
+
# HUGGINGFACE_MODEL_ID = "Qwen/Qwen1.5-1.8B-Chat" # For a smaller Qwen model for local testing
|
11 |
+
HUGGINGFACE_MODEL_ID = "."
|
12 |
+
|
13 |
+
# You might need to adjust TORCH_DTYPE based on your GPU and model support
|
14 |
+
# torch.float16 (FP16) is common for inference, torch.bfloat16 for newer GPUs
|
15 |
+
TORCH_DTYPE = torch.float16 # or torch.bfloat16 or torch.float32
|
16 |
+
|
17 |
+
# Generation parameters (can be adjusted for different response styles)
|
18 |
+
MAX_NEW_TOKENS = 512
|
19 |
+
DO_SAMPLE = True
|
20 |
+
TEMPERATURE = 0.7
|
21 |
+
TOP_K = 50
|
22 |
+
TOP_P = 0.95
|
23 |
+
|
24 |
+
# --- Global variables for models and tokenizers ---
|
25 |
+
tokenizer = None
|
26 |
+
model = None
|
27 |
+
|
28 |
+
# --- Load Models and Tokenizers Function ---
|
29 |
+
def load_model_and_tokenizer():
|
30 |
+
"""
|
31 |
+
Loads the language model and tokenizer from Hugging Face Hub or a local path.
|
32 |
+
This function will be called once when the Gradio app starts up.
|
33 |
+
"""
|
34 |
+
global tokenizer, model
|
35 |
+
|
36 |
+
if tokenizer is not None and model is not None:
|
37 |
+
print("Model and tokenizer already loaded.")
|
38 |
+
return
|
39 |
+
|
40 |
+
print(f"Loading tokenizer from: {HUGGINGFACE_MODEL_ID}")
|
41 |
+
try:
|
42 |
+
tokenizer = AutoTokenizer.from_pretrained(HUGGINGFACE_MODEL_ID)
|
43 |
+
if tokenizer.pad_token is None:
|
44 |
+
tokenizer.pad_token = tokenizer.eos_token
|
45 |
+
print(f"Set tokenizer.pad_token to tokenizer.eos_token ({tokenizer.pad_token_id})")
|
46 |
+
|
47 |
+
print(f"Loading model from: {HUGGINGFACE_MODEL_ID}...")
|
48 |
+
model = AutoModelForCausalLM.from_pretrained(
|
49 |
+
HUGGINGFACE_MODEL_ID,
|
50 |
+
torch_dtype=TORCH_DTYPE,
|
51 |
+
device_map="auto" # Automatically maps model to GPU if available, else CPU
|
52 |
+
)
|
53 |
+
model.eval() # Set model to evaluation mode
|
54 |
+
print("Model loaded successfully.")
|
55 |
+
except Exception as e:
|
56 |
+
print(f"Error loading model or tokenizer: {e}")
|
57 |
+
print("Please ensure the model ID is correct and you have an internet connection for initial download, or the local path is valid.")
|
58 |
+
tokenizer = None
|
59 |
+
model = None
|
60 |
+
raise RuntimeError("Failed to load model. Check your model ID/path and internet connection.")
|
61 |
+
|
62 |
+
|
63 |
+
# --- Generate Response Function ---
|
64 |
+
def generate_response(
|
65 |
+
message: str, # Current user message
|
66 |
+
history: list # Gradio Chatbot history format (list of dictionaries with 'role' and 'content')
|
67 |
+
) -> list: # Returns updated history for the Chatbot
|
68 |
+
"""
|
69 |
+
Generates a text response from the loaded model based on user input and chat history.
|
70 |
+
"""
|
71 |
+
global tokenizer, model
|
72 |
+
|
73 |
+
# Initialize models if not already loaded
|
74 |
+
if tokenizer is None or model is None:
|
75 |
+
load_model_and_tokenizer()
|
76 |
+
|
77 |
+
if tokenizer is None or model is None: # Check again in case loading failed
|
78 |
+
# history.append([message, "Error: Chatbot model not loaded. Please check logs."])
|
79 |
+
# For 'messages' type history, append a dictionary
|
80 |
+
history.append({"role": "user", "content": message})
|
81 |
+
history.append({"role": "assistant", "content": "Error: Chatbot model not loaded. Please check logs."})
|
82 |
+
return history
|
83 |
+
|
84 |
+
# Format messages for the model's chat template (e.g., for Instruct models)
|
85 |
+
# The 'history' now directly contains dictionaries if type='messages' is used.
|
86 |
+
messages = history # Use history directly as it's already in the correct format
|
87 |
+
messages.append({"role": "user", "content": message}) # Add current user message
|
88 |
+
|
89 |
+
# Apply the chat template and tokenize
|
90 |
+
try:
|
91 |
+
input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
92 |
+
except Exception as e:
|
93 |
+
print(f"Error applying chat template: {e}")
|
94 |
+
# Fallback if chat template fails (e.g., for non-chat models)
|
95 |
+
# Reconstruct input_text for models without explicit chat templates
|
96 |
+
input_text = ""
|
97 |
+
for item in history:
|
98 |
+
if item["role"] == "user":
|
99 |
+
input_text += f"User: {item['content']}\n"
|
100 |
+
elif item["role"] == "assistant":
|
101 |
+
input_text += f"Assistant: {item['content']}\n"
|
102 |
+
input_text += f"User: {message}\nAssistant:"
|
103 |
+
|
104 |
+
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(model.device)
|
105 |
+
|
106 |
+
# Generate response
|
107 |
+
with torch.no_grad(): # Disable gradient calculations for inference
|
108 |
+
output_ids = model.generate(
|
109 |
+
input_ids,
|
110 |
+
max_new_tokens=MAX_NEW_TOKENS,
|
111 |
+
do_sample=DO_SAMPLE,
|
112 |
+
temperature=TEMPERATURE,
|
113 |
+
top_k=TOP_K,
|
114 |
+
top_p=TOP_P,
|
115 |
+
pad_token_id=tokenizer.eos_token_id # Important for generation to stop cleanly
|
116 |
+
)
|
117 |
+
|
118 |
+
# Decode the generated text, excluding the input prompt part
|
119 |
+
generated_token_ids = output_ids[0][input_ids.shape[-1]:]
|
120 |
+
generated_text = tokenizer.decode(generated_token_ids, skip_special_tokens=True).strip()
|
121 |
+
|
122 |
+
# --- Update Chat History ---
|
123 |
+
# Append the latest generated response to the history with its role
|
124 |
+
history.append({"role": "assistant", "content": generated_text})
|
125 |
+
|
126 |
+
return history
|
127 |
+
|
128 |
+
# --- Gradio Interface ---
|
129 |
+
with gr.Blocks() as demo:
|
130 |
+
gr.Markdown(
|
131 |
+
"""
|
132 |
+
# Local Chatbot Powered by Hugging Face Transformers
|
133 |
+
Type your message below and chat with the model loaded locally on your machine!
|
134 |
+
"""
|
135 |
+
)
|
136 |
+
|
137 |
+
# Set type='messages' for the chatbot to use OpenAI-style dictionaries
|
138 |
+
chatbot = gr.Chatbot(label="Conversation", type='messages')
|
139 |
+
with gr.Row():
|
140 |
+
text_input = gr.Textbox(
|
141 |
+
label="Your message",
|
142 |
+
placeholder="Type your message here...",
|
143 |
+
scale=4
|
144 |
+
)
|
145 |
+
submit_button = gr.Button("Send", scale=1)
|
146 |
+
|
147 |
+
# Link the text input and button to the generation function
|
148 |
+
# Note: 'inputs' will be current message and the full history (as 'messages' type)
|
149 |
+
# 'outputs' will be the updated full history
|
150 |
+
submit_button.click(
|
151 |
+
fn=generate_response,
|
152 |
+
inputs=[text_input, chatbot], # text_input is the new message, chatbot is the history
|
153 |
+
outputs=[chatbot],
|
154 |
+
queue=True # Queue requests for better concurrency
|
155 |
+
)
|
156 |
+
text_input.submit( # Also trigger on Enter key
|
157 |
+
fn=generate_response,
|
158 |
+
inputs=[text_input, chatbot],
|
159 |
+
outputs=[chatbot],
|
160 |
+
queue=True
|
161 |
+
)
|
162 |
+
|
163 |
+
# Clear button
|
164 |
+
def clear_chat():
|
165 |
+
# When type='messages', the clear function should return an empty list for history
|
166 |
+
# and an empty string for the text input.
|
167 |
+
return [], ""
|
168 |
+
clear_button = gr.Button("Clear Chat")
|
169 |
+
clear_button.click(clear_chat, inputs=None, outputs=[chatbot, text_input])
|
170 |
+
|
171 |
+
|
172 |
+
# Load the model when the app starts. This will ensure it's ready when the first request comes in.
|
173 |
+
load_model_and_tokenizer()
|
174 |
+
|
175 |
+
# Launch the Gradio app
|
176 |
+
#demo.queue().launch() # For local development, use launch()
|
177 |
+
demo.queue().launch(server_name="0.0.0.0")
|