Update app.py
Browse files
app.py
CHANGED
@@ -1,10 +1,10 @@
|
|
1 |
import gradio as gr
|
2 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
3 |
import torch
|
4 |
-
import spaces
|
5 |
import time
|
6 |
|
7 |
-
|
|
|
8 |
MODELS = {
|
9 |
"Athena-R3X 8B": "Spestly/Athena-R3X-8B",
|
10 |
"Athena-R3X 4B": "Spestly/Athena-R3X-4B",
|
@@ -17,65 +17,75 @@ MODELS = {
|
|
17 |
"Athena-1 7B": "Spestly/Athena-1-7B"
|
18 |
}
|
19 |
|
20 |
-
|
|
|
21 |
|
22 |
-
# GPU-accelerated function
|
23 |
-
@spaces.GPU
|
24 |
def load_model(model_name):
|
25 |
-
|
26 |
-
|
27 |
-
|
|
|
|
|
28 |
start_time = time.time()
|
29 |
-
|
30 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
31 |
-
|
32 |
model = AutoModelForCausalLM.from_pretrained(
|
33 |
model_id,
|
34 |
torch_dtype=torch.bfloat16,
|
35 |
-
device_map=
|
36 |
-
low_cpu_mem_usage=True
|
37 |
)
|
38 |
-
|
|
|
|
|
39 |
load_time = time.time() - start_time
|
40 |
-
print(f"β
Model loaded in {load_time:.2f}
|
41 |
-
|
42 |
-
|
|
|
43 |
return model, tokenizer
|
44 |
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
with gr.Row():
|
80 |
with gr.Column(scale=1):
|
81 |
model_choice = gr.Dropdown(
|
@@ -85,28 +95,32 @@ with gr.Blocks(title="Athena Playground") as demo:
|
|
85 |
)
|
86 |
max_length = gr.Slider(32, 4096, value=512, label="Max Tokens")
|
87 |
temperature = gr.Slider(0.1, 2.0, value=0.7, label="Creativity")
|
88 |
-
gr.
|
89 |
-
|
90 |
-
|
91 |
with gr.Column(scale=3):
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
submit_btn.click(
|
97 |
-
|
98 |
-
inputs=[
|
99 |
-
outputs=[
|
|
|
100 |
)
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
["Generate Python code for a convolutional neural network"]
|
107 |
-
],
|
108 |
-
inputs=prompt
|
109 |
)
|
110 |
|
111 |
if __name__ == "__main__":
|
112 |
-
demo.launch()
|
|
|
1 |
import gradio as gr
|
2 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
3 |
import torch
|
|
|
4 |
import time
|
5 |
|
6 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
7 |
+
|
8 |
MODELS = {
|
9 |
"Athena-R3X 8B": "Spestly/Athena-R3X-8B",
|
10 |
"Athena-R3X 4B": "Spestly/Athena-R3X-4B",
|
|
|
17 |
"Athena-1 7B": "Spestly/Athena-1-7B"
|
18 |
}
|
19 |
|
20 |
+
loaded_models = {}
|
21 |
+
loaded_tokenizers = {}
|
22 |
|
|
|
|
|
23 |
def load_model(model_name):
|
24 |
+
if model_name in loaded_models:
|
25 |
+
return loaded_models[model_name], loaded_tokenizers[model_name]
|
26 |
+
|
27 |
+
model_id = MODELS.get(model_name, MODELS["Athena-R3X 8B"])
|
28 |
+
print(f"π Loading {model_id} on {device}...")
|
29 |
start_time = time.time()
|
30 |
+
|
31 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
|
|
32 |
model = AutoModelForCausalLM.from_pretrained(
|
33 |
model_id,
|
34 |
torch_dtype=torch.bfloat16,
|
35 |
+
device_map=None
|
|
|
36 |
)
|
37 |
+
model.to(device)
|
38 |
+
model.eval()
|
39 |
+
|
40 |
load_time = time.time() - start_time
|
41 |
+
print(f"β
Model loaded in {load_time:.2f}s, GPU mem: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
|
42 |
+
|
43 |
+
loaded_models[model_name] = model
|
44 |
+
loaded_tokenizers[model_name] = tokenizer
|
45 |
return model, tokenizer
|
46 |
|
47 |
+
def chatbot(conversation, user_message, model_name, max_length=512, temperature=0.7):
|
48 |
+
if conversation is None:
|
49 |
+
conversation = []
|
50 |
+
model, tokenizer = load_model(model_name)
|
51 |
+
|
52 |
+
# Append user message to conversation
|
53 |
+
conversation.append(("User", user_message))
|
54 |
+
|
55 |
+
# Build prompt from conversation history (simple concatenation)
|
56 |
+
prompt = ""
|
57 |
+
for speaker, text in conversation:
|
58 |
+
if speaker == "User":
|
59 |
+
prompt += f"User: {text}\n"
|
60 |
+
else:
|
61 |
+
prompt += f"Athena: {text}\n"
|
62 |
+
prompt += "Athena:"
|
63 |
+
|
64 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
65 |
+
|
66 |
+
start_time = time.time()
|
67 |
+
with torch.no_grad():
|
68 |
+
outputs = model.generate(
|
69 |
+
**inputs,
|
70 |
+
max_new_tokens=max_length,
|
71 |
+
temperature=temperature,
|
72 |
+
do_sample=True,
|
73 |
+
top_p=0.9,
|
74 |
+
pad_token_id=tokenizer.eos_token_id
|
75 |
+
)
|
76 |
+
generation_time = time.time() - start_time
|
77 |
+
|
78 |
+
output_text = tokenizer.decode(outputs[0][inputs['input_ids'].shape[-1]:], skip_special_tokens=True).strip()
|
79 |
+
|
80 |
+
conversation.append(("Athena", output_text))
|
81 |
+
|
82 |
+
stats = f"β‘ Generated in {generation_time:.2f}s | GPU mem: {torch.cuda.memory_allocated()/1e9:.2f} GB | Temp: {temperature}"
|
83 |
+
|
84 |
+
return conversation, "", stats
|
85 |
+
|
86 |
+
with gr.Blocks(title="Athena Playground Chat") as demo:
|
87 |
+
gr.Markdown("# π Athena Playground Chat")
|
88 |
+
|
89 |
with gr.Row():
|
90 |
with gr.Column(scale=1):
|
91 |
model_choice = gr.Dropdown(
|
|
|
95 |
)
|
96 |
max_length = gr.Slider(32, 4096, value=512, label="Max Tokens")
|
97 |
temperature = gr.Slider(0.1, 2.0, value=0.7, label="Creativity")
|
98 |
+
clear_btn = gr.Button("Clear Chat")
|
99 |
+
|
|
|
100 |
with gr.Column(scale=3):
|
101 |
+
chat_history = gr.Chatbot(elem_id="chatbot").style(height=600)
|
102 |
+
user_input = gr.Textbox(
|
103 |
+
placeholder="Ask Athena anything...",
|
104 |
+
label="Your message",
|
105 |
+
lines=2
|
106 |
+
)
|
107 |
+
submit_btn = gr.Button("Send")
|
108 |
+
|
109 |
+
def clear_chat():
|
110 |
+
return [], "", ""
|
111 |
+
|
112 |
submit_btn.click(
|
113 |
+
chatbot,
|
114 |
+
inputs=[chat_history, user_input, model_choice, max_length, temperature],
|
115 |
+
outputs=[chat_history, user_input, gr.Textbox(label="Stats")],
|
116 |
+
queue=True
|
117 |
)
|
118 |
+
|
119 |
+
clear_btn.click(
|
120 |
+
clear_chat,
|
121 |
+
inputs=[],
|
122 |
+
outputs=[chat_history, user_input, gr.Textbox(label="Stats")]
|
|
|
|
|
|
|
123 |
)
|
124 |
|
125 |
if __name__ == "__main__":
|
126 |
+
demo.launch()
|