Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
import time | |
import spaces | |
import re | |
# Model configurations | |
MODELS = { | |
"Athena-R3X 8B": "Spestly/Athena-R3X-8B", | |
"Athena-R3X 4B": "Spestly/Athena-R3X-4B", | |
"Athena-R3 7B": "Spestly/Athena-R3-7B", | |
"Athena-3 3B": "Spestly/Athena-3-3B", | |
"Athena-3 7B": "Spestly/Athena-3-7B", | |
"Athena-3 14B": "Spestly/Athena-3-14B", | |
"Athena-2 1.5B": "Spestly/Athena-2-1.5B", | |
"Athena-1 3B": "Spestly/Athena-1-3B", | |
"Athena-1 7B": "Spestly/Athena-1-7B" | |
} | |
# Models that need the enable_thinking parameter | |
THINKING_ENABLED_MODELS = ["Spestly/Athena-R3X-4B"] | |
def generate_response(model_id, conversation, user_message, max_length=512, temperature=0.7): | |
"""Generate response using ZeroGPU - all CUDA operations happen here""" | |
print(f"π Loading {model_id}...") | |
start_time = time.time() | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
torch_dtype=torch.float16, | |
device_map="auto", | |
trust_remote_code=True | |
) | |
load_time = time.time() - start_time | |
print(f"β Model loaded in {load_time:.2f}s") | |
# Build messages in proper chat format (OpenAI-style messages) | |
messages = [] | |
system_prompt = ( | |
"You are Athena, a helpful, harmless, and honest AI assistant. " | |
"You provide clear, accurate, and concise responses to user questions. " | |
"You are knowledgeable across many domains and always aim to be respectful and helpful. " | |
"You are finetuned by Aayan Mishra" | |
) | |
messages.append({"role": "system", "content": system_prompt}) | |
# Add conversation history | |
for msg in conversation: | |
messages.append(msg) | |
# Add current user message | |
messages.append({"role": "user", "content": user_message}) | |
# Check if this model needs the enable_thinking parameter | |
if model_id in THINKING_ENABLED_MODELS: | |
prompt = tokenizer.apply_chat_template( | |
messages, | |
tokenize=False, | |
add_generation_prompt=True, | |
enable_thinking=True | |
) | |
else: | |
prompt = tokenizer.apply_chat_template( | |
messages, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
inputs = tokenizer(prompt, return_tensors="pt") | |
device = next(model.parameters()).device | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
generation_start = time.time() | |
with torch.no_grad(): | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=max_length, | |
temperature=temperature, | |
do_sample=True, | |
top_p=0.9, | |
pad_token_id=tokenizer.eos_token_id, | |
eos_token_id=tokenizer.eos_token_id | |
) | |
generation_time = time.time() - generation_start | |
response = tokenizer.decode( | |
outputs[0][inputs['input_ids'].shape[-1]:], | |
skip_special_tokens=True | |
).strip() | |
print(f"Generation time: {generation_time:.2f}s") | |
return response, load_time, generation_time | |
def format_response_with_thinking(response): | |
"""Format response to handle <think></think> tags""" | |
# Check if response contains thinking tags | |
if '<think>' in response and '</think>' in response: | |
# Split the response into parts | |
pattern = r'(.*?)(<think>(.*?)</think>)(.*)' | |
match = re.search(pattern, response, re.DOTALL) | |
if match: | |
before_thinking = match.group(1).strip() | |
thinking_content = match.group(3).strip() | |
after_thinking = match.group(4).strip() | |
# Create HTML with collapsible thinking section | |
html = f"{before_thinking}\n" | |
html += f'<div class="thinking-container">' | |
html += f'<button class="thinking-toggle"><div class="thinking-icon"></div> Thinking completed <span class="dropdown-arrow">βΌ</span></button>' | |
html += f'<div class="thinking-content hidden">{thinking_content}</div>' | |
html += f'</div>\n' | |
html += after_thinking | |
return html | |
# If no thinking tags, return the original response | |
return response | |
def chat_submit(message, history, conversation_state, model_name, max_length, temperature): | |
"""Process a new message and update the chat history""" | |
# For debugging - print when the function is called | |
print(f"chat_submit function called with message: '{message}'") | |
if not message or not message.strip(): | |
print("Empty message, returning without processing") | |
return "", history, conversation_state | |
model_id = MODELS.get(model_name, MODELS["Athena-R3X 4B"]) | |
try: | |
response, load_time, generation_time = generate_response( | |
model_id, conversation_state, message, max_length, temperature | |
) | |
# Update the conversation state with the raw response | |
conversation_state.append({"role": "user", "content": message}) | |
conversation_state.append({"role": "assistant", "content": response}) | |
# Format the response for display | |
formatted_response = format_response_with_thinking(response) | |
# Update the visible chat history | |
history.append((message, formatted_response)) | |
print(f"Response added to history. Current length: {len(history)}") | |
return "", history, conversation_state | |
except Exception as e: | |
import traceback | |
print(f"Error in chat_submit: {str(e)}") | |
print(traceback.format_exc()) | |
error_message = f"Error: {str(e)}" | |
history.append((message, error_message)) | |
return "", history, conversation_state | |
css = """ | |
.message { | |
padding: 10px; | |
margin: 5px; | |
border-radius: 10px; | |
} | |
.thinking-container { | |
margin: 10px 0; | |
} | |
.thinking-toggle { | |
background-color: rgba(30, 30, 40, 0.8); | |
border: none; | |
border-radius: 25px; | |
padding: 8px 15px; | |
cursor: pointer; | |
font-size: 0.95em; | |
margin-bottom: 8px; | |
color: white; | |
display: flex; | |
align-items: center; | |
gap: 8px; | |
box-shadow: 0 2px 5px rgba(0,0,0,0.2); | |
transition: background-color 0.2s; | |
width: auto; | |
max-width: 280px; | |
} | |
.thinking-toggle:hover { | |
background-color: rgba(40, 40, 50, 0.9); | |
} | |
.thinking-icon { | |
width: 16px; | |
height: 16px; | |
border-radius: 50%; | |
background-color: #6366f1; | |
position: relative; | |
overflow: hidden; | |
} | |
.thinking-icon::after { | |
content: ""; | |
position: absolute; | |
top: 50%; | |
left: 50%; | |
width: 60%; | |
height: 60%; | |
background-color: #a5b4fc; | |
transform: translate(-50%, -50%); | |
border-radius: 50%; | |
} | |
.dropdown-arrow { | |
font-size: 0.7em; | |
margin-left: auto; | |
transition: transform 0.3s; | |
} | |
.thinking-content { | |
background-color: rgba(30, 30, 40, 0.8); | |
border-left: 2px solid #6366f1; | |
padding: 15px; | |
margin-top: 5px; | |
margin-bottom: 15px; | |
font-size: 0.95em; | |
color: #e2e8f0; | |
font-family: monospace; | |
white-space: pre-wrap; | |
overflow-x: auto; | |
border-radius: 5px; | |
line-height: 1.5; | |
} | |
.hidden { | |
display: none; | |
} | |
""" | |
# Add JavaScript to make the thinking buttons work | |
js = """ | |
function setupThinkingToggle() { | |
document.querySelectorAll('.thinking-toggle').forEach(button => { | |
if (!button.hasEventListener) { | |
button.addEventListener('click', function() { | |
const content = this.nextElementSibling; | |
content.classList.toggle('hidden'); | |
const arrow = this.querySelector('.dropdown-arrow'); | |
if (content.classList.contains('hidden')) { | |
arrow.textContent = 'βΌ'; | |
arrow.style.transform = ''; | |
} else { | |
arrow.textContent = 'β²'; | |
arrow.style.transform = 'rotate(0deg)'; | |
} | |
}); | |
button.hasEventListener = true; | |
} | |
}); | |
} | |
// Setup a mutation observer to watch for changes in the DOM | |
const observer = new MutationObserver(function(mutations) { | |
setupThinkingToggle(); | |
}); | |
// Start observing after DOM is loaded | |
document.addEventListener('DOMContentLoaded', () => { | |
setupThinkingToggle(); | |
setTimeout(() => { | |
const chatbot = document.querySelector('.chatbot'); | |
if (chatbot) { | |
observer.observe(chatbot, { | |
childList: true, | |
subtree: true, | |
characterData: true | |
}); | |
} else { | |
observer.observe(document.body, { | |
childList: true, | |
subtree: true | |
}); | |
} | |
}, 1000); | |
}); | |
""" | |
# Create Gradio interface | |
with gr.Blocks(title="Athena Playground Chat", css=css, js=js) as demo: | |
gr.Markdown("# π Athena Playground Chat") | |
gr.Markdown("*Powered by HuggingFace ZeroGPU*") | |
# State to keep track of the conversation for the model | |
conversation_state = gr.State([]) | |
# Chatbot component | |
chatbot = gr.Chatbot( | |
height=500, | |
label="Athena", | |
render_markdown=True, | |
elem_classes=["chatbot"] | |
) | |
# Input and send button row | |
with gr.Row(): | |
user_input = gr.Textbox( | |
label="Your message", | |
scale=8, | |
autofocus=True, | |
placeholder="Type your message here...", | |
lines=2 | |
) | |
send_btn = gr.Button( | |
value="Send", | |
scale=1, | |
variant="primary" | |
) | |
# Clear button | |
clear_btn = gr.Button("Clear Conversation") | |
# Configuration controls | |
gr.Markdown("### βοΈ Model & Generation Settings") | |
with gr.Row(): | |
model_choice = gr.Dropdown( | |
label="π± Model", | |
choices=list(MODELS.keys()), | |
value="Athena-R3X 4B", | |
info="Select which Athena model to use" | |
) | |
max_length = gr.Slider( | |
32, 8192, value=512, | |
label="π Max Tokens", | |
info="Maximum number of tokens to generate" | |
) | |
temperature = gr.Slider( | |
0.1, 2.0, value=0.7, | |
label="π¨ Creativity", | |
info="Higher values = more creative responses" | |
) | |
# Function to clear the conversation | |
def clear_conversation(): | |
return [], [] | |
# Connect the interface components with explicit handlers | |
submit_click = user_input.submit( | |
fn=chat_submit, | |
inputs=[user_input, chatbot, conversation_state, model_choice, max_length, temperature], | |
outputs=[user_input, chatbot, conversation_state] | |
) | |
# Connect send button explicitly | |
send_click = send_btn.click( | |
fn=chat_submit, | |
inputs=[user_input, chatbot, conversation_state, model_choice, max_length, temperature], | |
outputs=[user_input, chatbot, conversation_state] | |
) | |
# Clear conversation | |
clear_btn.click( | |
fn=clear_conversation, | |
outputs=[chatbot, conversation_state] | |
) | |
# Examples | |
gr.Examples( | |
examples=[ | |
"What is artificial intelligence?", | |
"Can you explain quantum computing?", | |
"Write a short poem about technology", | |
"What are some ethical concerns about AI?" | |
], | |
inputs=user_input | |
) | |
gr.Markdown(""" | |
### About the Thinking Tags | |
Some Athena models (particularly R3X series) include reasoning in `<think></think>` tags. | |
Click on "Thinking completed" to view the model's thought process behind its answers. | |
""") | |
if __name__ == "__main__": | |
# Enable queue and debugging | |
demo.queue() | |
demo.launch(debug=True) |