|
import gradio as gr |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
import torch |
|
import time |
|
import spaces |
|
|
|
|
|
@spaces.GPU |
|
def load_model_gpu(model_id): |
|
"""Load model on ZeroGPU""" |
|
print(f"π Loading {model_id} on ZeroGPU...") |
|
start_time = time.time() |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_id, |
|
torch_dtype=torch.float16, |
|
device_map="auto", |
|
trust_remote_code=True |
|
) |
|
|
|
load_time = time.time() - start_time |
|
print(f"β
Model loaded in {load_time:.2f}s") |
|
|
|
return model, tokenizer |
|
|
|
@spaces.GPU |
|
def generate_response(model, tokenizer, prompt, max_length=512, temperature=0.7): |
|
"""Generate response using ZeroGPU""" |
|
device = next(model.parameters()).device |
|
inputs = tokenizer(prompt, return_tensors="pt").to(device) |
|
|
|
start_time = time.time() |
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
**inputs, |
|
max_new_tokens=max_length, |
|
temperature=temperature, |
|
do_sample=True, |
|
top_p=0.9, |
|
pad_token_id=tokenizer.eos_token_id, |
|
eos_token_id=tokenizer.eos_token_id |
|
) |
|
|
|
generation_time = time.time() - start_time |
|
output_text = tokenizer.decode( |
|
outputs[0][inputs['input_ids'].shape[-1]:], |
|
skip_special_tokens=True |
|
).strip() |
|
|
|
return output_text, generation_time |
|
|
|
|
|
MODELS = { |
|
"Athena-R3X 8B": "Spestly/Athena-R3X-8B", |
|
"Athena-R3X 4B": "Spestly/Athena-R3X-4B", |
|
"Athena-R3 7B": "Spestly/Athena-R3-7B", |
|
"Athena-3 3B": "Spestly/Athena-3-3B", |
|
"Athena-3 7B": "Spestly/Athena-3-7B", |
|
"Athena-3 14B": "Spestly/Athena-3-14B", |
|
"Athena-2 1.5B": "Spestly/Athena-2-1.5B", |
|
"Athena-1 3B": "Spestly/Athena-1-3B", |
|
"Athena-1 7B": "Spestly/Athena-1-7B" |
|
} |
|
|
|
def chatbot(conversation, user_message, model_name, max_length=512, temperature=0.7): |
|
if not user_message.strip(): |
|
return conversation, "", "Please enter a message" |
|
|
|
if conversation is None: |
|
conversation = [] |
|
|
|
|
|
model_id = MODELS.get(model_name, MODELS["Athena-R3X 8B"]) |
|
|
|
try: |
|
|
|
model, tokenizer = load_model_gpu(model_id) |
|
|
|
|
|
conversation.append([user_message, ""]) |
|
|
|
|
|
prompt = "" |
|
for user_msg, assistant_msg in conversation[:-1]: |
|
prompt += f"User: {user_msg}\nAthena: {assistant_msg}\n" |
|
prompt += f"User: {user_message}\nAthena:" |
|
|
|
|
|
output_text, generation_time = generate_response( |
|
model, tokenizer, prompt, max_length, temperature |
|
) |
|
|
|
|
|
conversation[-1][1] = output_text |
|
|
|
stats = f"β‘ Generated in {generation_time:.2f}s | Model: {model_name} | Temp: {temperature}" |
|
|
|
return conversation, "", stats |
|
|
|
except Exception as e: |
|
error_msg = f"Error: {str(e)}" |
|
if conversation: |
|
conversation[-1][1] = error_msg |
|
else: |
|
conversation = [[user_message, error_msg]] |
|
return conversation, "", f"β Error occurred: {str(e)}" |
|
|
|
def clear_chat(): |
|
return [], "", "" |
|
|
|
|
|
css = """ |
|
#chatbot { |
|
height: 600px; |
|
} |
|
.message { |
|
padding: 10px; |
|
margin: 5px; |
|
border-radius: 10px; |
|
} |
|
""" |
|
|
|
|
|
with gr.Blocks(title="Athena Playground Chat", css=css) as demo: |
|
gr.Markdown("# π Athena Playground Chat") |
|
gr.Markdown("*Powered by HuggingFace ZeroGPU*") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
model_choice = gr.Dropdown( |
|
label="π± Model", |
|
choices=list(MODELS.keys()), |
|
value="Athena-R3X 8B", |
|
info="Select which Athena model to use" |
|
) |
|
max_length = gr.Slider( |
|
32, 2048, value=512, |
|
label="π Max Tokens", |
|
info="Maximum number of tokens to generate" |
|
) |
|
temperature = gr.Slider( |
|
0.1, 2.0, value=0.7, |
|
label="π¨ Creativity", |
|
info="Higher values = more creative responses" |
|
) |
|
clear_btn = gr.Button("ποΈ Clear Chat", variant="secondary") |
|
|
|
with gr.Column(scale=3): |
|
chat_history = gr.Chatbot( |
|
elem_id="chatbot", |
|
show_label=False, |
|
avatar_images=["π€", "π€"] |
|
) |
|
user_input = gr.Textbox( |
|
placeholder="Ask Athena anything...", |
|
label="Your message", |
|
lines=2, |
|
max_lines=10 |
|
) |
|
with gr.Row(): |
|
submit_btn = gr.Button("π€ Send", variant="primary") |
|
stats_output = gr.Textbox( |
|
label="Stats", |
|
interactive=False, |
|
show_label=False, |
|
placeholder="Stats will appear here..." |
|
) |
|
|
|
|
|
submit_btn.click( |
|
chatbot, |
|
inputs=[chat_history, user_input, model_choice, max_length, temperature], |
|
outputs=[chat_history, user_input, stats_output] |
|
) |
|
|
|
user_input.submit( |
|
chatbot, |
|
inputs=[chat_history, user_input, model_choice, max_length, temperature], |
|
outputs=[chat_history, user_input, stats_output] |
|
) |
|
|
|
clear_btn.click( |
|
clear_chat, |
|
inputs=[], |
|
outputs=[chat_history, user_input, stats_output] |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |