|
import gradio as gr |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
import torch |
|
from huggingface_hub import space |
|
import time |
|
|
|
|
|
MODELS = { |
|
"Athena-R3X 8B": "Spestly/Athena-R3X-8B", |
|
"Athena-R3X 4B": "Spestly/Athena-R3X-4B", |
|
"Athena-R3 7B": "Spestly/Athena-R3-7B", |
|
"Athena-3 3B": "Spestly/Athena-3-3B", |
|
"Athena-3 7B": "Spestly/Athena-3-7B", |
|
"Athena-3 14B": "Spestly/Athena-3-14B", |
|
"Athena-2 1.5B": "Spestly/Athena-2-1.5B", |
|
"Athena-1 3B": "Spestly/Athena-1-3B", |
|
"Athena-1 7B": "Spestly/Athena-1-7B" |
|
} |
|
|
|
DEFAULT_MODEL = "Spestly/Athena-R3X-8B" |
|
|
|
|
|
@space.GPU |
|
def load_model(model_name): |
|
model_id = MODELS.get(model_name, DEFAULT_MODEL) |
|
|
|
print(f"π Loading {model_id} on H200 GPU...") |
|
start_time = time.time() |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_id, |
|
torch_dtype=torch.bfloat16, |
|
device_map="auto", |
|
low_cpu_mem_usage=True |
|
) |
|
|
|
load_time = time.time() - start_time |
|
print(f"β
Model loaded in {load_time:.2f} seconds") |
|
print(f"GPU Memory Allocated: {torch.cuda.memory_allocated()/1e9:.2f}GB") |
|
|
|
return model, tokenizer |
|
|
|
@space.GPU |
|
def generate_text(prompt, model_name, max_length=512, temperature=0.7): |
|
try: |
|
model, tokenizer = load_model(model_name) |
|
|
|
inputs = tokenizer(prompt, return_tensors="pt").to("cuda") |
|
|
|
start_time = time.time() |
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
**inputs, |
|
max_new_tokens=max_length, |
|
temperature=temperature, |
|
do_sample=True, |
|
top_p=0.9 |
|
) |
|
generation_time = time.time() - start_time |
|
|
|
output_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
stats = f""" |
|
β‘ Generation completed in {generation_time:.2f}s |
|
πΎ GPU Memory: {torch.cuda.memory_allocated()/1e9:.2f}GB allocated |
|
π‘οΈ Temperature: {temperature} |
|
""" |
|
|
|
return output_text, stats |
|
|
|
except Exception as e: |
|
return f"β Error: {str(e)}", "" |
|
|
|
with gr.Blocks(title="Athena Playground") as demo: |
|
gr.Markdown("""# π Athena Playground""") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
model_choice = gr.Dropdown( |
|
label="Model", |
|
choices=list(MODELS.keys()), |
|
value="Athena-R3X 8B" |
|
) |
|
max_length = gr.Slider(32, 4096, value=512, label="Max Tokens") |
|
temperature = gr.Slider(0.1, 2.0, value=0.7, label="Creativity") |
|
gr.Markdown("**Note:** First load may take 1-2 minutes") |
|
submit_btn = gr.Button("Generate", variant="primary") |
|
|
|
with gr.Column(scale=3): |
|
prompt = gr.Textbox(label="Your Prompt", lines=8, placeholder="Type your prompt here...") |
|
output = gr.Textbox(label="Model Output", lines=12) |
|
stats = gr.Textbox(label="Performance Stats", lines=3) |
|
|
|
submit_btn.click( |
|
generate_text, |
|
inputs=[prompt, model_choice, max_length, temperature], |
|
outputs=[output, stats] |
|
) |
|
|
|
gr.Examples( |
|
examples=[ |
|
["Explain the transformer architecture like I'm five"], |
|
["Write a poem about AI in the style of Shakespeare"], |
|
["Generate Python code for a convolutional neural network"] |
|
], |
|
inputs=prompt |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |