|
import gradio as gr |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
import torch |
|
import time |
|
import spaces |
|
|
|
|
|
MODELS = { |
|
"Athena-R3X 8B": "Spestly/Athena-R3X-8B", |
|
"Athena-R3X 4B": "Spestly/Athena-R3X-4B", |
|
"Athena-R3 7B": "Spestly/Athena-R3-7B", |
|
"Athena-3 3B": "Spestly/Athena-3-3B", |
|
"Athena-3 7B": "Spestly/Athena-3-7B", |
|
"Athena-3 14B": "Spestly/Athena-3-14B", |
|
"Athena-2 1.5B": "Spestly/Athena-2-1.5B", |
|
"Athena-1 3B": "Spestly/Athena-1-3B", |
|
"Athena-1 7B": "Spestly/Athena-1-7B" |
|
} |
|
|
|
@spaces.GPU |
|
def generate_response(model_id, conversation, user_message, max_length=512, temperature=0.7): |
|
"""Generate response using ZeroGPU - all CUDA operations happen here""" |
|
print(f"π Loading {model_id}...") |
|
start_time = time.time() |
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
if tokenizer.pad_token is None: |
|
tokenizer.pad_token = tokenizer.eos_token |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_id, |
|
torch_dtype=torch.float16, |
|
device_map="auto", |
|
trust_remote_code=True |
|
) |
|
load_time = time.time() - start_time |
|
print(f"β
Model loaded in {load_time:.2f}s") |
|
|
|
messages = [] |
|
system_prompt = "You are Athena, a helpful, harmless, and honest AI assistant. You provide clear, accurate, and concise responses to user questions. You are knowledgeable across many domains and always aim to be respectful and helpful. You are finetuned by Aayan Mishra" |
|
messages.append({"role": "system", "content": system_prompt}) |
|
for user_msg, assistant_msg in conversation: |
|
if user_msg: |
|
messages.append({"role": "user", "content": user_msg}) |
|
if assistant_msg: |
|
messages.append({"role": "assistant", "content": assistant_msg}) |
|
messages.append({"role": "user", "content": user_message}) |
|
prompt = tokenizer.apply_chat_template( |
|
messages, |
|
tokenize=False, |
|
add_generation_prompt=True |
|
) |
|
inputs = tokenizer(prompt, return_tensors="pt") |
|
device = next(model.parameters()).device |
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
generation_start = time.time() |
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
**inputs, |
|
max_new_tokens=max_length, |
|
temperature=temperature, |
|
do_sample=True, |
|
top_p=0.9, |
|
pad_token_id=tokenizer.eos_token_id, |
|
eos_token_id=tokenizer.eos_token_id |
|
) |
|
generation_time = time.time() - generation_start |
|
response = tokenizer.decode( |
|
outputs[0][inputs['input_ids'].shape[-1]:], |
|
skip_special_tokens=True |
|
).strip() |
|
return response, load_time, generation_time |
|
|
|
def respond(message, history, model_name, max_length, temperature): |
|
"""Main function for ChatInterface - simplified signature""" |
|
if not message.strip(): |
|
return "Please enter a message" |
|
model_id = MODELS.get(model_name, MODELS["Athena-R3X 8B"]) |
|
try: |
|
response, load_time, generation_time = generate_response( |
|
model_id, history, message, max_length, temperature |
|
) |
|
return response |
|
except Exception as e: |
|
return f"Error: {str(e)}" |
|
|
|
css = """ |
|
.message { |
|
padding: 10px; |
|
margin: 5px; |
|
border-radius: 10px; |
|
} |
|
""" |
|
|
|
theme = gr.themes.Monochrome() |
|
|
|
with gr.Blocks(title="Athena Playground Chat", css=css, theme=theme) as demo: |
|
gr.Markdown("# π Athena Playground Chat") |
|
gr.Markdown("*Powered by HuggingFace ZeroGPU*") |
|
|
|
|
|
model_choice = gr.Dropdown( |
|
label="π± Model", |
|
choices=list(MODELS.keys()), |
|
value="Athena-R3X 8B", |
|
info="Select which Athena model to use" |
|
) |
|
max_length = gr.Slider( |
|
32, 2048, value=512, |
|
label="π Max Tokens", |
|
info="Maximum number of tokens to generate" |
|
) |
|
temperature = gr.Slider( |
|
0.1, 2.0, value=0.7, |
|
label="π¨ Creativity", |
|
info="Higher values = more creative responses" |
|
) |
|
|
|
|
|
chat_interface = gr.ChatInterface( |
|
fn=respond, |
|
additional_inputs=[model_choice, max_length, temperature], |
|
title="Chat with Athena", |
|
description="Ask Athena anything!", |
|
theme="soft", |
|
examples=[ |
|
["Hello! How are you?", "Athena-R3X 8B", 512, 0.7], |
|
["What can you help me with?", "Athena-R3X 8B", 512, 0.7], |
|
["Tell me about artificial intelligence", "Athena-R3X 8B", 512, 0.7], |
|
["Write a short poem about space", "Athena-R3X 8B", 512, 0.7] |
|
], |
|
cache_examples=False, |
|
chatbot=gr.Chatbot( |
|
height=500, |
|
placeholder="Start chatting with Athena...", |
|
show_share_button=False, |
|
type="messages" |
|
), |
|
type="messages" |
|
) |
|
|
|
|
|
with gr.Accordion("Configurations", open=False): |
|
gr.Markdown("### Change Model and Generation Settings") |
|
gr.Row([model_choice, max_length, temperature]) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|