Abhaykoul's picture
Create app.py
ac24bf9 verified
raw
history blame
14 kB
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import spaces
import re
# Model configuration
model_name = "HelpingAI/Dhanishtha-2.0-preview"
# Global variables for model and tokenizer
model = None
tokenizer = None
def load_model():
"""Load the model and tokenizer"""
global model, tokenizer
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Ensure pad token is set
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
device_map="auto",
trust_remote_code=True
)
print("Model loaded successfully!")
def format_thinking_text(text):
"""Format text to properly display <think> tags in Gradio with better styling"""
if not text:
return text
# More sophisticated formatting for thinking blocks
# Replace <think> and </think> tags with styled markdown
formatted_text = text
# Handle thinking blocks with proper markdown formatting
thinking_pattern = r'<think>(.*?)</think>'
def replace_thinking_block(match):
thinking_content = match.group(1).strip()
return f"\n\n๐Ÿ’ญ **Thinking Process:**\n\n```\n{thinking_content}\n```\n\n"
formatted_text = re.sub(thinking_pattern, replace_thinking_block, formatted_text, flags=re.DOTALL)
# Clean up any remaining raw tags that might not have been caught
formatted_text = re.sub(r'</?think>', '', formatted_text)
return formatted_text.strip()
@spaces.GPU()
def generate_response(message, history, max_tokens, temperature, top_p):
"""Generate streaming response without threading"""
global model, tokenizer
if model is None or tokenizer is None:
yield "Model is still loading. Please wait..."
return
# Prepare conversation history
messages = []
for user_msg, assistant_msg in history:
messages.append({"role": "user", "content": user_msg})
if assistant_msg:
messages.append({"role": "assistant", "content": assistant_msg})
# Add current message
messages.append({"role": "user", "content": message})
# Apply chat template
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# Tokenize input
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
try:
with torch.no_grad():
# Use transformers streaming with custom approach
generated_text = ""
current_input_ids = model_inputs["input_ids"]
current_attention_mask = model_inputs["attention_mask"]
for _ in range(max_tokens):
# Generate next token
outputs = model(
input_ids=current_input_ids,
attention_mask=current_attention_mask,
use_cache=True
)
# Get logits for the last token
logits = outputs.logits[0, -1, :]
# Apply temperature
if temperature != 1.0:
logits = logits / temperature
# Apply top-p sampling
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
sorted_indices_to_remove[0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[indices_to_remove] = float('-inf')
# Sample next token
probs = torch.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
# Check for EOS token
if next_token.item() == tokenizer.eos_token_id:
break
# Decode the new token (preserve special tokens like <think>)
new_token_text = tokenizer.decode(next_token, skip_special_tokens=False)
generated_text += new_token_text
# Format and yield the current text
formatted_text = format_thinking_text(generated_text)
yield formatted_text
# Update inputs for next iteration
current_input_ids = torch.cat([current_input_ids, next_token.unsqueeze(0)], dim=-1)
current_attention_mask = torch.cat([current_attention_mask, torch.ones((1, 1), device=model.device)], dim=-1)
except Exception as e:
yield f"Error generating response: {str(e)}"
return
# Final yield with complete formatted text
final_text = format_thinking_text(generated_text) if generated_text else "No response generated."
yield final_text
def chat_interface(message, history, max_tokens, temperature, top_p):
"""Main chat interface with improved streaming"""
if not message.strip():
return history, ""
# Add user message to history
history.append([message, ""])
# Generate response with streaming
for partial_response in generate_response(message, history[:-1], max_tokens, temperature, top_p):
history[-1][1] = partial_response
yield history, ""
return history, ""
# Load model on startup
print("Initializing model...")
load_model()
# Custom CSS for better styling and thinking blocks
custom_css = """
/* Main chatbot styling */
.chatbot {
font-size: 14px;
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
}
/* Thinking block styling */
.thinking-block {
background: linear-gradient(135deg, #f0f8ff 0%, #e6f3ff 100%);
border-left: 4px solid #4a90e2;
border-radius: 8px;
padding: 12px 16px;
margin: 12px 0;
font-family: 'Consolas', 'Monaco', 'Courier New', monospace;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
position: relative;
}
.thinking-block::before {
content: "๐Ÿค”";
position: absolute;
top: -8px;
left: 12px;
background: white;
padding: 0 4px;
font-size: 16px;
}
/* Message styling */
.message {
padding: 10px 14px;
margin: 6px 0;
border-radius: 12px;
line-height: 1.5;
}
.user-message {
background: linear-gradient(135deg, #e3f2fd 0%, #bbdefb 100%);
margin-left: 15%;
border-bottom-right-radius: 4px;
}
.assistant-message {
background: linear-gradient(135deg, #f5f5f5 0%, #eeeeee 100%);
margin-right: 15%;
border-bottom-left-radius: 4px;
}
/* Code block styling */
pre {
background-color: #f8f9fa;
border: 1px solid #e9ecef;
border-radius: 6px;
padding: 12px;
overflow-x: auto;
font-family: 'Consolas', 'Monaco', 'Courier New', monospace;
font-size: 13px;
line-height: 1.4;
}
/* Button styling */
.gradio-button {
border-radius: 8px;
font-weight: 500;
transition: all 0.2s ease;
}
.gradio-button:hover {
transform: translateY(-1px);
box-shadow: 0 4px 8px rgba(0,0,0,0.15);
}
/* Input styling */
.gradio-textbox {
border-radius: 8px;
border: 2px solid #e0e0e0;
transition: border-color 0.2s ease;
}
.gradio-textbox:focus {
border-color: #4a90e2;
box-shadow: 0 0 0 3px rgba(74, 144, 226, 0.1);
}
/* Slider styling */
.gradio-slider {
margin: 8px 0;
}
/* Examples styling */
.gradio-examples {
margin-top: 16px;
}
.gradio-examples .gradio-button {
background: linear-gradient(135deg, #f8f9fa 0%, #e9ecef 100%);
border: 1px solid #dee2e6;
color: #495057;
font-size: 13px;
padding: 8px 12px;
}
.gradio-examples .gradio-button:hover {
background: linear-gradient(135deg, #e9ecef 0%, #dee2e6 100%);
color: #212529;
}
"""
# Create Gradio interface
with gr.Blocks(
title="๐Ÿค– Dhanishtha-2.0-preview Chat",
theme=gr.themes.Soft(),
css=custom_css
) as demo:
gr.Markdown(
"""
# ๐Ÿค– Dhanishtha-2.0-preview Chat
Chat with the **HelpingAI/Dhanishtha-2.0-preview** model - The world's first LLM designed to think between responses!
### โœจ Key Features:
- ๐Ÿง  **Multi-step Reasoning**: Unlike other LLMs that think once, Dhanishtha can think, rethink, self-evaluate, and refine using multiple `<think>` blocks
- ๐Ÿ”„ **Iterative Thinking**: Watch the model's thought process unfold in real-time
- ๐Ÿ’ก **Enhanced Problem Solving**: Better reasoning capabilities through structured thinking
**Note**: The `<think>` blocks show the model's internal reasoning process and will be displayed in a formatted way below.
"""
)
with gr.Row():
with gr.Column(scale=4):
chatbot = gr.Chatbot(
[],
elem_id="chatbot",
bubble_full_width=False,
height=600,
show_copy_button=True,
show_share_button=True,
avatar_images=("๐Ÿ‘ค", "๐Ÿค–"),
render_markdown=True,
latex_delimiters=[
{"left": "$$", "right": "$$", "display": True},
{"left": "$", "right": "$", "display": False}
]
)
with gr.Row():
msg = gr.Textbox(
container=False,
placeholder="Ask me anything! I'll show you my thinking process...",
label="Message",
autofocus=True,
scale=8,
lines=1,
max_lines=5
)
send_btn = gr.Button("๐Ÿš€ Send", variant="primary", scale=1, size="lg")
with gr.Column(scale=1, min_width=300):
gr.Markdown("### โš™๏ธ Generation Parameters")
max_tokens = gr.Slider(
minimum=50,
maximum=8192,
value=2048,
step=50,
label="๐ŸŽฏ Max Tokens",
info="Maximum number of tokens to generate"
)
temperature = gr.Slider(
minimum=0.1,
maximum=2.0,
value=0.7,
step=0.1,
label="๐ŸŒก๏ธ Temperature",
info="Higher = more creative, Lower = more focused"
)
top_p = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.9,
step=0.05,
label="๐ŸŽฒ Top-p",
info="Nucleus sampling threshold"
)
with gr.Row():
clear_btn = gr.Button("๐Ÿ—‘๏ธ Clear Chat", variant="secondary", scale=1)
stop_btn = gr.Button("โน๏ธ Stop", variant="stop", scale=1)
gr.Markdown("### ๐Ÿ“Š Model Info")
gr.Markdown(
"""
**Model**: HelpingAI/Dhanishtha-2.0-preview
**Type**: Reasoning LLM with thinking blocks
**Features**: Multi-step reasoning, self-evaluation
"""
)
# Event handlers
def submit_message(message, history, max_tokens, temperature, top_p):
"""Handle message submission"""
return chat_interface(message, history, max_tokens, temperature, top_p)
def clear_chat():
"""Clear the chat history"""
return [], ""
# Message submission events
msg.submit(
submit_message,
inputs=[msg, chatbot, max_tokens, temperature, top_p],
outputs=[chatbot, msg],
concurrency_limit=1,
show_progress="minimal"
)
send_btn.click(
submit_message,
inputs=[msg, chatbot, max_tokens, temperature, top_p],
outputs=[chatbot, msg],
concurrency_limit=1,
show_progress="minimal"
)
# Clear chat event
clear_btn.click(
clear_chat,
outputs=[chatbot, msg],
show_progress=False
)
# Example prompts section
with gr.Row():
gr.Examples(
examples=[
["Hello! Can you introduce yourself and show me how you think?"],
["Solve this step by step: What is 15% of 240?"],
["Explain quantum entanglement in simple terms"],
["Write a short Python function to find the factorial of a number"],
["What are the pros and cons of renewable energy?"],
["Help me understand the difference between AI and machine learning"],
["Create a haiku about artificial intelligence"],
["Explain why the sky is blue using physics principles"]
],
inputs=msg,
label="๐Ÿ’ก Example Prompts - Try these to see the thinking process!",
examples_per_page=4
)
# Footer with information
gr.Markdown(
"""
---
### ๐Ÿ”ง Technical Details
- **Model**: HelpingAI/Dhanishtha-2.0-preview
- **Framework**: Transformers + Gradio
- **Features**: Real-time streaming, thinking process visualization, custom sampling
- **Reasoning**: Multi-step thinking with `<think>` blocks for transparent AI reasoning
**Note**: This interface streams responses token by token and formats thinking blocks for better readability.
The model's internal reasoning process is displayed in formatted code blocks.
---
*Built with โค๏ธ using Gradio and Transformers*
"""
)
if __name__ == "__main__":
demo.queue(
max_size=20,
default_concurrency_limit=1
).launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_error=True,
quiet=False
)