import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer import torch import time import spaces import re # Model configurations MODELS = { "Athena-R3X 8B": "Spestly/Athena-R3X-8B", "Athena-R3X 4B": "Spestly/Athena-R3X-4B", "Athena-R3 7B": "Spestly/Athena-R3-7B", "Athena-3 3B": "Spestly/Athena-3-3B", "Athena-3 7B": "Spestly/Athena-3-7B", "Athena-3 14B": "Spestly/Athena-3-14B", "Athena-2 1.5B": "Spestly/Athena-2-1.5B", "Athena-1 3B": "Spestly/Athena-1-3B", "Athena-1 7B": "Spestly/Athena-1-7B" } # Models that need the enable_thinking parameter THINKING_ENABLED_MODELS = ["Spestly/Athena-R3X-4B"] @spaces.GPU def generate_response(model_id, conversation, user_message, max_length=512, temperature=0.7): """Generate response using ZeroGPU - all CUDA operations happen here""" print(f"🚀 Loading {model_id}...") start_time = time.time() tokenizer = AutoTokenizer.from_pretrained(model_id) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.float16, device_map="auto", trust_remote_code=True ) load_time = time.time() - start_time print(f"✅ Model loaded in {load_time:.2f}s") # Build messages in proper chat format (OpenAI-style messages) messages = [] system_prompt = ( "You are Athena, a helpful, harmless, and honest AI assistant. " "You provide clear, accurate, and concise responses to user questions. " "You are knowledgeable across many domains and always aim to be respectful and helpful. " "You are finetuned by Aayan Mishra" ) messages.append({"role": "system", "content": system_prompt}) # Add conversation history for msg in conversation: messages.append(msg) # Add current user message messages.append({"role": "user", "content": user_message}) # Check if this model needs the enable_thinking parameter if model_id in THINKING_ENABLED_MODELS: prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, enable_thinking=True ) else: prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) inputs = tokenizer(prompt, return_tensors="pt") device = next(model.parameters()).device inputs = {k: v.to(device) for k, v in inputs.items()} generation_start = time.time() with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=max_length, temperature=temperature, do_sample=True, top_p=0.9, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id ) generation_time = time.time() - generation_start response = tokenizer.decode( outputs[0][inputs['input_ids'].shape[-1]:], skip_special_tokens=True ).strip() print(f"Generation time: {generation_time:.2f}s") return response, load_time, generation_time def format_response_with_thinking(response): """Format response to handle tags""" # Check if response contains thinking tags if '' in response and '' in response: # Split the response into parts pattern = r'(.*?)((.*?))(.*)' match = re.search(pattern, response, re.DOTALL) if match: before_thinking = match.group(1).strip() thinking_content = match.group(3).strip() after_thinking = match.group(4).strip() # Create HTML with collapsible thinking section html = f"{before_thinking}\n" html += f'
' html += f'' html += f'' html += f'
\n' html += after_thinking return html # If no thinking tags, return the original response return response def chat_submit(message, history, conversation_state, model_name, max_length, temperature): """Process a new message and update the chat history""" # For debugging - print when the function is called print(f"chat_submit function called with message: '{message}'") if not message or not message.strip(): print("Empty message, returning without processing") return "", history, conversation_state model_id = MODELS.get(model_name, MODELS["Athena-R3X 4B"]) try: response, load_time, generation_time = generate_response( model_id, conversation_state, message, max_length, temperature ) # Update the conversation state with the raw response conversation_state.append({"role": "user", "content": message}) conversation_state.append({"role": "assistant", "content": response}) # Format the response for display formatted_response = format_response_with_thinking(response) # Update the visible chat history history.append((message, formatted_response)) print(f"Response added to history. Current length: {len(history)}") return "", history, conversation_state except Exception as e: import traceback print(f"Error in chat_submit: {str(e)}") print(traceback.format_exc()) error_message = f"Error: {str(e)}" history.append((message, error_message)) return "", history, conversation_state css = """ .message { padding: 10px; margin: 5px; border-radius: 10px; } .thinking-container { margin: 10px 0; } .thinking-toggle { background-color: rgba(30, 30, 40, 0.8); border: none; border-radius: 25px; padding: 8px 15px; cursor: pointer; font-size: 0.95em; margin-bottom: 8px; color: white; display: flex; align-items: center; gap: 8px; box-shadow: 0 2px 5px rgba(0,0,0,0.2); transition: background-color 0.2s; width: auto; max-width: 280px; } .thinking-toggle:hover { background-color: rgba(40, 40, 50, 0.9); } .thinking-icon { width: 16px; height: 16px; border-radius: 50%; background-color: #6366f1; position: relative; overflow: hidden; } .thinking-icon::after { content: ""; position: absolute; top: 50%; left: 50%; width: 60%; height: 60%; background-color: #a5b4fc; transform: translate(-50%, -50%); border-radius: 50%; } .dropdown-arrow { font-size: 0.7em; margin-left: auto; transition: transform 0.3s; } .thinking-content { background-color: rgba(30, 30, 40, 0.8); border-left: 2px solid #6366f1; padding: 15px; margin-top: 5px; margin-bottom: 15px; font-size: 0.95em; color: #e2e8f0; font-family: monospace; white-space: pre-wrap; overflow-x: auto; border-radius: 5px; line-height: 1.5; } .hidden { display: none; } """ # Add JavaScript to make the thinking buttons work js = """ function setupThinkingToggle() { document.querySelectorAll('.thinking-toggle').forEach(button => { if (!button.hasEventListener) { button.addEventListener('click', function() { const content = this.nextElementSibling; content.classList.toggle('hidden'); const arrow = this.querySelector('.dropdown-arrow'); if (content.classList.contains('hidden')) { arrow.textContent = '▼'; arrow.style.transform = ''; } else { arrow.textContent = '▲'; arrow.style.transform = 'rotate(0deg)'; } }); button.hasEventListener = true; } }); } // Setup a mutation observer to watch for changes in the DOM const observer = new MutationObserver(function(mutations) { setupThinkingToggle(); }); // Start observing after DOM is loaded document.addEventListener('DOMContentLoaded', () => { setupThinkingToggle(); setTimeout(() => { const chatbot = document.querySelector('.chatbot'); if (chatbot) { observer.observe(chatbot, { childList: true, subtree: true, characterData: true }); } else { observer.observe(document.body, { childList: true, subtree: true }); } }, 1000); }); """ # Create Gradio interface with gr.Blocks(title="Athena Playground Chat", css=css, js=js) as demo: gr.Markdown("# 🚀 Athena Playground Chat") gr.Markdown("*Powered by HuggingFace ZeroGPU*") # State to keep track of the conversation for the model conversation_state = gr.State([]) # Chatbot component chatbot = gr.Chatbot( height=500, label="Athena", render_markdown=True, elem_classes=["chatbot"] ) # Input and send button row with gr.Row(): user_input = gr.Textbox( label="Your message", scale=8, autofocus=True, placeholder="Type your message here...", lines=2 ) send_btn = gr.Button( value="Send", scale=1, variant="primary" ) # Clear button clear_btn = gr.Button("Clear Conversation") # Configuration controls gr.Markdown("### ⚙️ Model & Generation Settings") with gr.Row(): model_choice = gr.Dropdown( label="📱 Model", choices=list(MODELS.keys()), value="Athena-R3X 4B", info="Select which Athena model to use" ) max_length = gr.Slider( 32, 8192, value=512, label="📝 Max Tokens", info="Maximum number of tokens to generate" ) temperature = gr.Slider( 0.1, 2.0, value=0.7, label="🎨 Creativity", info="Higher values = more creative responses" ) # Function to clear the conversation def clear_conversation(): return [], [] # Connect the interface components with explicit handlers submit_click = user_input.submit( fn=chat_submit, inputs=[user_input, chatbot, conversation_state, model_choice, max_length, temperature], outputs=[user_input, chatbot, conversation_state] ) # Connect send button explicitly send_click = send_btn.click( fn=chat_submit, inputs=[user_input, chatbot, conversation_state, model_choice, max_length, temperature], outputs=[user_input, chatbot, conversation_state] ) # Clear conversation clear_btn.click( fn=clear_conversation, outputs=[chatbot, conversation_state] ) # Examples gr.Examples( examples=[ "What is artificial intelligence?", "Can you explain quantum computing?", "Write a short poem about technology", "What are some ethical concerns about AI?" ], inputs=user_input ) gr.Markdown(""" ### About the Thinking Tags Some Athena models (particularly R3X series) include reasoning in `` tags. Click on "Thinking completed" to view the model's thought process behind its answers. """) if __name__ == "__main__": # Enable queue and debugging demo.queue() demo.launch(debug=True)