Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
import time | |
import spaces | |
# Model configurations | |
MODELS = { | |
"Athena-R3X 8B": "Spestly/Athena-R3X-8B", | |
"Athena-R3X 4B": "Spestly/Athena-R3X-4B", | |
"Athena-R3 7B": "Spestly/Athena-R3-7B", | |
"Athena-3 3B": "Spestly/Athena-3-3B", | |
"Athena-3 7B": "Spestly/Athena-3-7B", | |
"Athena-3 14B": "Spestly/Athena-3-14B", | |
"Athena-2 1.5B": "Spestly/Athena-2-1.5B", | |
"Athena-1 3B": "Spestly/Athena-1-3B", | |
"Athena-1 7B": "Spestly/Athena-1-7B" | |
} | |
def generate_response(model_id, conversation, user_message, max_length=512, temperature=0.7): | |
"""Generate response using ZeroGPU - all CUDA operations happen here""" | |
# Load model and tokenizer inside the GPU function | |
print(f"π Loading {model_id}...") | |
start_time = time.time() | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
torch_dtype=torch.float16, | |
device_map="auto", | |
trust_remote_code=True | |
) | |
load_time = time.time() - start_time | |
print(f"β Model loaded in {load_time:.2f}s") | |
# Build conversation history | |
conversation_history = [] | |
for user_msg, assistant_msg in conversation: | |
if user_msg: | |
conversation_history.append(f"User: {user_msg}") | |
if assistant_msg: | |
conversation_history.append(f"Athena: {assistant_msg}") | |
# Add current user message | |
conversation_history.append(f"User: {user_message}") | |
conversation_history.append("Athena:") | |
# Create prompt | |
prompt = "\n".join(conversation_history) | |
# Tokenize and move to GPU | |
inputs = tokenizer(prompt, return_tensors="pt") | |
# Move inputs to the same device as the model | |
device = next(model.parameters()).device | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
generation_start = time.time() | |
with torch.no_grad(): | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=max_length, | |
temperature=temperature, | |
do_sample=True, | |
top_p=0.9, | |
pad_token_id=tokenizer.eos_token_id, | |
eos_token_id=tokenizer.eos_token_id | |
) | |
generation_time = time.time() - generation_start | |
# Decode response | |
response = tokenizer.decode( | |
outputs[0][inputs['input_ids'].shape[-1]:], | |
skip_special_tokens=True | |
).strip() | |
return response, load_time, generation_time | |
def chatbot(conversation, user_message, model_name, max_length=512, temperature=0.7): | |
if not user_message.strip(): | |
return conversation, "", "Please enter a message" | |
if conversation is None: | |
conversation = [] | |
# Get model ID | |
model_id = MODELS.get(model_name, MODELS["Athena-R3X 8B"]) | |
try: | |
# Add user message to conversation | |
conversation.append([user_message, ""]) | |
# Generate response using ZeroGPU | |
response, load_time, generation_time = generate_response( | |
model_id, conversation[:-1], user_message, max_length, temperature | |
) | |
# Update the conversation with the response | |
conversation[-1][1] = response | |
stats = f"β‘ Load: {load_time:.1f}s | Gen: {generation_time:.1f}s | Model: {model_name}" | |
return conversation, "", stats | |
except Exception as e: | |
error_msg = f"Error: {str(e)}" | |
if conversation: | |
conversation[-1][1] = error_msg | |
else: | |
conversation = [[user_message, error_msg]] | |
return conversation, "", f"β Error: {str(e)}" | |
def clear_chat(): | |
return [], "", "" | |
# CSS for better styling | |
css = """ | |
#chatbot { | |
height: 600px; | |
} | |
.message { | |
padding: 10px; | |
margin: 5px; | |
border-radius: 10px; | |
} | |
""" | |
# Create Gradio interface | |
with gr.Blocks(title="Athena Playground Chat", css=css) as demo: | |
gr.Markdown("# π Athena Playground Chat") | |
gr.Markdown("*Powered by HuggingFace ZeroGPU*") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
model_choice = gr.Dropdown( | |
label="π± Model", | |
choices=list(MODELS.keys()), | |
value="Athena-R3X 8B", | |
info="Select which Athena model to use" | |
) | |
max_length = gr.Slider( | |
32, 2048, value=512, | |
label="π Max Tokens", | |
info="Maximum number of tokens to generate" | |
) | |
temperature = gr.Slider( | |
0.1, 2.0, value=0.7, | |
label="π¨ Creativity", | |
info="Higher values = more creative responses" | |
) | |
clear_btn = gr.Button("ποΈ Clear Chat", variant="secondary") | |
with gr.Column(scale=3): | |
chat_history = gr.Chatbot( | |
elem_id="chatbot", | |
show_label=False, | |
avatar_images=["π€", "π€"] | |
) | |
user_input = gr.Textbox( | |
placeholder="Ask Athena anything...", | |
label="Your message", | |
lines=2, | |
max_lines=10 | |
) | |
with gr.Row(): | |
submit_btn = gr.Button("π€ Send", variant="primary") | |
stats_output = gr.Textbox( | |
label="Stats", | |
interactive=False, | |
show_label=False, | |
placeholder="Stats will appear here..." | |
) | |
# Event handlers | |
submit_btn.click( | |
chatbot, | |
inputs=[chat_history, user_input, model_choice, max_length, temperature], | |
outputs=[chat_history, user_input, stats_output] | |
) | |
user_input.submit( | |
chatbot, | |
inputs=[chat_history, user_input, model_choice, max_length, temperature], | |
outputs=[chat_history, user_input, stats_output] | |
) | |
clear_btn.click( | |
clear_chat, | |
inputs=[], | |
outputs=[chat_history, user_input, stats_output] | |
) | |
if __name__ == "__main__": | |
demo.launch() |