Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer | |
import threading | |
import queue | |
import time | |
import spaces | |
import sys | |
from io import StringIO | |
# Model configuration | |
model_name = "HelpingAI/Dhanishtha-2.0-preview" | |
# Global variables for model and tokenizer | |
model = None | |
tokenizer = None | |
def load_model(): | |
"""Load the model and tokenizer""" | |
global model, tokenizer | |
print("Loading tokenizer...") | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
print("Loading model...") | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype="auto", | |
device_map="auto", | |
trust_remote_code=True | |
) | |
print("Model loaded successfully!") | |
class StreamCapture: | |
"""Capture streaming output from TextStreamer""" | |
def __init__(self): | |
self.text_queue = queue.Queue() | |
self.captured_text = "" | |
def write(self, text): | |
"""Capture written text""" | |
if text and text.strip(): | |
self.captured_text += text | |
self.text_queue.put(text) | |
return len(text) | |
def flush(self): | |
"""Flush method for compatibility""" | |
pass | |
def get_text(self): | |
"""Get all captured text""" | |
return self.captured_text | |
def reset(self): | |
"""Reset the capture""" | |
self.captured_text = "" | |
while not self.text_queue.empty(): | |
try: | |
self.text_queue.get_nowait() | |
except queue.Empty: | |
break | |
def generate_response(message, history, max_tokens, temperature, top_p): | |
"""Generate streaming response""" | |
global model, tokenizer | |
if model is None or tokenizer is None: | |
yield "Model is still loading. Please wait..." | |
return | |
# Prepare conversation history | |
messages = [] | |
for user_msg, assistant_msg in history: | |
messages.append({"role": "user", "content": user_msg}) | |
if assistant_msg: | |
messages.append({"role": "assistant", "content": assistant_msg}) | |
# Add current message | |
messages.append({"role": "user", "content": message}) | |
# Apply chat template | |
text = tokenizer.apply_chat_template( | |
messages, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
# Tokenize input | |
model_inputs = tokenizer([text], return_tensors="pt").to(model.device) | |
# Create stream capture | |
stream_capture = StreamCapture() | |
# Create TextStreamer with our capture | |
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
# Temporarily redirect the streamer's output | |
original_stdout = sys.stdout | |
# Generation parameters | |
generation_kwargs = { | |
**model_inputs, | |
"max_new_tokens": max_tokens, | |
"temperature": temperature, | |
"top_p": top_p, | |
"do_sample": True, | |
"pad_token_id": tokenizer.eos_token_id, | |
"streamer": streamer, | |
} | |
# Start generation in a separate thread | |
def generate(): | |
try: | |
# Redirect stdout to capture streamer output | |
sys.stdout = stream_capture | |
with torch.no_grad(): | |
model.generate(**generation_kwargs) | |
except Exception as e: | |
stream_capture.text_queue.put(f"Error: {str(e)}") | |
finally: | |
# Restore stdout | |
sys.stdout = original_stdout | |
stream_capture.text_queue.put(None) # Signal end | |
thread = threading.Thread(target=generate) | |
thread.start() | |
# Stream the results | |
generated_text = "" | |
while True: | |
try: | |
new_text = stream_capture.text_queue.get(timeout=30) | |
if new_text is None: | |
break | |
generated_text += new_text | |
yield generated_text | |
except queue.Empty: | |
break | |
thread.join(timeout=1) | |
# Final yield with complete text | |
if generated_text: | |
yield generated_text | |
else: | |
yield "No response generated." | |
def chat_interface(message, history, max_tokens, temperature, top_p): | |
"""Main chat interface""" | |
if not message.strip(): | |
return history, "" | |
# Add user message to history | |
history.append([message, ""]) | |
# Generate response | |
for partial_response in generate_response(message, history[:-1], max_tokens, temperature, top_p): | |
history[-1][1] = partial_response | |
yield history, "" | |
return history, "" | |
# Load model on startup | |
print("Initializing model...") | |
load_model() | |
# Create Gradio interface | |
with gr.Blocks(title="Dhanishtha-2.0-preview Chat", theme=gr.themes.Soft()) as demo: | |
gr.Markdown( | |
""" | |
# 🤖 Dhanishtha-2.0-preview Chat | |
Chat with the **HelpingAI/Dhanishtha-2.0-preview** model! | |
Dhanishtha 2.0 is the world's first LLM designed to think between the responses. Unlike other Reasoning LLMs, which think just once. | |
Dhanishtha can think, rethink, self-evaluate, and refine in between responses using multiple <think> blocks. | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(scale=4): | |
chatbot = gr.Chatbot( | |
[], | |
elem_id="chatbot", | |
bubble_full_width=False, | |
height=500, | |
show_copy_button=True | |
) | |
with gr.Row(): | |
msg = gr.Textbox( | |
container=False, | |
placeholder="Type your message here...", | |
label="Message", | |
autofocus=True, | |
scale=7 | |
) | |
send_btn = gr.Button("Send", variant="primary", scale=1) | |
with gr.Column(scale=1): | |
gr.Markdown("### ⚙️ Parameters") | |
max_tokens = gr.Slider( | |
minimum=1, | |
maximum=40960, | |
value=2048, | |
step=1, | |
label="Max Tokens", | |
info="Maximum number of tokens to generate" | |
) | |
temperature = gr.Slider( | |
minimum=0.1, | |
maximum=2.0, | |
value=0.7, | |
step=0.1, | |
label="Temperature", | |
info="Controls randomness in generation" | |
) | |
top_p = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.9, | |
step=0.05, | |
label="Top-p", | |
info="Controls diversity of generation" | |
) | |
clear_btn = gr.Button("🗑️ Clear Chat", variant="secondary") | |
# Event handlers | |
msg.submit( | |
chat_interface, | |
inputs=[msg, chatbot, max_tokens, temperature, top_p], | |
outputs=[chatbot, msg], | |
concurrency_limit=1 | |
) | |
send_btn.click( | |
chat_interface, | |
inputs=[msg, chatbot, max_tokens, temperature, top_p], | |
outputs=[chatbot, msg], | |
concurrency_limit=1 | |
) | |
clear_btn.click( | |
lambda: ([], ""), | |
outputs=[chatbot, msg] | |
) | |
# Example prompts | |
gr.Examples( | |
examples=[ | |
["Hello! Who are you?"], | |
["Explain quantum computing in simple terms"], | |
["Write a short story about a robot learning to paint"], | |
["What are the benefits of renewable energy?"], | |
["Help me write a Python function to sort a list"] | |
], | |
inputs=msg, | |
label="💡 Example Prompts" | |
) | |
if __name__ == "__main__": | |
demo.queue(max_size=20).launch() |