Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer | |
import threading | |
import queue | |
import time | |
import spaces | |
# Model configuration | |
model_name = "HelpingAI/Dhanishtha-2.0-preview" | |
# Global variables for model and tokenizer | |
model = None | |
tokenizer = None | |
def load_model(): | |
"""Load the model and tokenizer""" | |
global model, tokenizer | |
print("Loading tokenizer...") | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
print("Loading model...") | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype="auto", | |
device_map="auto", | |
trust_remote_code=True | |
) | |
print("Model loaded successfully!") | |
class GradioTextStreamer(TextStreamer): | |
"""Custom TextStreamer for Gradio integration""" | |
def __init__(self, tokenizer, skip_prompt=True, skip_special_tokens=True): | |
super().__init__(tokenizer, skip_prompt, skip_special_tokens) | |
self.text_queue = queue.Queue() | |
self.generated_text = "" | |
def on_finalized_text(self, text: str, stream_end: bool = False): | |
"""Called when text is finalized""" | |
self.generated_text += text | |
self.text_queue.put(text) | |
if stream_end: | |
self.text_queue.put(None) | |
def get_generated_text(self): | |
"""Get all generated text so far""" | |
return self.generated_text | |
def reset(self): | |
"""Reset the streamer""" | |
self.generated_text = "" | |
# Clear the queue | |
while not self.text_queue.empty(): | |
try: | |
self.text_queue.get_nowait() | |
except queue.Empty: | |
break | |
def generate_response(message, history, max_tokens, temperature, top_p): | |
"""Generate streaming response""" | |
global model, tokenizer | |
if model is None or tokenizer is None: | |
yield "Model is still loading. Please wait..." | |
return | |
# Prepare conversation history | |
messages = [] | |
for user_msg, assistant_msg in history: | |
messages.append({"role": "user", "content": user_msg}) | |
if assistant_msg: | |
messages.append({"role": "assistant", "content": assistant_msg}) | |
# Add current message | |
messages.append({"role": "user", "content": message}) | |
# Apply chat template | |
text = tokenizer.apply_chat_template( | |
messages, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
# Tokenize input | |
model_inputs = tokenizer([text], return_tensors="pt").to(model.device) | |
# Create and setup streamer | |
streamer = GradioTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
streamer.reset() | |
# Start generation in a separate thread | |
generation_kwargs = { | |
**model_inputs, | |
"max_new_tokens": max_tokens, | |
"temperature": temperature, | |
"top_p": top_p, | |
"do_sample": True, | |
"pad_token_id": tokenizer.eos_token_id, | |
"streamer": streamer, | |
"return_dict_in_generate": True | |
} | |
# Run generation in thread | |
def generate(): | |
try: | |
with torch.no_grad(): | |
model.generate(**generation_kwargs) | |
except Exception as e: | |
streamer.text_queue.put(f"Error: {str(e)}") | |
streamer.text_queue.put(None) | |
thread = threading.Thread(target=generate) | |
thread.start() | |
# Stream the results | |
generated_text = "" | |
while True: | |
try: | |
new_text = streamer.text_queue.get(timeout=30) | |
if new_text is None: | |
break | |
generated_text += new_text | |
yield generated_text | |
except queue.Empty: | |
break | |
thread.join(timeout=1) | |
# Final yield with complete text | |
if generated_text: | |
yield generated_text | |
else: | |
yield "No response generated." | |
def chat_interface(message, history, max_tokens, temperature, top_p): | |
"""Main chat interface""" | |
if not message.strip(): | |
return history, "" | |
# Add user message to history | |
history.append([message, ""]) | |
# Generate response | |
for partial_response in generate_response(message, history[:-1], max_tokens, temperature, top_p): | |
history[-1][1] = partial_response | |
yield history, "" | |
return history, "" | |
# Load model on startup | |
print("Initializing model...") | |
load_model() | |
# Create Gradio interface | |
with gr.Blocks(title="Dhanishtha-2.0-preview Chat", theme=gr.themes.Soft()) as demo: | |
gr.Markdown( | |
""" | |
# 🤖 Dhanishtha-2.0-preview Chat | |
Chat with the **HelpingAI/Dhanishtha-2.0-preview** model! | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(scale=4): | |
chatbot = gr.Chatbot( | |
[], | |
elem_id="chatbot", | |
bubble_full_width=False, | |
height=500, | |
show_copy_button=True | |
) | |
with gr.Row(): | |
msg = gr.Textbox( | |
container=False, | |
placeholder="Type your message here...", | |
label="Message", | |
autofocus=True, | |
scale=7 | |
) | |
send_btn = gr.Button("Send", variant="primary", scale=1) | |
with gr.Column(scale=1): | |
gr.Markdown("### ⚙️ Parameters") | |
max_tokens = gr.Slider( | |
minimum=1, | |
maximum=4096, | |
value=2048, | |
step=1, | |
label="Max Tokens", | |
info="Maximum number of tokens to generate" | |
) | |
temperature = gr.Slider( | |
minimum=0.1, | |
maximum=2.0, | |
value=0.7, | |
step=0.1, | |
label="Temperature", | |
info="Controls randomness in generation" | |
) | |
top_p = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.9, | |
step=0.05, | |
label="Top-p", | |
info="Controls diversity of generation" | |
) | |
clear_btn = gr.Button("🗑️ Clear Chat", variant="secondary") | |
# Event handlers | |
msg.submit( | |
chat_interface, | |
inputs=[msg, chatbot, max_tokens, temperature, top_p], | |
outputs=[chatbot, msg], | |
concurrency_limit=1 | |
) | |
send_btn.click( | |
chat_interface, | |
inputs=[msg, chatbot, max_tokens, temperature, top_p], | |
outputs=[chatbot, msg], | |
concurrency_limit=1 | |
) | |
clear_btn.click( | |
lambda: ([], ""), | |
outputs=[chatbot, msg] | |
) | |
# Example prompts | |
gr.Examples( | |
examples=[ | |
["Hello! Who are you?"], | |
["Explain quantum computing in simple terms"], | |
["Write a short story about a robot learning to paint"], | |
["What are the benefits of renewable energy?"], | |
["Help me write a Python function to sort a list"] | |
], | |
inputs=msg, | |
label="💡 Example Prompts" | |
) | |
if __name__ == "__main__": | |
demo.queue(max_size=20).launch() |