GPT-OSS / app.py
Spestly's picture
Update app.py
e6695b6 verified
raw
history blame
5.99 kB
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import time
import spaces
# ZeroGPU decorator for GPU-intensive functions
@spaces.GPU
def load_model_gpu(model_id):
"""Load model on ZeroGPU"""
print(f"πŸš€ Loading {model_id} on ZeroGPU...")
start_time = time.time()
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16, # Use float16 for better memory efficiency
device_map="auto",
trust_remote_code=True
)
load_time = time.time() - start_time
print(f"βœ… Model loaded in {load_time:.2f}s")
return model, tokenizer
@spaces.GPU
def generate_response(model, tokenizer, prompt, max_length=512, temperature=0.7):
"""Generate response using ZeroGPU"""
device = next(model.parameters()).device
inputs = tokenizer(prompt, return_tensors="pt").to(device)
start_time = time.time()
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_length,
temperature=temperature,
do_sample=True,
top_p=0.9,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id
)
generation_time = time.time() - start_time
output_text = tokenizer.decode(
outputs[0][inputs['input_ids'].shape[-1]:],
skip_special_tokens=True
).strip()
return output_text, generation_time
# Model configurations
MODELS = {
"Athena-R3X 8B": "Spestly/Athena-R3X-8B",
"Athena-R3X 4B": "Spestly/Athena-R3X-4B",
"Athena-R3 7B": "Spestly/Athena-R3-7B",
"Athena-3 3B": "Spestly/Athena-3-3B",
"Athena-3 7B": "Spestly/Athena-3-7B",
"Athena-3 14B": "Spestly/Athena-3-14B",
"Athena-2 1.5B": "Spestly/Athena-2-1.5B",
"Athena-1 3B": "Spestly/Athena-1-3B",
"Athena-1 7B": "Spestly/Athena-1-7B"
}
def chatbot(conversation, user_message, model_name, max_length=512, temperature=0.7):
if not user_message.strip():
return conversation, "", "Please enter a message"
if conversation is None:
conversation = []
# Get model ID
model_id = MODELS.get(model_name, MODELS["Athena-R3X 8B"])
try:
# Load model and tokenizer using ZeroGPU
model, tokenizer = load_model_gpu(model_id)
# Append user message to conversation
conversation.append([user_message, ""])
# Build prompt from conversation history
prompt = ""
for user_msg, assistant_msg in conversation[:-1]: # Exclude the current message
prompt += f"User: {user_msg}\nAthena: {assistant_msg}\n"
prompt += f"User: {user_message}\nAthena:"
# Generate response using ZeroGPU
output_text, generation_time = generate_response(
model, tokenizer, prompt, max_length, temperature
)
# Update the last conversation entry with the response
conversation[-1][1] = output_text
stats = f"⚑ Generated in {generation_time:.2f}s | Model: {model_name} | Temp: {temperature}"
return conversation, "", stats
except Exception as e:
error_msg = f"Error: {str(e)}"
if conversation:
conversation[-1][1] = error_msg
else:
conversation = [[user_message, error_msg]]
return conversation, "", f"❌ Error occurred: {str(e)}"
def clear_chat():
return [], "", ""
# CSS for better styling
css = """
#chatbot {
height: 600px;
}
.message {
padding: 10px;
margin: 5px;
border-radius: 10px;
}
"""
# Create Gradio interface
with gr.Blocks(title="Athena Playground Chat", css=css) as demo:
gr.Markdown("# πŸš€ Athena Playground Chat")
gr.Markdown("*Powered by HuggingFace ZeroGPU*")
with gr.Row():
with gr.Column(scale=1):
model_choice = gr.Dropdown(
label="πŸ“± Model",
choices=list(MODELS.keys()),
value="Athena-R3X 8B",
info="Select which Athena model to use"
)
max_length = gr.Slider(
32, 2048, value=512,
label="πŸ“ Max Tokens",
info="Maximum number of tokens to generate"
)
temperature = gr.Slider(
0.1, 2.0, value=0.7,
label="🎨 Creativity",
info="Higher values = more creative responses"
)
clear_btn = gr.Button("πŸ—‘οΈ Clear Chat", variant="secondary")
with gr.Column(scale=3):
chat_history = gr.Chatbot(
elem_id="chatbot",
show_label=False,
avatar_images=["πŸ‘€", "πŸ€–"]
)
user_input = gr.Textbox(
placeholder="Ask Athena anything...",
label="Your message",
lines=2,
max_lines=10
)
with gr.Row():
submit_btn = gr.Button("πŸ“€ Send", variant="primary")
stats_output = gr.Textbox(
label="Stats",
interactive=False,
show_label=False,
placeholder="Stats will appear here..."
)
# Event handlers
submit_btn.click(
chatbot,
inputs=[chat_history, user_input, model_choice, max_length, temperature],
outputs=[chat_history, user_input, stats_output]
)
user_input.submit(
chatbot,
inputs=[chat_history, user_input, model_choice, max_length, temperature],
outputs=[chat_history, user_input, stats_output]
)
clear_btn.click(
clear_chat,
inputs=[],
outputs=[chat_history, user_input, stats_output]
)
if __name__ == "__main__":
demo.launch()