File size: 6,089 Bytes
82452fa 3a04e30 82452fa eeda09f 3a04e30 eeda09f 82452fa 3a04e30 82452fa eeda09f 82452fa eeda09f 3a04e30 82452fa 3a04e30 82452fa 3a04e30 eeda09f 3a04e30 eeda09f 77246c4 3a04e30 77246c4 3a04e30 eeda09f 3a04e30 eeda09f 77246c4 3a04e30 eeda09f 3a04e30 eeda09f 3a04e30 eeda09f 3a04e30 eeda09f 3a04e30 eeda09f 77246c4 3a04e30 77246c4 3a04e30 77246c4 3a04e30 77246c4 3a04e30 82452fa 3a04e30 82452fa 3a04e30 82452fa 3a04e30 82452fa 3a04e30 77246c4 3a04e30 77246c4 3a04e30 82452fa 77246c4 3a04e30 82452fa 3a04e30 77246c4 3a04e30 82452fa e6695b6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 |
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import time
import spaces
# Model configurations
MODELS = {
"Athena-R3X 8B": "Spestly/Athena-R3X-8B",
"Athena-R3X 4B": "Spestly/Athena-R3X-4B",
"Athena-R3 7B": "Spestly/Athena-R3-7B",
"Athena-3 3B": "Spestly/Athena-3-3B",
"Athena-3 7B": "Spestly/Athena-3-7B",
"Athena-3 14B": "Spestly/Athena-3-14B",
"Athena-2 1.5B": "Spestly/Athena-2-1.5B",
"Athena-1 3B": "Spestly/Athena-1-3B",
"Athena-1 7B": "Spestly/Athena-1-7B"
}
@spaces.GPU
def generate_response(model_id, conversation, user_message, max_length=512, temperature=0.7):
"""Generate response using ZeroGPU - all CUDA operations happen here"""
# Load model and tokenizer inside the GPU function
print(f"π Loading {model_id}...")
start_time = time.time()
tokenizer = AutoTokenizer.from_pretrained(model_id)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True
)
load_time = time.time() - start_time
print(f"β
Model loaded in {load_time:.2f}s")
# Build conversation history
conversation_history = []
for user_msg, assistant_msg in conversation:
if user_msg:
conversation_history.append(f"User: {user_msg}")
if assistant_msg:
conversation_history.append(f"Athena: {assistant_msg}")
# Add current user message
conversation_history.append(f"User: {user_message}")
conversation_history.append("Athena:")
# Create prompt
prompt = "\n".join(conversation_history)
# Tokenize and generate
inputs = tokenizer(prompt, return_tensors="pt")
generation_start = time.time()
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_length,
temperature=temperature,
do_sample=True,
top_p=0.9,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id
)
generation_time = time.time() - generation_start
# Decode response
response = tokenizer.decode(
outputs[0][inputs['input_ids'].shape[-1]:],
skip_special_tokens=True
).strip()
return response, load_time, generation_time
def chatbot(conversation, user_message, model_name, max_length=512, temperature=0.7):
if not user_message.strip():
return conversation, "", "Please enter a message"
if conversation is None:
conversation = []
# Get model ID
model_id = MODELS.get(model_name, MODELS["Athena-R3X 8B"])
try:
# Add user message to conversation
conversation.append([user_message, ""])
# Generate response using ZeroGPU
response, load_time, generation_time = generate_response(
model_id, conversation[:-1], user_message, max_length, temperature
)
# Update the conversation with the response
conversation[-1][1] = response
stats = f"β‘ Load: {load_time:.1f}s | Gen: {generation_time:.1f}s | Model: {model_name}"
return conversation, "", stats
except Exception as e:
error_msg = f"Error: {str(e)}"
if conversation:
conversation[-1][1] = error_msg
else:
conversation = [[user_message, error_msg]]
return conversation, "", f"β Error: {str(e)}"
def clear_chat():
return [], "", ""
# CSS for better styling
css = """
#chatbot {
height: 600px;
}
.message {
padding: 10px;
margin: 5px;
border-radius: 10px;
}
"""
# Create Gradio interface
with gr.Blocks(title="Athena Playground Chat", css=css) as demo:
gr.Markdown("# π Athena Playground Chat")
gr.Markdown("*Powered by HuggingFace ZeroGPU*")
with gr.Row():
with gr.Column(scale=1):
model_choice = gr.Dropdown(
label="π± Model",
choices=list(MODELS.keys()),
value="Athena-R3X 8B",
info="Select which Athena model to use"
)
max_length = gr.Slider(
32, 2048, value=512,
label="π Max Tokens",
info="Maximum number of tokens to generate"
)
temperature = gr.Slider(
0.1, 2.0, value=0.7,
label="π¨ Creativity",
info="Higher values = more creative responses"
)
clear_btn = gr.Button("ποΈ Clear Chat", variant="secondary")
with gr.Column(scale=3):
chat_history = gr.Chatbot(
elem_id="chatbot",
show_label=False,
avatar_images=["π€", "π€"]
)
user_input = gr.Textbox(
placeholder="Ask Athena anything...",
label="Your message",
lines=2,
max_lines=10
)
with gr.Row():
submit_btn = gr.Button("π€ Send", variant="primary")
stats_output = gr.Textbox(
label="Stats",
interactive=False,
show_label=False,
placeholder="Stats will appear here..."
)
# Event handlers
submit_btn.click(
chatbot,
inputs=[chat_history, user_input, model_choice, max_length, temperature],
outputs=[chat_history, user_input, stats_output]
)
user_input.submit(
chatbot,
inputs=[chat_history, user_input, model_choice, max_length, temperature],
outputs=[chat_history, user_input, stats_output]
)
clear_btn.click(
clear_chat,
inputs=[],
outputs=[chat_history, user_input, stats_output]
)
if __name__ == "__main__":
demo.launch() |