Spestly commited on
Commit
77246c4
Β·
verified Β·
1 Parent(s): 609287a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -69
app.py CHANGED
@@ -1,10 +1,10 @@
1
  import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
- import spaces
5
  import time
6
 
7
- # Full precision models for H200 70GB
 
8
  MODELS = {
9
  "Athena-R3X 8B": "Spestly/Athena-R3X-8B",
10
  "Athena-R3X 4B": "Spestly/Athena-R3X-4B",
@@ -17,65 +17,75 @@ MODELS = {
17
  "Athena-1 7B": "Spestly/Athena-1-7B"
18
  }
19
 
20
- DEFAULT_MODEL = "Spestly/Athena-R3X-8B"
 
21
 
22
- # GPU-accelerated function
23
- @spaces.GPU
24
  def load_model(model_name):
25
- model_id = MODELS.get(model_name, DEFAULT_MODEL)
26
-
27
- print(f"πŸš€ Loading {model_id} on H200 GPU...")
 
 
28
  start_time = time.time()
29
-
30
  tokenizer = AutoTokenizer.from_pretrained(model_id)
31
-
32
  model = AutoModelForCausalLM.from_pretrained(
33
  model_id,
34
  torch_dtype=torch.bfloat16,
35
- device_map="auto",
36
- low_cpu_mem_usage=True
37
  )
38
-
 
 
39
  load_time = time.time() - start_time
40
- print(f"βœ… Model loaded in {load_time:.2f} seconds")
41
- print(f"GPU Memory Allocated: {torch.cuda.memory_allocated()/1e9:.2f}GB")
42
-
 
43
  return model, tokenizer
44
 
45
- @spaces.GPU
46
- def generate_text(prompt, model_name, max_length=512, temperature=0.7):
47
- try:
48
- model, tokenizer = load_model(model_name)
49
-
50
- inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
51
-
52
- start_time = time.time()
53
- with torch.no_grad():
54
- outputs = model.generate(
55
- **inputs,
56
- max_new_tokens=max_length,
57
- temperature=temperature,
58
- do_sample=True,
59
- top_p=0.9
60
- )
61
- generation_time = time.time() - start_time
62
-
63
- output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
64
-
65
- stats = f"""
66
- ⚑ Generation completed in {generation_time:.2f}s
67
- πŸ’Ύ GPU Memory: {torch.cuda.memory_allocated()/1e9:.2f}GB allocated
68
- 🌑️ Temperature: {temperature}
69
- """
70
-
71
- return output_text, stats
72
-
73
- except Exception as e:
74
- return f"❌ Error: {str(e)}", ""
75
-
76
- with gr.Blocks(title="Athena Playground") as demo:
77
- gr.Markdown("""# πŸš€ Athena Playground""")
78
-
 
 
 
 
 
 
 
 
79
  with gr.Row():
80
  with gr.Column(scale=1):
81
  model_choice = gr.Dropdown(
@@ -85,28 +95,32 @@ with gr.Blocks(title="Athena Playground") as demo:
85
  )
86
  max_length = gr.Slider(32, 4096, value=512, label="Max Tokens")
87
  temperature = gr.Slider(0.1, 2.0, value=0.7, label="Creativity")
88
- gr.Markdown("**Note:** First load may take 1-2 minutes")
89
- submit_btn = gr.Button("Generate", variant="primary")
90
-
91
  with gr.Column(scale=3):
92
- prompt = gr.Textbox(label="Your Prompt", lines=8, placeholder="Type your prompt here...")
93
- output = gr.Textbox(label="Model Output", lines=12)
94
- stats = gr.Textbox(label="Performance Stats", lines=3)
95
-
 
 
 
 
 
 
 
96
  submit_btn.click(
97
- generate_text,
98
- inputs=[prompt, model_choice, max_length, temperature],
99
- outputs=[output, stats]
 
100
  )
101
-
102
- gr.Examples(
103
- examples=[
104
- ["Explain the transformer architecture like I'm five"],
105
- ["Write a poem about AI in the style of Shakespeare"],
106
- ["Generate Python code for a convolutional neural network"]
107
- ],
108
- inputs=prompt
109
  )
110
 
111
  if __name__ == "__main__":
112
- demo.launch()
 
1
  import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
 
4
  import time
5
 
6
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
+
8
  MODELS = {
9
  "Athena-R3X 8B": "Spestly/Athena-R3X-8B",
10
  "Athena-R3X 4B": "Spestly/Athena-R3X-4B",
 
17
  "Athena-1 7B": "Spestly/Athena-1-7B"
18
  }
19
 
20
+ loaded_models = {}
21
+ loaded_tokenizers = {}
22
 
 
 
23
  def load_model(model_name):
24
+ if model_name in loaded_models:
25
+ return loaded_models[model_name], loaded_tokenizers[model_name]
26
+
27
+ model_id = MODELS.get(model_name, MODELS["Athena-R3X 8B"])
28
+ print(f"πŸš€ Loading {model_id} on {device}...")
29
  start_time = time.time()
30
+
31
  tokenizer = AutoTokenizer.from_pretrained(model_id)
 
32
  model = AutoModelForCausalLM.from_pretrained(
33
  model_id,
34
  torch_dtype=torch.bfloat16,
35
+ device_map=None
 
36
  )
37
+ model.to(device)
38
+ model.eval()
39
+
40
  load_time = time.time() - start_time
41
+ print(f"βœ… Model loaded in {load_time:.2f}s, GPU mem: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
42
+
43
+ loaded_models[model_name] = model
44
+ loaded_tokenizers[model_name] = tokenizer
45
  return model, tokenizer
46
 
47
+ def chatbot(conversation, user_message, model_name, max_length=512, temperature=0.7):
48
+ if conversation is None:
49
+ conversation = []
50
+ model, tokenizer = load_model(model_name)
51
+
52
+ # Append user message to conversation
53
+ conversation.append(("User", user_message))
54
+
55
+ # Build prompt from conversation history (simple concatenation)
56
+ prompt = ""
57
+ for speaker, text in conversation:
58
+ if speaker == "User":
59
+ prompt += f"User: {text}\n"
60
+ else:
61
+ prompt += f"Athena: {text}\n"
62
+ prompt += "Athena:"
63
+
64
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
65
+
66
+ start_time = time.time()
67
+ with torch.no_grad():
68
+ outputs = model.generate(
69
+ **inputs,
70
+ max_new_tokens=max_length,
71
+ temperature=temperature,
72
+ do_sample=True,
73
+ top_p=0.9,
74
+ pad_token_id=tokenizer.eos_token_id
75
+ )
76
+ generation_time = time.time() - start_time
77
+
78
+ output_text = tokenizer.decode(outputs[0][inputs['input_ids'].shape[-1]:], skip_special_tokens=True).strip()
79
+
80
+ conversation.append(("Athena", output_text))
81
+
82
+ stats = f"⚑ Generated in {generation_time:.2f}s | GPU mem: {torch.cuda.memory_allocated()/1e9:.2f} GB | Temp: {temperature}"
83
+
84
+ return conversation, "", stats
85
+
86
+ with gr.Blocks(title="Athena Playground Chat") as demo:
87
+ gr.Markdown("# πŸš€ Athena Playground Chat")
88
+
89
  with gr.Row():
90
  with gr.Column(scale=1):
91
  model_choice = gr.Dropdown(
 
95
  )
96
  max_length = gr.Slider(32, 4096, value=512, label="Max Tokens")
97
  temperature = gr.Slider(0.1, 2.0, value=0.7, label="Creativity")
98
+ clear_btn = gr.Button("Clear Chat")
99
+
 
100
  with gr.Column(scale=3):
101
+ chat_history = gr.Chatbot(elem_id="chatbot").style(height=600)
102
+ user_input = gr.Textbox(
103
+ placeholder="Ask Athena anything...",
104
+ label="Your message",
105
+ lines=2
106
+ )
107
+ submit_btn = gr.Button("Send")
108
+
109
+ def clear_chat():
110
+ return [], "", ""
111
+
112
  submit_btn.click(
113
+ chatbot,
114
+ inputs=[chat_history, user_input, model_choice, max_length, temperature],
115
+ outputs=[chat_history, user_input, gr.Textbox(label="Stats")],
116
+ queue=True
117
  )
118
+
119
+ clear_btn.click(
120
+ clear_chat,
121
+ inputs=[],
122
+ outputs=[chat_history, user_input, gr.Textbox(label="Stats")]
 
 
 
123
  )
124
 
125
  if __name__ == "__main__":
126
+ demo.launch()