Abhaykoul's picture
Update app.py
a1c55c3 verified
raw
history blame
7.3 kB
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
import threading
import queue
import time
import spaces
# Model configuration
model_name = "HelpingAI/Dhanishtha-2.0-preview"
# Global variables for model and tokenizer
model = None
tokenizer = None
def load_model():
"""Load the model and tokenizer"""
global model, tokenizer
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
device_map="auto",
trust_remote_code=True
)
print("Model loaded successfully!")
class GradioTextStreamer(TextStreamer):
"""Custom TextStreamer for Gradio integration"""
def __init__(self, tokenizer, skip_prompt=True, skip_special_tokens=True):
super().__init__(tokenizer, skip_prompt, skip_special_tokens)
self.text_queue = queue.Queue()
self.generated_text = ""
def on_finalized_text(self, text: str, stream_end: bool = False):
"""Called when text is finalized"""
self.generated_text += text
self.text_queue.put(text)
if stream_end:
self.text_queue.put(None)
def get_generated_text(self):
"""Get all generated text so far"""
return self.generated_text
def reset(self):
"""Reset the streamer"""
self.generated_text = ""
# Clear the queue
while not self.text_queue.empty():
try:
self.text_queue.get_nowait()
except queue.Empty:
break
@spaces.GPU()
def generate_response(message, history, max_tokens, temperature, top_p):
"""Generate streaming response"""
global model, tokenizer
if model is None or tokenizer is None:
yield "Model is still loading. Please wait..."
return
# Prepare conversation history
messages = []
for user_msg, assistant_msg in history:
messages.append({"role": "user", "content": user_msg})
if assistant_msg:
messages.append({"role": "assistant", "content": assistant_msg})
# Add current message
messages.append({"role": "user", "content": message})
# Apply chat template
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# Tokenize input
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
# Create and setup streamer
streamer = GradioTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
streamer.reset()
# Start generation in a separate thread
generation_kwargs = {
**model_inputs,
"max_new_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
"do_sample": True,
"pad_token_id": tokenizer.eos_token_id,
"streamer": streamer,
"return_dict_in_generate": True
}
# Run generation in thread
def generate():
try:
with torch.no_grad():
model.generate(**generation_kwargs)
except Exception as e:
streamer.text_queue.put(f"Error: {str(e)}")
streamer.text_queue.put(None)
thread = threading.Thread(target=generate)
thread.start()
# Stream the results
generated_text = ""
while True:
try:
new_text = streamer.text_queue.get(timeout=30)
if new_text is None:
break
generated_text += new_text
yield generated_text
except queue.Empty:
break
thread.join(timeout=1)
# Final yield with complete text
if generated_text:
yield generated_text
else:
yield "No response generated."
def chat_interface(message, history, max_tokens, temperature, top_p):
"""Main chat interface"""
if not message.strip():
return history, ""
# Add user message to history
history.append([message, ""])
# Generate response
for partial_response in generate_response(message, history[:-1], max_tokens, temperature, top_p):
history[-1][1] = partial_response
yield history, ""
return history, ""
# Load model on startup
print("Initializing model...")
load_model()
# Create Gradio interface
with gr.Blocks(title="Dhanishtha-2.0-preview Chat", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# 🤖 Dhanishtha-2.0-preview Chat
Chat with the **HelpingAI/Dhanishtha-2.0-preview** model!
"""
)
with gr.Row():
with gr.Column(scale=4):
chatbot = gr.Chatbot(
[],
elem_id="chatbot",
bubble_full_width=False,
height=500,
show_copy_button=True
)
with gr.Row():
msg = gr.Textbox(
container=False,
placeholder="Type your message here...",
label="Message",
autofocus=True,
scale=7
)
send_btn = gr.Button("Send", variant="primary", scale=1)
with gr.Column(scale=1):
gr.Markdown("### ⚙️ Parameters")
max_tokens = gr.Slider(
minimum=1,
maximum=4096,
value=2048,
step=1,
label="Max Tokens",
info="Maximum number of tokens to generate"
)
temperature = gr.Slider(
minimum=0.1,
maximum=2.0,
value=0.7,
step=0.1,
label="Temperature",
info="Controls randomness in generation"
)
top_p = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.9,
step=0.05,
label="Top-p",
info="Controls diversity of generation"
)
clear_btn = gr.Button("🗑️ Clear Chat", variant="secondary")
# Event handlers
msg.submit(
chat_interface,
inputs=[msg, chatbot, max_tokens, temperature, top_p],
outputs=[chatbot, msg],
concurrency_limit=1
)
send_btn.click(
chat_interface,
inputs=[msg, chatbot, max_tokens, temperature, top_p],
outputs=[chatbot, msg],
concurrency_limit=1
)
clear_btn.click(
lambda: ([], ""),
outputs=[chatbot, msg]
)
# Example prompts
gr.Examples(
examples=[
["Hello! Who are you?"],
["Explain quantum computing in simple terms"],
["Write a short story about a robot learning to paint"],
["What are the benefits of renewable energy?"],
["Help me write a Python function to sort a list"]
],
inputs=msg,
label="💡 Example Prompts"
)
if __name__ == "__main__":
demo.queue(max_size=20).launch()