Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import spaces | |
import re | |
# Model configuration | |
model_name = "HelpingAI/Dhanishtha-2.0-preview" | |
# Global variables for model and tokenizer | |
model = None | |
tokenizer = None | |
def load_model(): | |
"""Load the model and tokenizer""" | |
global model, tokenizer | |
print("Loading tokenizer...") | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
# Ensure pad token is set | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
print("Loading model...") | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype="auto", | |
device_map="auto", | |
trust_remote_code=True | |
) | |
print("Model loaded successfully!") | |
def format_thinking_text(text): | |
"""Format text to properly display <think> tags in Gradio with better styling""" | |
if not text: | |
return text | |
# More sophisticated formatting for thinking blocks | |
# Replace <think> and </think> tags with styled markdown | |
formatted_text = text | |
# Handle thinking blocks with proper markdown formatting | |
thinking_pattern = r'<think>(.*?)</think>' | |
def replace_thinking_block(match): | |
thinking_content = match.group(1).strip() | |
return f"\n\n๐ญ **Thinking Process:**\n\n```\n{thinking_content}\n```\n\n" | |
formatted_text = re.sub(thinking_pattern, replace_thinking_block, formatted_text, flags=re.DOTALL) | |
# Clean up any remaining raw tags that might not have been caught | |
formatted_text = re.sub(r'</?think>', '', formatted_text) | |
return formatted_text.strip() | |
def generate_response(message, history, max_tokens, temperature, top_p): | |
"""Generate streaming response without threading""" | |
global model, tokenizer | |
if model is None or tokenizer is None: | |
yield "Model is still loading. Please wait..." | |
return | |
# Prepare conversation history | |
messages = [] | |
for user_msg, assistant_msg in history: | |
messages.append({"role": "user", "content": user_msg}) | |
if assistant_msg: | |
messages.append({"role": "assistant", "content": assistant_msg}) | |
# Add current message | |
messages.append({"role": "user", "content": message}) | |
# Apply chat template | |
text = tokenizer.apply_chat_template( | |
messages, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
# Tokenize input | |
model_inputs = tokenizer([text], return_tensors="pt").to(model.device) | |
try: | |
with torch.no_grad(): | |
# Use transformers streaming with custom approach | |
generated_text = "" | |
current_input_ids = model_inputs["input_ids"] | |
current_attention_mask = model_inputs["attention_mask"] | |
for _ in range(max_tokens): | |
# Generate next token | |
outputs = model( | |
input_ids=current_input_ids, | |
attention_mask=current_attention_mask, | |
use_cache=True | |
) | |
# Get logits for the last token | |
logits = outputs.logits[0, -1, :] | |
# Apply temperature | |
if temperature != 1.0: | |
logits = logits / temperature | |
# Apply top-p sampling | |
if top_p < 1.0: | |
sorted_logits, sorted_indices = torch.sort(logits, descending=True) | |
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) | |
sorted_indices_to_remove = cumulative_probs > top_p | |
sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone() | |
sorted_indices_to_remove[0] = 0 | |
indices_to_remove = sorted_indices[sorted_indices_to_remove] | |
logits[indices_to_remove] = float('-inf') | |
# Sample next token | |
probs = torch.softmax(logits, dim=-1) | |
next_token = torch.multinomial(probs, num_samples=1) | |
# Check for EOS token | |
if next_token.item() == tokenizer.eos_token_id: | |
break | |
# Decode the new token (preserve special tokens like <think>) | |
new_token_text = tokenizer.decode(next_token, skip_special_tokens=False) | |
generated_text += new_token_text | |
# Format and yield the current text | |
formatted_text = format_thinking_text(generated_text) | |
yield formatted_text | |
# Update inputs for next iteration | |
current_input_ids = torch.cat([current_input_ids, next_token.unsqueeze(0)], dim=-1) | |
current_attention_mask = torch.cat([current_attention_mask, torch.ones((1, 1), device=model.device)], dim=-1) | |
except Exception as e: | |
yield f"Error generating response: {str(e)}" | |
return | |
# Final yield with complete formatted text | |
final_text = format_thinking_text(generated_text) if generated_text else "No response generated." | |
yield final_text | |
def chat_interface(message, history, max_tokens, temperature, top_p): | |
"""Main chat interface with improved streaming""" | |
if not message.strip(): | |
return history, "" | |
# Add user message to history | |
history.append([message, ""]) | |
# Generate response with streaming | |
for partial_response in generate_response(message, history[:-1], max_tokens, temperature, top_p): | |
history[-1][1] = partial_response | |
yield history, "" | |
return history, "" | |
# Load model on startup | |
print("Initializing model...") | |
load_model() | |
# Custom CSS for better styling and thinking blocks | |
custom_css = """ | |
/* Main chatbot styling */ | |
.chatbot { | |
font-size: 14px; | |
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; | |
} | |
/* Thinking block styling */ | |
.thinking-block { | |
background: linear-gradient(135deg, #f0f8ff 0%, #e6f3ff 100%); | |
border-left: 4px solid #4a90e2; | |
border-radius: 8px; | |
padding: 12px 16px; | |
margin: 12px 0; | |
font-family: 'Consolas', 'Monaco', 'Courier New', monospace; | |
box-shadow: 0 2px 4px rgba(0,0,0,0.1); | |
position: relative; | |
} | |
.thinking-block::before { | |
content: "๐ค"; | |
position: absolute; | |
top: -8px; | |
left: 12px; | |
background: white; | |
padding: 0 4px; | |
font-size: 16px; | |
} | |
/* Message styling */ | |
.message { | |
padding: 10px 14px; | |
margin: 6px 0; | |
border-radius: 12px; | |
line-height: 1.5; | |
} | |
.user-message { | |
background: linear-gradient(135deg, #e3f2fd 0%, #bbdefb 100%); | |
margin-left: 15%; | |
border-bottom-right-radius: 4px; | |
} | |
.assistant-message { | |
background: linear-gradient(135deg, #f5f5f5 0%, #eeeeee 100%); | |
margin-right: 15%; | |
border-bottom-left-radius: 4px; | |
} | |
/* Code block styling */ | |
pre { | |
background-color: #f8f9fa; | |
border: 1px solid #e9ecef; | |
border-radius: 6px; | |
padding: 12px; | |
overflow-x: auto; | |
font-family: 'Consolas', 'Monaco', 'Courier New', monospace; | |
font-size: 13px; | |
line-height: 1.4; | |
} | |
/* Button styling */ | |
.gradio-button { | |
border-radius: 8px; | |
font-weight: 500; | |
transition: all 0.2s ease; | |
} | |
.gradio-button:hover { | |
transform: translateY(-1px); | |
box-shadow: 0 4px 8px rgba(0,0,0,0.15); | |
} | |
/* Input styling */ | |
.gradio-textbox { | |
border-radius: 8px; | |
border: 2px solid #e0e0e0; | |
transition: border-color 0.2s ease; | |
} | |
.gradio-textbox:focus { | |
border-color: #4a90e2; | |
box-shadow: 0 0 0 3px rgba(74, 144, 226, 0.1); | |
} | |
/* Slider styling */ | |
.gradio-slider { | |
margin: 8px 0; | |
} | |
/* Examples styling */ | |
.gradio-examples { | |
margin-top: 16px; | |
} | |
.gradio-examples .gradio-button { | |
background: linear-gradient(135deg, #f8f9fa 0%, #e9ecef 100%); | |
border: 1px solid #dee2e6; | |
color: #495057; | |
font-size: 13px; | |
padding: 8px 12px; | |
} | |
.gradio-examples .gradio-button:hover { | |
background: linear-gradient(135deg, #e9ecef 0%, #dee2e6 100%); | |
color: #212529; | |
} | |
""" | |
# Create Gradio interface | |
with gr.Blocks( | |
title="๐ค Dhanishtha-2.0-preview Chat", | |
theme=gr.themes.Soft(), | |
css=custom_css | |
) as demo: | |
gr.Markdown( | |
""" | |
# ๐ค Dhanishtha-2.0-preview Chat | |
Chat with the **HelpingAI/Dhanishtha-2.0-preview** model - The world's first LLM designed to think between responses! | |
### โจ Key Features: | |
- ๐ง **Multi-step Reasoning**: Unlike other LLMs that think once, Dhanishtha can think, rethink, self-evaluate, and refine using multiple `<think>` blocks | |
- ๐ **Iterative Thinking**: Watch the model's thought process unfold in real-time | |
- ๐ก **Enhanced Problem Solving**: Better reasoning capabilities through structured thinking | |
**Note**: The `<think>` blocks show the model's internal reasoning process and will be displayed in a formatted way below. | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(scale=4): | |
chatbot = gr.Chatbot( | |
[], | |
elem_id="chatbot", | |
bubble_full_width=False, | |
height=600, | |
show_copy_button=True, | |
show_share_button=True, | |
avatar_images=("๐ค", "๐ค"), | |
render_markdown=True, | |
latex_delimiters=[ | |
{"left": "$$", "right": "$$", "display": True}, | |
{"left": "$", "right": "$", "display": False} | |
] | |
) | |
with gr.Row(): | |
msg = gr.Textbox( | |
container=False, | |
placeholder="Ask me anything! I'll show you my thinking process...", | |
label="Message", | |
autofocus=True, | |
scale=8, | |
lines=1, | |
max_lines=5 | |
) | |
send_btn = gr.Button("๐ Send", variant="primary", scale=1, size="lg") | |
with gr.Column(scale=1, min_width=300): | |
gr.Markdown("### โ๏ธ Generation Parameters") | |
max_tokens = gr.Slider( | |
minimum=50, | |
maximum=8192, | |
value=2048, | |
step=50, | |
label="๐ฏ Max Tokens", | |
info="Maximum number of tokens to generate" | |
) | |
temperature = gr.Slider( | |
minimum=0.1, | |
maximum=2.0, | |
value=0.7, | |
step=0.1, | |
label="๐ก๏ธ Temperature", | |
info="Higher = more creative, Lower = more focused" | |
) | |
top_p = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.9, | |
step=0.05, | |
label="๐ฒ Top-p", | |
info="Nucleus sampling threshold" | |
) | |
with gr.Row(): | |
clear_btn = gr.Button("๐๏ธ Clear Chat", variant="secondary", scale=1) | |
stop_btn = gr.Button("โน๏ธ Stop", variant="stop", scale=1) | |
gr.Markdown("### ๐ Model Info") | |
gr.Markdown( | |
""" | |
**Model**: HelpingAI/Dhanishtha-2.0-preview | |
**Type**: Reasoning LLM with thinking blocks | |
**Features**: Multi-step reasoning, self-evaluation | |
""" | |
) | |
# Event handlers | |
def submit_message(message, history, max_tokens, temperature, top_p): | |
"""Handle message submission""" | |
return chat_interface(message, history, max_tokens, temperature, top_p) | |
def clear_chat(): | |
"""Clear the chat history""" | |
return [], "" | |
# Message submission events | |
msg.submit( | |
submit_message, | |
inputs=[msg, chatbot, max_tokens, temperature, top_p], | |
outputs=[chatbot, msg], | |
concurrency_limit=1, | |
show_progress="minimal" | |
) | |
send_btn.click( | |
submit_message, | |
inputs=[msg, chatbot, max_tokens, temperature, top_p], | |
outputs=[chatbot, msg], | |
concurrency_limit=1, | |
show_progress="minimal" | |
) | |
# Clear chat event | |
clear_btn.click( | |
clear_chat, | |
outputs=[chatbot, msg], | |
show_progress=False | |
) | |
# Example prompts section | |
with gr.Row(): | |
gr.Examples( | |
examples=[ | |
["Hello! Can you introduce yourself and show me how you think?"], | |
["Solve this step by step: What is 15% of 240?"], | |
["Explain quantum entanglement in simple terms"], | |
["Write a short Python function to find the factorial of a number"], | |
["What are the pros and cons of renewable energy?"], | |
["Help me understand the difference between AI and machine learning"], | |
["Create a haiku about artificial intelligence"], | |
["Explain why the sky is blue using physics principles"] | |
], | |
inputs=msg, | |
label="๐ก Example Prompts - Try these to see the thinking process!", | |
examples_per_page=4 | |
) | |
# Footer with information | |
gr.Markdown( | |
""" | |
--- | |
### ๐ง Technical Details | |
- **Model**: HelpingAI/Dhanishtha-2.0-preview | |
- **Framework**: Transformers + Gradio | |
- **Features**: Real-time streaming, thinking process visualization, custom sampling | |
- **Reasoning**: Multi-step thinking with `<think>` blocks for transparent AI reasoning | |
**Note**: This interface streams responses token by token and formats thinking blocks for better readability. | |
The model's internal reasoning process is displayed in formatted code blocks. | |
--- | |
*Built with โค๏ธ using Gradio and Transformers* | |
""" | |
) | |
if __name__ == "__main__": | |
demo.queue( | |
max_size=20, | |
default_concurrency_limit=1 | |
).launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=False, | |
show_error=True, | |
quiet=False | |
) |