GPT-OSS / app.py
Spestly's picture
Update app.py
663284d verified
raw
history blame
6.73 kB
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import time
import spaces
# Model configurations
MODELS = {
"Athena-R3X 8B": "Spestly/Athena-R3X-8B",
"Athena-R3X 4B": "Spestly/Athena-R3X-4B",
"Athena-R3 7B": "Spestly/Athena-R3-7B",
"Athena-3 3B": "Spestly/Athena-3-3B",
"Athena-3 7B": "Spestly/Athena-3-7B",
"Athena-3 14B": "Spestly/Athena-3-14B",
"Athena-2 1.5B": "Spestly/Athena-2-1.5B",
"Athena-1 3B": "Spestly/Athena-1-3B",
"Athena-1 7B": "Spestly/Athena-1-7B"
}
@spaces.GPU
def generate_response(model_id, conversation, user_message, max_length=512, temperature=0.7):
"""Generate response using ZeroGPU - all CUDA operations happen here"""
# Load model and tokenizer inside the GPU function
print(f"πŸš€ Loading {model_id}...")
start_time = time.time()
tokenizer = AutoTokenizer.from_pretrained(model_id)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True
)
load_time = time.time() - start_time
print(f"βœ… Model loaded in {load_time:.2f}s")
# Build messages in proper chat format
messages = []
# Add system prompt first
system_prompt = "You are Athena, a helpful, harmless, and honest AI assistant. You provide clear, accurate, and concise responses to user questions. You are knowledgeable across many domains and always aim to be respectful and helpful. You are finetuned by Aayan Mishra"
messages.append({"role": "system", "content": system_prompt})
# Add conversation history
for user_msg, assistant_msg in conversation:
if user_msg:
messages.append({"role": "user", "content": user_msg})
if assistant_msg:
messages.append({"role": "assistant", "content": assistant_msg})
# Add current user message
messages.append({"role": "user", "content": user_message})
# Apply chat template
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# Tokenize and move to GPU
inputs = tokenizer(prompt, return_tensors="pt")
# Move inputs to the same device as the model
device = next(model.parameters()).device
inputs = {k: v.to(device) for k, v in inputs.items()}
generation_start = time.time()
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_length,
temperature=temperature,
do_sample=True,
top_p=0.9,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id
)
generation_time = time.time() - generation_start
# Decode response
response = tokenizer.decode(
outputs[0][inputs['input_ids'].shape[-1]:],
skip_special_tokens=True
).strip()
return response, load_time, generation_time
def chatbot(conversation, user_message, model_name, max_length=512, temperature=0.7):
if not user_message.strip():
return conversation, "", "Please enter a message"
if conversation is None:
conversation = []
# Get model ID
model_id = MODELS.get(model_name, MODELS["Athena-R3X 8B"])
try:
# Add user message to conversation
conversation.append([user_message, ""])
# Generate response using ZeroGPU
response, load_time, generation_time = generate_response(
model_id, conversation[:-1], user_message, max_length, temperature
)
# Update the conversation with the response
conversation[-1][1] = response
stats = f"⚑ Load: {load_time:.1f}s | Gen: {generation_time:.1f}s | Model: {model_name}"
return conversation, "", stats
except Exception as e:
error_msg = f"Error: {str(e)}"
if conversation:
conversation[-1][1] = error_msg
else:
conversation = [[user_message, error_msg]]
return conversation, "", f"❌ Error: {str(e)}"
def clear_chat():
return [], "", ""
# CSS for better styling
css = """
#chatbot {
height: 600px;
}
.message {
padding: 10px;
margin: 5px;
border-radius: 10px;
}
"""
# Create Gradio interface
with gr.Blocks(title="Athena Playground Chat", css=css) as demo:
gr.Markdown("# πŸš€ Athena Playground Chat")
gr.Markdown("*Powered by HuggingFace ZeroGPU*")
# Main chat interface
chat_history = gr.Chatbot(
elem_id="chatbot",
show_label=False,
show_share_button=False,
container=False
)
user_input = gr.Textbox(
placeholder="Ask Athena anything...",
label="Your message",
lines=2,
max_lines=10
)
with gr.Row():
submit_btn = gr.Button("πŸ“€ Send", variant="primary")
clear_btn = gr.Button("πŸ—‘οΈ Clear Chat", variant="secondary")
stats_output = gr.Textbox(
label="Stats",
interactive=False,
show_label=False,
placeholder="Stats will appear here..."
)
# Configuration settings at the bottom
gr.Markdown("---")
gr.Markdown("## βš™οΈ Configuration")
with gr.Row():
with gr.Column():
model_choice = gr.Dropdown(
label="πŸ“± Model",
choices=list(MODELS.keys()),
value="Athena-R3X 8B",
info="Select which Athena model to use"
)
with gr.Column():
max_length = gr.Slider(
32, 2048, value=512,
label="πŸ“ Max Tokens",
info="Maximum number of tokens to generate"
)
with gr.Column():
temperature = gr.Slider(
0.1, 2.0, value=0.7,
label="🎨 Creativity",
info="Higher values = more creative responses"
)
# Event handlers
submit_btn.click(
chatbot,
inputs=[chat_history, user_input, model_choice, max_length, temperature],
outputs=[chat_history, user_input, stats_output]
)
user_input.submit(
chatbot,
inputs=[chat_history, user_input, model_choice, max_length, temperature],
outputs=[chat_history, user_input, stats_output]
)
clear_btn.click(
clear_chat,
inputs=[],
outputs=[chat_history, user_input, stats_output]
)
if __name__ == "__main__":
demo.launch()