Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
import time | |
import spaces | |
# Model configurations | |
MODELS = { | |
"Athena-R3X 8B": "Spestly/Athena-R3X-8B", | |
"Athena-R3X 4B": "Spestly/Athena-R3X-4B", | |
"Athena-R3 7B": "Spestly/Athena-R3-7B", | |
"Athena-3 3B": "Spestly/Athena-3-3B", | |
"Athena-3 7B": "Spestly/Athena-3-7B", | |
"Athena-3 14B": "Spestly/Athena-3-14B", | |
"Athena-2 1.5B": "Spestly/Athena-2-1.5B", | |
"Athena-1 3B": "Spestly/Athena-1-3B", | |
"Athena-1 7B": "Spestly/Athena-1-7B" | |
} | |
def generate_response(model_id, conversation, user_message, max_length=512, temperature=0.7): | |
"""Generate response using ZeroGPU - all CUDA operations happen here""" | |
# Load model and tokenizer inside the GPU function | |
print(f"π Loading {model_id}...") | |
start_time = time.time() | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
torch_dtype=torch.float16, | |
device_map="auto", | |
trust_remote_code=True | |
) | |
load_time = time.time() - start_time | |
print(f"β Model loaded in {load_time:.2f}s") | |
# Build messages in proper chat format | |
messages = [] | |
# Add system prompt first | |
system_prompt = "You are Athena, a helpful, harmless, and honest AI assistant. You provide clear, accurate, and concise responses to user questions. You are knowledgeable across many domains and always aim to be respectful and helpful. You are finetuned by Aayan Mishra" | |
messages.append({"role": "system", "content": system_prompt}) | |
# Add conversation history | |
for user_msg, assistant_msg in conversation: | |
if user_msg: | |
messages.append({"role": "user", "content": user_msg}) | |
if assistant_msg: | |
messages.append({"role": "assistant", "content": assistant_msg}) | |
# Add current user message | |
messages.append({"role": "user", "content": user_message}) | |
# Apply chat template | |
prompt = tokenizer.apply_chat_template( | |
messages, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
# Tokenize and move to GPU | |
inputs = tokenizer(prompt, return_tensors="pt") | |
# Move inputs to the same device as the model | |
device = next(model.parameters()).device | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
generation_start = time.time() | |
with torch.no_grad(): | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=max_length, | |
temperature=temperature, | |
do_sample=True, | |
top_p=0.9, | |
pad_token_id=tokenizer.eos_token_id, | |
eos_token_id=tokenizer.eos_token_id | |
) | |
generation_time = time.time() - generation_start | |
# Decode response | |
response = tokenizer.decode( | |
outputs[0][inputs['input_ids'].shape[-1]:], | |
skip_special_tokens=True | |
).strip() | |
return response, load_time, generation_time | |
def chatbot(conversation, user_message, model_name, max_length=512, temperature=0.7): | |
if not user_message.strip(): | |
return conversation, "", "Please enter a message" | |
if conversation is None: | |
conversation = [] | |
# Get model ID | |
model_id = MODELS.get(model_name, MODELS["Athena-R3X 8B"]) | |
try: | |
# Add user message to conversation | |
conversation.append([user_message, ""]) | |
# Generate response using ZeroGPU | |
response, load_time, generation_time = generate_response( | |
model_id, conversation[:-1], user_message, max_length, temperature | |
) | |
# Update the conversation with the response | |
conversation[-1][1] = response | |
stats = f"β‘ Load: {load_time:.1f}s | Gen: {generation_time:.1f}s | Model: {model_name}" | |
return conversation, "", stats | |
except Exception as e: | |
error_msg = f"Error: {str(e)}" | |
if conversation: | |
conversation[-1][1] = error_msg | |
else: | |
conversation = [[user_message, error_msg]] | |
return conversation, "", f"β Error: {str(e)}" | |
def clear_chat(): | |
return [], "", "" | |
# CSS for better styling | |
css = """ | |
#chatbot { | |
height: 600px; | |
} | |
.message { | |
padding: 10px; | |
margin: 5px; | |
border-radius: 10px; | |
} | |
""" | |
# Create Gradio interface | |
with gr.Blocks(title="Athena Playground Chat", css=css) as demo: | |
gr.Markdown("# π Athena Playground Chat") | |
gr.Markdown("*Powered by HuggingFace ZeroGPU*") | |
# Main chat interface | |
chat_history = gr.Chatbot( | |
elem_id="chatbot", | |
show_label=False, | |
show_share_button=False, | |
container=False | |
) | |
user_input = gr.Textbox( | |
placeholder="Ask Athena anything...", | |
label="Your message", | |
lines=2, | |
max_lines=10 | |
) | |
with gr.Row(): | |
submit_btn = gr.Button("π€ Send", variant="primary") | |
clear_btn = gr.Button("ποΈ Clear Chat", variant="secondary") | |
stats_output = gr.Textbox( | |
label="Stats", | |
interactive=False, | |
show_label=False, | |
placeholder="Stats will appear here..." | |
) | |
# Configuration settings at the bottom | |
gr.Markdown("---") | |
gr.Markdown("## βοΈ Configuration") | |
with gr.Row(): | |
with gr.Column(): | |
model_choice = gr.Dropdown( | |
label="π± Model", | |
choices=list(MODELS.keys()), | |
value="Athena-R3X 8B", | |
info="Select which Athena model to use" | |
) | |
with gr.Column(): | |
max_length = gr.Slider( | |
32, 2048, value=512, | |
label="π Max Tokens", | |
info="Maximum number of tokens to generate" | |
) | |
with gr.Column(): | |
temperature = gr.Slider( | |
0.1, 2.0, value=0.7, | |
label="π¨ Creativity", | |
info="Higher values = more creative responses" | |
) | |
# Event handlers | |
submit_btn.click( | |
chatbot, | |
inputs=[chat_history, user_input, model_choice, max_length, temperature], | |
outputs=[chat_history, user_input, stats_output] | |
) | |
user_input.submit( | |
chatbot, | |
inputs=[chat_history, user_input, model_choice, max_length, temperature], | |
outputs=[chat_history, user_input, stats_output] | |
) | |
clear_btn.click( | |
clear_chat, | |
inputs=[], | |
outputs=[chat_history, user_input, stats_output] | |
) | |
if __name__ == "__main__": | |
demo.launch() |