GPT-OSS / app.py
Spestly's picture
Update app.py
0fa7b48 verified
raw
history blame
13.4 kB
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import time
import spaces
import re
# Model configurations
MODELS = {
"Athena-R3X 8B": "Spestly/Athena-R3X-8B",
"Athena-R3X 4B": "Spestly/Athena-R3X-4B",
"Athena-R3 7B": "Spestly/Athena-R3-7B",
"Athena-3 3B": "Spestly/Athena-3-3B",
"Athena-3 7B": "Spestly/Athena-3-7B",
"Athena-3 14B": "Spestly/Athena-3-14B",
"Athena-2 1.5B": "Spestly/Athena-2-1.5B",
"Athena-1 3B": "Spestly/Athena-1-3B",
"Athena-1 7B": "Spestly/Athena-1-7B"
}
# Models that need the enable_thinking parameter
THINKING_ENABLED_MODELS = ["Spestly/Athena-R3X-4B"]
@spaces.GPU
def generate_response(model_id, conversation, user_message, max_length=512, temperature=0.7):
"""Generate response using ZeroGPU - all CUDA operations happen here"""
print(f"πŸš€ Loading {model_id}...")
start_time = time.time()
tokenizer = AutoTokenizer.from_pretrained(model_id)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True
)
load_time = time.time() - start_time
print(f"βœ… Model loaded in {load_time:.2f}s")
# Build messages in proper chat format (OpenAI-style messages)
messages = []
system_prompt = (
"You are Athena, a helpful, harmless, and honest AI assistant. "
"You provide clear, accurate, and concise responses to user questions. "
"You are knowledgeable across many domains and always aim to be respectful and helpful. "
"You are finetuned by Aayan Mishra"
)
messages.append({"role": "system", "content": system_prompt})
# Add conversation history
for msg in conversation:
messages.append(msg)
# Add current user message
messages.append({"role": "user", "content": user_message})
# Check if this model needs the enable_thinking parameter
if model_id in THINKING_ENABLED_MODELS:
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=True
)
else:
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
inputs = tokenizer(prompt, return_tensors="pt")
device = next(model.parameters()).device
inputs = {k: v.to(device) for k, v in inputs.items()}
generation_start = time.time()
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_length,
temperature=temperature,
do_sample=True,
top_p=0.9,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id
)
generation_time = time.time() - generation_start
response = tokenizer.decode(
outputs[0][inputs['input_ids'].shape[-1]:],
skip_special_tokens=True
).strip()
print(f"Generation time: {generation_time:.2f}s")
return response, load_time, generation_time
def format_response_with_thinking(response):
"""Format response to handle <think></think> tags"""
# Check if response contains thinking tags
if '<think>' in response and '</think>' in response:
# Split the response into parts
pattern = r'(.*?)(<think>(.*?)</think>)(.*)'
match = re.search(pattern, response, re.DOTALL)
if match:
before_thinking = match.group(1).strip()
thinking_content = match.group(3).strip()
after_thinking = match.group(4).strip()
# Create HTML with collapsible thinking section
html = f"{before_thinking}\n"
html += f'<div class="thinking-container">'
html += f'<button class="thinking-toggle"><div class="thinking-icon"></div> Thinking completed <span class="dropdown-arrow">β–Ό</span></button>'
html += f'<div class="thinking-content hidden">{thinking_content}</div>'
html += f'</div>\n'
html += after_thinking
return html
# If no thinking tags, return the original response
return response
def chat_submit(message, history, conversation_state, model_name, max_length, temperature):
"""Process a new message and update the chat history"""
if not message or not message.strip():
return "", history, conversation_state
# Debug print to check function execution
print(f"Processing message in chat_submit: {message}")
model_id = MODELS.get(model_name, MODELS["Athena-R3X 4B"])
try:
response, load_time, generation_time = generate_response(
model_id, conversation_state, message, max_length, temperature
)
# Update the conversation state with the raw response
conversation_state.append({"role": "user", "content": message})
conversation_state.append({"role": "assistant", "content": response})
# Format the response for display
formatted_response = format_response_with_thinking(response)
# Update the visible chat history
history.append((message, formatted_response))
print(f"Response added to history. Current length: {len(history)}")
return "", history, conversation_state
except Exception as e:
import traceback
print(f"Error in chat_submit: {str(e)}")
print(traceback.format_exc())
error_message = f"Error: {str(e)}"
history.append((message, error_message))
return "", history, conversation_state
css = """
.message {
padding: 10px;
margin: 5px;
border-radius: 10px;
}
.thinking-container {
margin: 10px 0;
}
.thinking-toggle {
background-color: rgba(30, 30, 40, 0.8);
border: none;
border-radius: 25px;
padding: 8px 15px;
cursor: pointer;
font-size: 0.95em;
margin-bottom: 8px;
color: white;
display: flex;
align-items: center;
gap: 8px;
box-shadow: 0 2px 5px rgba(0,0,0,0.2);
transition: background-color 0.2s;
width: auto;
max-width: 280px;
}
.thinking-toggle:hover {
background-color: rgba(40, 40, 50, 0.9);
}
.thinking-icon {
width: 16px;
height: 16px;
border-radius: 50%;
background-color: #6366f1;
position: relative;
overflow: hidden;
}
.thinking-icon::after {
content: "";
position: absolute;
top: 50%;
left: 50%;
width: 60%;
height: 60%;
background-color: #a5b4fc;
transform: translate(-50%, -50%);
border-radius: 50%;
}
.dropdown-arrow {
font-size: 0.7em;
margin-left: auto;
transition: transform 0.3s;
}
.thinking-content {
background-color: rgba(30, 30, 40, 0.8);
border-left: 2px solid #6366f1;
padding: 15px;
margin-top: 5px;
margin-bottom: 15px;
font-size: 0.95em;
color: #e2e8f0;
font-family: monospace;
white-space: pre-wrap;
overflow-x: auto;
border-radius: 5px;
line-height: 1.5;
}
.hidden {
display: none;
}
"""
# Add JavaScript to make the thinking buttons work and fix enter key issues
js = """
// Function to handle thinking toggle buttons
function setupThinkingToggle() {
document.querySelectorAll('.thinking-toggle').forEach(button => {
if (!button.hasEventListener) {
button.addEventListener('click', function() {
const content = this.nextElementSibling;
content.classList.toggle('hidden');
const arrow = this.querySelector('.dropdown-arrow');
if (content.classList.contains('hidden')) {
arrow.textContent = 'β–Ό';
arrow.style.transform = '';
} else {
arrow.textContent = 'β–²';
arrow.style.transform = 'rotate(0deg)';
}
});
button.hasEventListener = true;
}
});
}
// Setup a mutation observer to watch for changes in the DOM
const observer = new MutationObserver(function(mutations) {
setupThinkingToggle();
});
// Function to ensure the textbox and submit button work correctly
function fixChatInputs() {
const textbox = document.querySelector('textarea[data-testid="textbox"]');
const submitBtn = document.querySelector('button[data-testid="send-btn"]');
if (textbox && !textbox.hasEnterListener) {
console.log("Setting up enter key handler");
textbox.addEventListener('keydown', function(e) {
if (e.key === 'Enter' && !e.shiftKey) {
e.preventDefault();
if (textbox.value.trim() !== '') {
submitBtn.click();
}
}
});
textbox.hasEnterListener = true;
}
if (submitBtn && !submitBtn.hasClickFix) {
console.log("Enhancing submit button");
submitBtn.addEventListener('click', function() {
console.log("Submit button clicked");
});
submitBtn.hasClickFix = true;
}
}
// Function to run all UI fixes
function setupUI() {
setupThinkingToggle();
fixChatInputs();
}
// Initial setup
document.addEventListener('DOMContentLoaded', () => {
console.log("DOM loaded, setting up UI");
setTimeout(setupUI, 1000);
// Set up observer after a delay
setTimeout(() => {
const chatbot = document.querySelector('.chatbot');
if (chatbot) {
observer.observe(chatbot, {
childList: true,
subtree: true,
characterData: true
});
} else {
// If chatbot container not found, observe the body
observer.observe(document.body, {
childList: true,
subtree: true
});
}
// Run UI fixes periodically
setInterval(setupUI, 2000);
}, 1000);
});
"""
theme = gr.themes.Soft()
with gr.Blocks(title="Athena Playground Chat", css=css, theme=theme, js=js) as demo:
gr.Markdown("# πŸš€ Athena Playground Chat")
gr.Markdown("*Powered by HuggingFace ZeroGPU*")
# State to keep track of the conversation for the model
conversation_state = gr.State([])
chatbot = gr.Chatbot(height=500, label="Athena", render_markdown=True, elem_classes=["chatbot"])
with gr.Row():
user_input = gr.Textbox(
label="Your message",
scale=8,
autofocus=True,
placeholder="Type your message here...",
elem_id="chat-input",
lines=2,
max_lines=10,
)
send_btn = gr.Button(
value="Send",
scale=1,
variant="primary",
elem_id="send-btn",
min_width=100
)
# Clear button for resetting the conversation
clear_btn = gr.Button("Clear Conversation")
# Configuration controls
gr.Markdown("### βš™οΈ Model & Generation Settings")
with gr.Row():
model_choice = gr.Dropdown(
label="πŸ“± Model",
choices=list(MODELS.keys()),
value="Athena-R3X 4B",
info="Select which Athena model to use"
)
max_length = gr.Slider(
32, 8192, value=512,
label="πŸ“ Max Tokens",
info="Maximum number of tokens to generate"
)
temperature = gr.Slider(
0.1, 2.0, value=0.7,
label="🎨 Creativity",
info="Higher values = more creative responses"
)
# Function to clear the conversation
def clear_conversation():
return [], []
# Connect the interface components - note the specific ordering
user_input.submit(
fn=chat_submit,
inputs=[user_input, chatbot, conversation_state, model_choice, max_length, temperature],
outputs=[user_input, chatbot, conversation_state],
api_name="submit_message"
)
# Make sure send button uses the exact same function with the same parameter ordering
send_btn.click(
fn=chat_submit,
inputs=[user_input, chatbot, conversation_state, model_choice, max_length, temperature],
outputs=[user_input, chatbot, conversation_state],
api_name="send_message"
)
# Connect clear button
clear_btn.click(
fn=clear_conversation,
outputs=[chatbot, conversation_state],
api_name="clear_chat"
)
# Add examples if desired
gr.Examples(
examples=[
"What is artificial intelligence?",
"Can you explain quantum computing?",
"Write a short poem about technology",
"What are some ethical concerns about AI?"
],
inputs=[user_input]
)
gr.Markdown("""
### About the Thinking Tags
Some Athena models (particularly R3X series) include reasoning in `<think></think>` tags.
Click on "Thinking completed" to view the model's thought process behind its answers.
""")
if __name__ == "__main__":
demo.queue() # Enable queuing for smoother experience
demo.launch(debug=True) # Enable debug mode for better error reporting