GPT-OSS / app.py
Spestly's picture
Update app.py
90de0bc verified
raw
history blame
5.26 kB
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import time
import spaces
# Model configurations
MODELS = {
"Athena-R3X 8B": "Spestly/Athena-R3X-8B",
"Athena-R3X 4B": "Spestly/Athena-R3X-4B",
"Athena-R3 7B": "Spestly/Athena-R3-7B",
"Athena-3 3B": "Spestly/Athena-3-3B",
"Athena-3 7B": "Spestly/Athena-3-7B",
"Athena-3 14B": "Spestly/Athena-3-14B",
"Athena-2 1.5B": "Spestly/Athena-2-1.5B",
"Athena-1 3B": "Spestly/Athena-1-3B",
"Athena-1 7B": "Spestly/Athena-1-7B"
}
@spaces.GPU
def generate_response(model_id, conversation, user_message, max_length=512, temperature=0.7):
"""Generate response using ZeroGPU - all CUDA operations happen here"""
print(f"πŸš€ Loading {model_id}...")
start_time = time.time()
tokenizer = AutoTokenizer.from_pretrained(model_id)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True
)
load_time = time.time() - start_time
print(f"βœ… Model loaded in {load_time:.2f}s")
# Build messages in proper chat format (OpenAI-style messages)
messages = []
system_prompt = (
"You are Athena, a helpful, harmless, and honest AI assistant. "
"You provide clear, accurate, and concise responses to user questions. "
"You are knowledgeable across many domains and always aim to be respectful and helpful. "
"You are finetuned by Aayan Mishra"
)
messages.append({"role": "system", "content": system_prompt})
# Add conversation history (OpenAI-style)
for msg in conversation:
if msg["role"] in ("user", "assistant"):
messages.append({"role": msg["role"], "content": msg["content"]})
# Add current user message
messages.append({"role": "user", "content": user_message})
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
inputs = tokenizer(prompt, return_tensors="pt")
device = next(model.parameters()).device
inputs = {k: v.to(device) for k, v in inputs.items()}
generation_start = time.time()
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_length,
temperature=temperature,
do_sample=True,
top_p=0.9,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id
)
generation_time = time.time() - generation_start
response = tokenizer.decode(
outputs[0][inputs['input_ids'].shape[-1]:],
skip_special_tokens=True
).strip()
return response, load_time, generation_time
def respond(message, history, model_name, max_length, temperature):
"""Main function for ChatInterface - simplified signature"""
if not message.strip():
return "Please enter a message"
model_id = MODELS.get(model_name, MODELS["Athena-R3X 8B"])
try:
response, load_time, generation_time = generate_response(
model_id, history, message, max_length, temperature
)
return response
except Exception as e:
return f"Error: {str(e)}"
css = """
.message {
padding: 10px;
margin: 5px;
border-radius: 10px;
}
"""
theme = gr.themes.Monochrome()
with gr.Blocks(title="Athena Playground Chat", css=css, theme=theme) as demo:
gr.Markdown("# πŸš€ Athena Playground Chat")
gr.Markdown("*Powered by HuggingFace ZeroGPU*")
# --- Create config controls first ---
model_choice = gr.Dropdown(
label="πŸ“± Model",
choices=list(MODELS.keys()),
value="Athena-R3X 4B",
info="Select which Athena model to use"
)
max_length = gr.Slider(
32, 2048, value=512,
label="πŸ“ Max Tokens",
info="Maximum number of tokens to generate"
)
temperature = gr.Slider(
0.1, 2.0, value=0.7,
label="🎨 Creativity",
info="Higher values = more creative responses"
)
# --- Main chat interface ---
chat_interface = gr.ChatInterface(
fn=respond,
additional_inputs=[model_choice, max_length, temperature],
title="Chat with Athena",
description="Ask Athena anything!",
theme="soft",
examples=[
["Hello! How are you?", "Athena-R3X 8B", 512, 0.7],
["What can you help me with?", "Athena-R3X 8B", 512, 0.7],
["Tell me about artificial intelligence", "Athena-R3X 8B", 512, 0.7],
["Write a short poem about space", "Athena-R3X 8B", 512, 0.7]
],
cache_examples=False,
chatbot=gr.Chatbot(
height=500,
placeholder="Start chatting with Athena...",
show_share_button=False,
type="messages"
),
type="messages"
)
# --- Configuration controls at the bottom ---
gr.Markdown("### βš™οΈ Model & Generation Settings")
with gr.Row():
model_choice.render()
max_length.render()
temperature.render()
if __name__ == "__main__":
demo.launch()