|
import gradio as gr |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
import torch |
|
import time |
|
import spaces |
|
import re |
|
import logging |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
MODELS = { |
|
"Athena-R3X 8B": "Spestly/Athena-R3X-8B", |
|
"Athena-R3X 4B": "Spestly/Athena-R3X-4B", |
|
"Athena-R3 7B": "Spestly/Athena-R3-7B", |
|
"Athena-3 3B": "Spestly/Athena-3-3B", |
|
"Athena-3 7B": "Spestly/Athena-3-7B", |
|
"Athena-3 14B": "Spestly/Athena-3-14B", |
|
"Athena-2 1.5B": "Spestly/Athena-2-1.5B", |
|
"Athena-1 3B": "Spestly/Athena-1-3B", |
|
"Athena-1 7B": "Spestly/Athena-1-7B" |
|
} |
|
|
|
|
|
THINKING_ENABLED_MODELS = ["Spestly/Athena-R3X-4B"] |
|
|
|
|
|
loaded_models = {} |
|
|
|
@spaces.GPU |
|
def load_model(model_id): |
|
"""Load model and tokenizer once and cache them""" |
|
try: |
|
if model_id not in loaded_models: |
|
logger.info(f"🚀 Loading {model_id}...") |
|
start_time = time.time() |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
if tokenizer.pad_token is None: |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_id, |
|
torch_dtype=torch.float16, |
|
device_map="auto", |
|
trust_remote_code=True |
|
) |
|
|
|
load_time = time.time() - start_time |
|
logger.info(f"✅ Model loaded in {load_time:.2f}s") |
|
loaded_models[model_id] = (model, tokenizer, load_time) |
|
|
|
return loaded_models[model_id] |
|
except Exception as e: |
|
logger.error(f"Error loading model {model_id}: {str(e)}") |
|
raise gr.Error(f"Failed to load model {model_id}. Please try another model.") |
|
|
|
@spaces.GPU |
|
def generate_response(model_id, conversation, user_message, max_length=512, temperature=0.7): |
|
"""Generate response using the specified model""" |
|
try: |
|
model, tokenizer, _ = load_model(model_id) |
|
|
|
|
|
messages = [] |
|
system_prompt = ( |
|
"You are Athena, a helpful, harmless, and honest AI assistant. " |
|
"You provide clear, accurate, and concise responses to user questions. " |
|
"You are knowledgeable across many domains and always aim to be respectful and helpful. " |
|
"You are finetuned by Aayan Mishra" |
|
) |
|
messages.append({"role": "system", "content": system_prompt}) |
|
|
|
|
|
for msg in conversation: |
|
messages.append(msg) |
|
|
|
|
|
messages.append({"role": "user", "content": user_message}) |
|
|
|
|
|
if model_id in THINKING_ENABLED_MODELS: |
|
prompt = tokenizer.apply_chat_template( |
|
messages, |
|
tokenize=False, |
|
add_generation_prompt=True, |
|
enable_thinking=True |
|
) |
|
else: |
|
prompt = tokenizer.apply_chat_template( |
|
messages, |
|
tokenize=False, |
|
add_generation_prompt=True |
|
) |
|
|
|
inputs = tokenizer(prompt, return_tensors="pt") |
|
device = next(model.parameters()).device |
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
|
generation_start = time.time() |
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
**inputs, |
|
max_new_tokens=max_length, |
|
temperature=temperature, |
|
do_sample=True, |
|
top_p=0.9, |
|
pad_token_id=tokenizer.eos_token_id, |
|
eos_token_id=tokenizer.eos_token_id |
|
) |
|
|
|
generation_time = time.time() - generation_start |
|
response = tokenizer.decode( |
|
outputs[0][inputs['input_ids'].shape[-1]:], |
|
skip_special_tokens=True |
|
).strip() |
|
|
|
logger.info(f"Generation time: {generation_time:.2f}s") |
|
return response, generation_time |
|
|
|
except Exception as e: |
|
logger.error(f"Error in generate_response: {str(e)}") |
|
raise gr.Error(f"Error generating response: {str(e)}") |
|
|
|
def format_response_with_thinking(response): |
|
"""Format response to handle <think></think> tags""" |
|
if '<think>' in response and '</think>' in response: |
|
pattern = r'(.*?)(<think>(.*?)</think>)(.*)' |
|
match = re.search(pattern, response, re.DOTALL) |
|
|
|
if match: |
|
before_thinking = match.group(1).strip() |
|
thinking_content = match.group(3).strip() |
|
after_thinking = match.group(4).strip() |
|
|
|
html = f"{before_thinking}\n" |
|
html += f'<div class="thinking-container">' |
|
html += f'<button class="thinking-toggle"><div class="thinking-icon"></div> Thinking completed <span class="dropdown-arrow">▼</span></button>' |
|
html += f'<div class="thinking-content hidden">{thinking_content}</div>' |
|
html += f'</div>\n' |
|
html += after_thinking |
|
|
|
return html |
|
|
|
return response |
|
|
|
def validate_input(message): |
|
"""Validate user input""" |
|
if not message or not message.strip(): |
|
raise gr.Error("Message cannot be empty") |
|
if len(message) > 2000: |
|
raise gr.Error("Message too long (max 2000 characters)") |
|
return message |
|
|
|
def chat_submit(message, history, conversation_state, model_name, max_length, temperature): |
|
"""Process a new message and update the chat history""" |
|
try: |
|
|
|
message = validate_input(message) |
|
|
|
|
|
model_id = MODELS.get(model_name, MODELS["Athena-R3X 4B"]) |
|
|
|
|
|
yield "", history + [(message, "Generating response...")], conversation_state, gr.update(visible=True) |
|
|
|
|
|
response, generation_time = generate_response( |
|
model_id, conversation_state, message, max_length, temperature |
|
) |
|
|
|
|
|
conversation_state.append({"role": "user", "content": message}) |
|
conversation_state.append({"role": "assistant", "content": response}) |
|
|
|
|
|
if len(conversation_state) > 20: |
|
conversation_state = conversation_state[-20:] |
|
|
|
|
|
formatted_response = format_response_with_thinking(response) |
|
|
|
|
|
updated_history = history[:-1] + [(message, formatted_response)] |
|
|
|
yield "", updated_history, conversation_state, gr.update(visible=False) |
|
|
|
except Exception as e: |
|
logger.error(f"Error in chat_submit: {str(e)}") |
|
error_message = f"Error: {str(e)}" |
|
yield error_message, history, conversation_state, gr.update(visible=False) |
|
|
|
def clear_conversation(): |
|
"""Clear the conversation history""" |
|
return [], [], gr.update(visible=False) |
|
|
|
css = """ |
|
.message { |
|
padding: 10px; |
|
margin: 5px; |
|
border-radius: 10px; |
|
} |
|
.thinking-container { |
|
margin: 10px 0; |
|
} |
|
.thinking-toggle { |
|
background-color: rgba(30, 30, 40, 0.8); |
|
border: none; |
|
border-radius: 25px; |
|
padding: 8px 15px; |
|
cursor: pointer; |
|
font-size: 0.95em; |
|
margin-bottom: 8px; |
|
color: white; |
|
display: flex; |
|
align-items: center; |
|
gap: 8px; |
|
box-shadow: 0 2px 5px rgba(0,0,0,0.2); |
|
transition: background-color 0.2s; |
|
width: auto; |
|
max-width: 280px; |
|
} |
|
.thinking-toggle:hover { |
|
background-color: rgba(40, 40, 50, 0.9); |
|
} |
|
.thinking-icon { |
|
width: 16px; |
|
height: 16px; |
|
border-radius: 50%; |
|
background-color: #6366f1; |
|
position: relative; |
|
overflow: hidden; |
|
} |
|
.thinking-icon::after { |
|
content: ""; |
|
position: absolute; |
|
top: 50%; |
|
left: 50%; |
|
width: 60%; |
|
height: 60%; |
|
background-color: #a5b4fc; |
|
transform: translate(-50%, -50%); |
|
border-radius: 50%; |
|
} |
|
.dropdown-arrow { |
|
font-size: 0.7em; |
|
margin-left: auto; |
|
transition: transform 0.3s; |
|
} |
|
.thinking-content { |
|
background-color: rgba(30, 30, 40, 0.8); |
|
border-left: 2px solid #6366f1; |
|
padding: 15px; |
|
margin-top: 5px; |
|
margin-bottom: 15px; |
|
font-size: 0.95em; |
|
color: #e2e8f0; |
|
font-family: monospace; |
|
white-space: pre-wrap; |
|
overflow-x: auto; |
|
border-radius: 5px; |
|
line-height: 1.5; |
|
} |
|
.hidden { |
|
display: none; |
|
} |
|
.progress-container { |
|
text-align: center; |
|
margin: 10px 0; |
|
color: #6366f1; |
|
} |
|
""" |
|
|
|
js = """ |
|
function setupThinkingToggle() { |
|
document.querySelectorAll('.thinking-toggle').forEach(button => { |
|
if (!button.dataset.listenerAdded) { |
|
button.addEventListener('click', function() { |
|
const content = this.nextElementSibling; |
|
content.classList.toggle('hidden'); |
|
const arrow = this.querySelector('.dropdown-arrow'); |
|
arrow.textContent = content.classList.contains('hidden') ? '▼' : '▲'; |
|
}); |
|
button.dataset.listenerAdded = 'true'; |
|
} |
|
}); |
|
} |
|
|
|
document.addEventListener('DOMContentLoaded', () => { |
|
setupThinkingToggle(); |
|
|
|
const observer = new MutationObserver((mutations) => { |
|
setupThinkingToggle(); |
|
}); |
|
|
|
observer.observe(document.body, { |
|
childList: true, |
|
subtree: true |
|
}); |
|
}); |
|
""" |
|
|
|
|
|
with gr.Blocks(title="Athena Playground Chat", css=css, js=js) as demo: |
|
gr.Markdown("# 🚀 Athena Playground Chat") |
|
gr.Markdown("*Powered by HuggingFace ZeroGPU*") |
|
|
|
|
|
conversation_state = gr.State([]) |
|
|
|
|
|
progress = gr.HTML( |
|
"""<div class="progress-container">Generating response...</div>""", |
|
visible=False |
|
) |
|
|
|
|
|
chatbot = gr.Chatbot( |
|
height=500, |
|
label="Athena", |
|
render_markdown=True, |
|
elem_classes=["chatbot"] |
|
) |
|
|
|
|
|
with gr.Row(): |
|
user_input = gr.Textbox( |
|
label="Your message", |
|
scale=8, |
|
autofocus=True, |
|
placeholder="Type your message here...", |
|
lines=2 |
|
) |
|
send_btn = gr.Button( |
|
value="Send", |
|
scale=1, |
|
variant="primary" |
|
) |
|
|
|
|
|
clear_btn = gr.Button("Clear Conversation") |
|
|
|
|
|
gr.Markdown("### ⚙️ Model & Generation Settings") |
|
with gr.Row(): |
|
model_choice = gr.Dropdown( |
|
label="📱 Model", |
|
choices=list(MODELS.keys()), |
|
value="Athena-R3X 4B", |
|
info="Select which Athena model to use" |
|
) |
|
max_length = gr.Slider( |
|
32, 8192, value=512, |
|
label="📝 Max Tokens", |
|
info="Maximum number of tokens to generate" |
|
) |
|
temperature = gr.Slider( |
|
0.1, 2.0, value=0.7, |
|
label="🎨 Creativity", |
|
info="Higher values = more creative responses" |
|
) |
|
|
|
|
|
submit_event = user_input.submit( |
|
fn=chat_submit, |
|
inputs=[user_input, chatbot, conversation_state, model_choice, max_length, temperature], |
|
outputs=[user_input, chatbot, conversation_state, progress] |
|
) |
|
|
|
send_click = send_btn.click( |
|
fn=chat_submit, |
|
inputs=[user_input, chatbot, conversation_state, model_choice, max_length, temperature], |
|
outputs=[user_input, chatbot, conversation_state, progress] |
|
) |
|
|
|
clear_btn.click( |
|
fn=clear_conversation, |
|
outputs=[chatbot, conversation_state, progress] |
|
) |
|
|
|
|
|
gr.Examples( |
|
examples=[ |
|
"What is artificial intelligence?", |
|
"Can you explain quantum computing?", |
|
"Write a short poem about technology", |
|
"What are some ethical concerns about AI?" |
|
], |
|
inputs=user_input |
|
) |
|
|
|
gr.Markdown(""" |
|
### About the Thinking Tags |
|
Some Athena models (particularly R3X series) include reasoning in `<think></think>` tags. |
|
Click on "Thinking completed" to view the model's thought process behind its answers. |
|
""") |
|
|
|
if __name__ == "__main__": |
|
demo.queue() |
|
demo.launch(debug=True) |