import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer import torch import time import spaces # Model configurations MODELS = { "Athena-R3X 8B": "Spestly/Athena-R3X-8B", "Athena-R3X 4B": "Spestly/Athena-R3X-4B", "Athena-R3 7B": "Spestly/Athena-R3-7B", "Athena-3 3B": "Spestly/Athena-3-3B", "Athena-3 7B": "Spestly/Athena-3-7B", "Athena-3 14B": "Spestly/Athena-3-14B", "Athena-2 1.5B": "Spestly/Athena-2-1.5B", "Athena-1 3B": "Spestly/Athena-1-3B", "Athena-1 7B": "Spestly/Athena-1-7B" } @spaces.GPU def generate_response(model_id, conversation, user_message, max_length=512, temperature=0.7): """Generate response using ZeroGPU - all CUDA operations happen here""" # Load model and tokenizer inside the GPU function print(f"🚀 Loading {model_id}...") start_time = time.time() tokenizer = AutoTokenizer.from_pretrained(model_id) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.float16, device_map="auto", trust_remote_code=True ) load_time = time.time() - start_time print(f"✅ Model loaded in {load_time:.2f}s") # Build conversation history conversation_history = [] for user_msg, assistant_msg in conversation: if user_msg: conversation_history.append(f"User: {user_msg}") if assistant_msg: conversation_history.append(f"Athena: {assistant_msg}") # Add current user message conversation_history.append(f"User: {user_message}") conversation_history.append("Athena:") # Create prompt prompt = "\n".join(conversation_history) # Tokenize and move to GPU inputs = tokenizer(prompt, return_tensors="pt") # Move inputs to the same device as the model device = next(model.parameters()).device inputs = {k: v.to(device) for k, v in inputs.items()} generation_start = time.time() with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=max_length, temperature=temperature, do_sample=True, top_p=0.9, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id ) generation_time = time.time() - generation_start # Decode response response = tokenizer.decode( outputs[0][inputs['input_ids'].shape[-1]:], skip_special_tokens=True ).strip() return response, load_time, generation_time def chatbot(conversation, user_message, model_name, max_length=512, temperature=0.7): if not user_message.strip(): return conversation, "", "Please enter a message" if conversation is None: conversation = [] # Get model ID model_id = MODELS.get(model_name, MODELS["Athena-R3X 8B"]) try: # Add user message to conversation conversation.append([user_message, ""]) # Generate response using ZeroGPU response, load_time, generation_time = generate_response( model_id, conversation[:-1], user_message, max_length, temperature ) # Update the conversation with the response conversation[-1][1] = response stats = f"⚡ Load: {load_time:.1f}s | Gen: {generation_time:.1f}s | Model: {model_name}" return conversation, "", stats except Exception as e: error_msg = f"Error: {str(e)}" if conversation: conversation[-1][1] = error_msg else: conversation = [[user_message, error_msg]] return conversation, "", f"❌ Error: {str(e)}" def clear_chat(): return [], "", "" # CSS for better styling css = """ #chatbot { height: 600px; } .message { padding: 10px; margin: 5px; border-radius: 10px; } """ # Create Gradio interface with gr.Blocks(title="Athena Playground Chat", css=css) as demo: gr.Markdown("# 🚀 Athena Playground Chat") gr.Markdown("*Powered by HuggingFace ZeroGPU*") with gr.Row(): with gr.Column(scale=1): model_choice = gr.Dropdown( label="📱 Model", choices=list(MODELS.keys()), value="Athena-R3X 8B", info="Select which Athena model to use" ) max_length = gr.Slider( 32, 2048, value=512, label="📝 Max Tokens", info="Maximum number of tokens to generate" ) temperature = gr.Slider( 0.1, 2.0, value=0.7, label="🎨 Creativity", info="Higher values = more creative responses" ) clear_btn = gr.Button("🗑️ Clear Chat", variant="secondary") with gr.Column(scale=3): chat_history = gr.Chatbot( elem_id="chatbot", show_label=False, avatar_images=["👤", "🤖"] ) user_input = gr.Textbox( placeholder="Ask Athena anything...", label="Your message", lines=2, max_lines=10 ) with gr.Row(): submit_btn = gr.Button("📤 Send", variant="primary") stats_output = gr.Textbox( label="Stats", interactive=False, show_label=False, placeholder="Stats will appear here..." ) # Event handlers submit_btn.click( chatbot, inputs=[chat_history, user_input, model_choice, max_length, temperature], outputs=[chat_history, user_input, stats_output] ) user_input.submit( chatbot, inputs=[chat_history, user_input, model_choice, max_length, temperature], outputs=[chat_history, user_input, stats_output] ) clear_btn.click( clear_chat, inputs=[], outputs=[chat_history, user_input, stats_output] ) if __name__ == "__main__": demo.launch()