File size: 5,747 Bytes
82452fa 3a04e30 82452fa da8de8d eeda09f da8de8d eeda09f 3a04e30 eeda09f da8de8d 77246c4 739239b da8de8d d4a0dae da8de8d 3a04e30 da8de8d 739239b 3a04e30 739239b b0a11da da8de8d 77246c4 3a04e30 1c4e5c1 da8de8d 739239b 90de0bc 739239b da8de8d 739239b 82452fa da8de8d 6935809 739239b da8de8d 739239b da8de8d 739239b 1c4e5c1 e6695b6 90de0bc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import time
import spaces
# Model configurations
MODELS = {
"Athena-R3X 8B": "Spestly/Athena-R3X-8B",
"Athena-R3X 4B": "Spestly/Athena-R3X-4B",
"Athena-R3 7B": "Spestly/Athena-R3-7B",
"Athena-3 3B": "Spestly/Athena-3-3B",
"Athena-3 7B": "Spestly/Athena-3-7B",
"Athena-3 14B": "Spestly/Athena-3-14B",
"Athena-2 1.5B": "Spestly/Athena-2-1.5B",
"Athena-1 3B": "Spestly/Athena-1-3B",
"Athena-1 7B": "Spestly/Athena-1-7B"
}
@spaces.GPU
def generate_response(model_id, conversation, user_message, max_length=512, temperature=0.7):
"""Generate response using ZeroGPU - all CUDA operations happen here"""
print(f"π Loading {model_id}...")
start_time = time.time()
tokenizer = AutoTokenizer.from_pretrained(model_id)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True
)
load_time = time.time() - start_time
print(f"β
Model loaded in {load_time:.2f}s")
# Build messages in proper chat format (OpenAI-style messages)
messages = []
system_prompt = (
"You are Athena, a helpful, harmless, and honest AI assistant. "
"You provide clear, accurate, and concise responses to user questions. "
"You are knowledgeable across many domains and always aim to be respectful and helpful. "
"You are finetuned by Aayan Mishra"
)
messages.append({"role": "system", "content": system_prompt})
# Add conversation history (OpenAI-style)
for msg in conversation:
if msg["role"] in ("user", "assistant"):
messages.append({"role": msg["role"], "content": msg["content"]})
# Add current user message
messages.append({"role": "user", "content": user_message})
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
inputs = tokenizer(prompt, return_tensors="pt")
device = next(model.parameters()).device
inputs = {k: v.to(device) for k, v in inputs.items()}
generation_start = time.time()
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_length,
temperature=temperature,
do_sample=True,
top_p=0.9,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id
)
generation_time = time.time() - generation_start
response = tokenizer.decode(
outputs[0][inputs['input_ids'].shape[-1]:],
skip_special_tokens=True
).strip()
return response, load_time, generation_time
def respond(history, message, model_name, max_length, temperature):
"""Main function for custom Chatbot interface"""
if not message.strip():
history = history + [["user", message], ["assistant", "Please enter a message"]]
return history, ""
model_id = MODELS.get(model_name, MODELS["Athena-R3X 8B"])
try:
# Format history for Athena
formatted_history = []
for i in range(0, len(history), 2):
if i < len(history):
user_msg = history[i][1] if history[i][0] == "user" else ""
assistant_msg = history[i+1][1] if i+1 < len(history) and history[i+1][0] == "assistant" else ""
if user_msg:
formatted_history.append({"role": "user", "content": user_msg})
if assistant_msg:
formatted_history.append({"role": "assistant", "content": assistant_msg})
response, load_time, generation_time = generate_response(
model_id, formatted_history, message, max_length, temperature
)
history = history + [["user", message], ["assistant", response]]
return history, ""
except Exception as e:
history = history + [["user", message], ["assistant", f"Error: {str(e)}"]]
return history, ""
css = """
.message {
padding: 10px;
margin: 5px;
border-radius: 10px;
}
"""
theme = gr.themes.Monochrome()
with gr.Blocks(title="Athena Playground Chat", css=css, theme=theme) as demo:
gr.Markdown("# π Athena Playground Chat")
gr.Markdown("*Powered by HuggingFace ZeroGPU*")
chatbot = gr.Chatbot(height=500, label="Athena", avatar="π€")
state = gr.State([]) # chat history
with gr.Row():
user_input = gr.Textbox(label="Your message", scale=8, autofocus=True)
send_btn = gr.Button(value="Send", scale=1)
# --- Configuration controls at the bottom ---
gr.Markdown("### βοΈ Model & Generation Settings")
with gr.Row():
model_choice = gr.Dropdown(
label="π± Model",
choices=list(MODELS.keys()),
value="Athena-R3X 4B",
info="Select which Athena model to use"
)
max_length = gr.Slider(
32, 2048, value=512,
label="π Max Tokens",
info="Maximum number of tokens to generate"
)
temperature = gr.Slider(
0.1, 2.0, value=0.7,
label="π¨ Creativity",
info="Higher values = more creative responses"
)
def chat_submit(history, message, model_name, max_length, temperature):
return respond(history, message, model_name, max_length, temperature)
send_btn.click(
chat_submit,
inputs=[state, user_input, model_choice, max_length, temperature],
outputs=[chatbot, user_input]
)
if __name__ == "__main__":
demo.launch() |