Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,986 Bytes
82452fa 3a04e30 82452fa 3a04e30 82452fa 3a04e30 82452fa 3a04e30 82452fa 3a04e30 82452fa 3a04e30 82452fa 3a04e30 77246c4 3a04e30 77246c4 3a04e30 77246c4 3a04e30 77246c4 3a04e30 77246c4 3a04e30 77246c4 3a04e30 77246c4 3a04e30 77246c4 3a04e30 77246c4 3a04e30 82452fa 3a04e30 82452fa 3a04e30 82452fa 3a04e30 82452fa 3a04e30 77246c4 3a04e30 77246c4 3a04e30 82452fa 77246c4 3a04e30 82452fa 3a04e30 77246c4 3a04e30 82452fa e6695b6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import time
import spaces
# ZeroGPU decorator for GPU-intensive functions
@spaces.GPU
def load_model_gpu(model_id):
"""Load model on ZeroGPU"""
print(f"π Loading {model_id} on ZeroGPU...")
start_time = time.time()
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16, # Use float16 for better memory efficiency
device_map="auto",
trust_remote_code=True
)
load_time = time.time() - start_time
print(f"β
Model loaded in {load_time:.2f}s")
return model, tokenizer
@spaces.GPU
def generate_response(model, tokenizer, prompt, max_length=512, temperature=0.7):
"""Generate response using ZeroGPU"""
device = next(model.parameters()).device
inputs = tokenizer(prompt, return_tensors="pt").to(device)
start_time = time.time()
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_length,
temperature=temperature,
do_sample=True,
top_p=0.9,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id
)
generation_time = time.time() - start_time
output_text = tokenizer.decode(
outputs[0][inputs['input_ids'].shape[-1]:],
skip_special_tokens=True
).strip()
return output_text, generation_time
# Model configurations
MODELS = {
"Athena-R3X 8B": "Spestly/Athena-R3X-8B",
"Athena-R3X 4B": "Spestly/Athena-R3X-4B",
"Athena-R3 7B": "Spestly/Athena-R3-7B",
"Athena-3 3B": "Spestly/Athena-3-3B",
"Athena-3 7B": "Spestly/Athena-3-7B",
"Athena-3 14B": "Spestly/Athena-3-14B",
"Athena-2 1.5B": "Spestly/Athena-2-1.5B",
"Athena-1 3B": "Spestly/Athena-1-3B",
"Athena-1 7B": "Spestly/Athena-1-7B"
}
def chatbot(conversation, user_message, model_name, max_length=512, temperature=0.7):
if not user_message.strip():
return conversation, "", "Please enter a message"
if conversation is None:
conversation = []
# Get model ID
model_id = MODELS.get(model_name, MODELS["Athena-R3X 8B"])
try:
# Load model and tokenizer using ZeroGPU
model, tokenizer = load_model_gpu(model_id)
# Append user message to conversation
conversation.append([user_message, ""])
# Build prompt from conversation history
prompt = ""
for user_msg, assistant_msg in conversation[:-1]: # Exclude the current message
prompt += f"User: {user_msg}\nAthena: {assistant_msg}\n"
prompt += f"User: {user_message}\nAthena:"
# Generate response using ZeroGPU
output_text, generation_time = generate_response(
model, tokenizer, prompt, max_length, temperature
)
# Update the last conversation entry with the response
conversation[-1][1] = output_text
stats = f"β‘ Generated in {generation_time:.2f}s | Model: {model_name} | Temp: {temperature}"
return conversation, "", stats
except Exception as e:
error_msg = f"Error: {str(e)}"
if conversation:
conversation[-1][1] = error_msg
else:
conversation = [[user_message, error_msg]]
return conversation, "", f"β Error occurred: {str(e)}"
def clear_chat():
return [], "", ""
# CSS for better styling
css = """
#chatbot {
height: 600px;
}
.message {
padding: 10px;
margin: 5px;
border-radius: 10px;
}
"""
# Create Gradio interface
with gr.Blocks(title="Athena Playground Chat", css=css) as demo:
gr.Markdown("# π Athena Playground Chat")
gr.Markdown("*Powered by HuggingFace ZeroGPU*")
with gr.Row():
with gr.Column(scale=1):
model_choice = gr.Dropdown(
label="π± Model",
choices=list(MODELS.keys()),
value="Athena-R3X 8B",
info="Select which Athena model to use"
)
max_length = gr.Slider(
32, 2048, value=512,
label="π Max Tokens",
info="Maximum number of tokens to generate"
)
temperature = gr.Slider(
0.1, 2.0, value=0.7,
label="π¨ Creativity",
info="Higher values = more creative responses"
)
clear_btn = gr.Button("ποΈ Clear Chat", variant="secondary")
with gr.Column(scale=3):
chat_history = gr.Chatbot(
elem_id="chatbot",
show_label=False,
avatar_images=["π€", "π€"]
)
user_input = gr.Textbox(
placeholder="Ask Athena anything...",
label="Your message",
lines=2,
max_lines=10
)
with gr.Row():
submit_btn = gr.Button("π€ Send", variant="primary")
stats_output = gr.Textbox(
label="Stats",
interactive=False,
show_label=False,
placeholder="Stats will appear here..."
)
# Event handlers
submit_btn.click(
chatbot,
inputs=[chat_history, user_input, model_choice, max_length, temperature],
outputs=[chat_history, user_input, stats_output]
)
user_input.submit(
chatbot,
inputs=[chat_history, user_input, model_choice, max_length, temperature],
outputs=[chat_history, user_input, stats_output]
)
clear_btn.click(
clear_chat,
inputs=[],
outputs=[chat_history, user_input, stats_output]
)
if __name__ == "__main__":
demo.launch() |