AbstractPhil commited on
Commit
bbb5633
·
verified ·
1 Parent(s): 6a080c2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +155 -154
app.py CHANGED
@@ -1,10 +1,18 @@
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import torch
3
- from beeper_model import BeeperRoseGPT, generate
4
  from tokenizers import Tokenizer
5
  from huggingface_hub import hf_hub_download
6
  from safetensors.torch import load_file as load_safetensors
7
 
 
 
8
  # ----------------------------
9
  # 🔧 Model versions configuration
10
  # ----------------------------
@@ -31,8 +39,8 @@ MODEL_VERSIONS = {
31
  },
32
  }
33
 
34
- # Base configuration
35
- config = {
36
  "context": 512,
37
  "vocab_size": 8192,
38
  "dim": 512,
@@ -42,171 +50,169 @@ config = {
42
  "temperature": 0.9,
43
  "top_k": 40,
44
  "top_p": 0.9,
45
- "repetition_penalty": 1.1,
46
  "presence_penalty": 0.6,
47
  "frequency_penalty": 0.0,
48
  "resid_dropout": 0.1,
49
  "dropout": 0.0,
50
  "grad_checkpoint": False,
51
- "tokenizer_path": "beeper.tokenizer.json"
52
  }
53
 
54
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
55
 
56
- # Global model and tokenizer variables
57
- infer = None
58
- tok = None
59
- current_version = None
 
60
 
61
- def load_model_version(version_name):
62
- """Load the selected model version"""
 
 
 
63
  global infer, tok, current_version
64
-
65
- if current_version == version_name and infer is not None:
66
  return f"Already loaded: {version_name}"
67
-
68
  version_info = MODEL_VERSIONS[version_name]
69
-
70
  try:
71
- # Download model and tokenizer files
72
  model_file = hf_hub_download(
73
- repo_id=version_info["repo_id"],
74
  filename=version_info["model_file"]
75
  )
76
  tokenizer_file = hf_hub_download(
77
- repo_id=version_info["repo_id"],
78
  filename="tokenizer.json"
79
  )
80
-
81
- # Initialize model
82
- infer = BeeperRoseGPT(config).to(device)
83
-
84
- # Load safetensors
85
- state_dict = load_safetensors(model_file, device=str(device))
86
- infer.load_state_dict(state_dict)
87
- infer.eval()
88
-
89
- # Load tokenizer
90
- tok = Tokenizer.from_file(tokenizer_file)
91
-
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  current_version = version_name
93
- return f"Successfully loaded: {version_name}"
94
-
95
  except Exception as e:
 
 
 
96
  return f"Error loading {version_name}: {str(e)}"
97
 
98
- # Load default model on startup - try v4 first, fallback to v3
 
99
  try:
100
  load_status = load_model_version("Beeper v4 (Advanced)")
101
  if "Error" in load_status:
102
  print(f"v4 not ready yet: {load_status}")
103
  load_status = load_model_version("Beeper v3 (Multi-Concept)")
104
- except:
105
  load_status = load_model_version("Beeper v3 (Multi-Concept)")
106
-
107
  print(load_status)
108
 
 
109
  # ----------------------------
110
- # ���� Gradio Chat Wrapper
111
  # ----------------------------
112
- def beeper_reply(message, history, model_version, temperature=None, top_k=None, top_p=None, max_new_tokens=80):
 
 
 
 
 
 
 
 
113
  global infer, tok, current_version
114
-
115
- # Load model if version changed
116
  if model_version != current_version:
117
  status = load_model_version(model_version)
118
  if "Error" in status:
119
  return f"⚠️ {status}"
120
-
121
- # Check if model is loaded
122
  if infer is None or tok is None:
123
  return "⚠️ Model not loaded. Please select a version and try again."
124
-
125
- # Use defaults if not provided
126
- if temperature is None:
127
- temperature = 0.9
128
- if top_k is None:
129
- top_k = 40
130
- if top_p is None:
131
- top_p = 0.9
132
-
133
- # Try Q&A format since she has some in corpus
134
- if "?" in message:
135
- prompt = f"Q: {message}\nA:"
136
- elif message.lower().strip() in ["hi", "hello", "hey"]:
137
- prompt = "The little robot said hello. She said, \""
138
- elif "story" in message.lower():
139
  prompt = "Once upon a time, there was a robot. "
140
  else:
141
- # Simple continuation
142
- prompt = message + ". "
143
-
144
- # Generate response with lower temperature for less repetition
145
- response = generate(
146
  model=infer,
147
  tok=tok,
148
- cfg=config,
149
  prompt=prompt,
150
- max_new_tokens=max_new_tokens, # Shorter to avoid rambling
151
- temperature=float(temperature), # Slightly lower temp
152
- top_k=int(top_k),
153
- top_p=float(top_p),
154
- repetition_penalty=1.1, # Higher penalty for repetition
155
- presence_penalty=0.8, # Higher presence penalty
156
- frequency_penalty=0.1, # Add frequency penalty
157
  device=device,
158
- detokenize=True
159
  )
160
-
161
- # Aggressive cleanup
162
- # Remove the prompt completely
163
- if response.startswith(prompt):
164
- response = response[len(prompt):]
165
-
166
- # Remove Q&A format artifacts
167
- response = response.replace("Q:", "").replace("A:", "")
168
-
169
- # Split on newlines and take first non-empty line
170
- lines = response.split('\n')
171
- for line in lines:
172
- clean_line = line.strip()
173
- if clean_line and not clean_line.startswith(message[:10]):
174
- response = clean_line
175
- break
176
-
177
- # If response still contains the user message, try to extract after it
178
- if message.lower()[:20] in response.lower()[:50]:
179
- # Find where the echo ends
180
- words_in_message = message.split()
181
- for i in range(min(5, len(words_in_message)), 0, -1):
182
- pattern = ' '.join(words_in_message[:i])
183
- if pattern.lower() in response.lower():
184
- idx = response.lower().find(pattern.lower()) + len(pattern)
185
- response = response[idx:].strip()
186
- break
187
-
188
- # Remove any remaining "User" or "Beeper" artifacts
189
- for artifact in ["User:", "Beeper:", "U ser:", "Beep er:", "User ", "Beeper "]:
190
- response = response.replace(artifact, "")
191
-
192
- # Ensure we have something
193
- if not response or len(response) < 3:
194
- responses = [
195
- "I like robots and stories!",
196
- "That's interesting!",
197
- "I want to play in the park.",
198
- "The robot was happy.",
199
- "Yes, I think so too!"
200
- ]
201
- import random
202
- response = random.choice(responses)
203
-
204
- # Clean ending
205
- response = response.strip()
206
- if response and response[-1] not in '.!?"':
207
- response = response.rsplit('.', 1)[0] + '.' if '.' in response else response + '.'
208
-
209
- return response[:200] # Cap length
210
 
211
  # ----------------------------
212
  # 🖼️ Interface
@@ -214,37 +220,34 @@ def beeper_reply(message, history, model_version, temperature=None, top_k=None,
214
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
215
  gr.Markdown(
216
  """
217
- # 🤖 Beeper - A Rose-based Tiny Language Model
218
- Hello! I'm Beeper, a small language model trained with love and care. Please be patient with me - I'm still learning! 💕
219
  """
220
  )
221
-
222
  with gr.Row():
223
  with gr.Column(scale=3):
224
  model_dropdown = gr.Dropdown(
225
  choices=list(MODEL_VERSIONS.keys()),
226
- value="Beeper v3 (Multi-Concept)", # Default to v3 since v4 might not be ready
227
  label="Select Beeper Version",
228
- info="Choose which version of Beeper to chat with"
229
  )
230
  with gr.Column(scale=7):
231
- version_info = gr.Markdown("**Current:** Beeper v3 with 30+ epochs including reasoning, math, and ethics")
232
-
233
- # Update version info when dropdown changes
234
- def update_version_info(version_name):
235
- info = MODEL_VERSIONS[version_name]["description"]
236
- return f"**Current:** {info}"
237
-
238
  model_dropdown.change(
239
  fn=update_version_info,
240
  inputs=[model_dropdown],
241
- outputs=[version_info]
242
  )
243
-
244
- # Chat interface
245
- chatbot = gr.Chatbot(label="Chat with Beeper", type="tuples", height=400)
246
  msg = gr.Textbox(label="Message", placeholder="Type your message here...")
247
-
248
  with gr.Row():
249
  with gr.Column(scale=2):
250
  temperature_slider = gr.Slider(0.1, 1.5, value=0.9, step=0.1, label="Temperature")
@@ -253,13 +256,12 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
253
  with gr.Column(scale=2):
254
  top_p_slider = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
255
  with gr.Column(scale=2):
256
- max_new_tokens_slider = gr.Slider(20, 512, value=128, step=1, label="Max-new-tokens")
257
-
258
  with gr.Row():
259
  submit = gr.Button("Send", variant="primary")
260
  clear = gr.Button("Clear")
261
-
262
- # Examples
263
  gr.Examples(
264
  examples=[
265
  ["Hello Beeper! How are you today?"],
@@ -268,28 +270,27 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
268
  ["What makes you happy?"],
269
  ["Tell me about your dreams"],
270
  ],
271
- inputs=msg
272
  )
273
-
274
- # Handle chat
275
  def respond(message, chat_history, model_version, temperature, top_k, top_p, max_new_tokens):
276
- if not chat_history:
277
  chat_history = []
278
  response = beeper_reply(message, chat_history, model_version, temperature, top_k, top_p, max_new_tokens)
279
- chat_history.append([message, response])
280
  return "", chat_history
281
-
282
  msg.submit(
283
- respond,
284
- [msg, chatbot, model_dropdown, temperature_slider, top_k_slider, top_p_slider, max_new_tokens_slider],
285
- [msg, chatbot]
286
  )
287
  submit.click(
288
- respond,
289
- [msg, chatbot, model_dropdown, temperature_slider, top_k_slider, top_p_slider, max_new_tokens_slider],
290
- [msg, chatbot]
291
  )
292
  clear.click(lambda: None, None, chatbot, queue=False)
293
 
294
  if __name__ == "__main__":
295
- demo.launch()
 
1
+ # app.py
2
+ # --------------------------------------------------------------------------------------------------
3
+ # Gradio app for Beeper
4
+ # - Loads released safetensors + tokenizer from Hugging Face
5
+ # - Auto-sizes pentachora banks to match checkpoints (across Beeper v1..v4)
6
+ # - Generation uses same knobs & penalties as training script
7
+ # --------------------------------------------------------------------------------------------------
8
  import gradio as gr
9
  import torch
 
10
  from tokenizers import Tokenizer
11
  from huggingface_hub import hf_hub_download
12
  from safetensors.torch import load_file as load_safetensors
13
 
14
+ from beeper import BeeperRoseGPT, generate, prepare_model_for_state_dict
15
+
16
  # ----------------------------
17
  # 🔧 Model versions configuration
18
  # ----------------------------
 
39
  },
40
  }
41
 
42
+ # Base configuration (matches training defaults)
43
+ CONFIG = {
44
  "context": 512,
45
  "vocab_size": 8192,
46
  "dim": 512,
 
50
  "temperature": 0.9,
51
  "top_k": 40,
52
  "top_p": 0.9,
53
+ "repetition_penalty": 1.10,
54
  "presence_penalty": 0.6,
55
  "frequency_penalty": 0.0,
56
  "resid_dropout": 0.1,
57
  "dropout": 0.0,
58
  "grad_checkpoint": False,
59
+ # tokenizer_path not needed here; we load tokenizer.json from the HF repo
60
  }
61
 
62
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
63
 
64
+ # Globals (kept simple for a single process Gradio app)
65
+ infer: BeeperRoseGPT | None = None
66
+ tok: Tokenizer | None = None
67
+ current_version: str | None = None
68
+
69
 
70
+ def load_model_version(version_name: str) -> str:
71
+ """
72
+ Download the checkpoint and tokenizer, build model, ensure pentachora sizes match,
73
+ then strictly load weights. Robust to v1/v2 (no pentas) and v3/v4 (with pentas).
74
+ """
75
  global infer, tok, current_version
76
+
77
+ if current_version == version_name and infer is not None and tok is not None:
78
  return f"Already loaded: {version_name}"
79
+
80
  version_info = MODEL_VERSIONS[version_name]
81
+
82
  try:
83
+ # Download artifacts
84
  model_file = hf_hub_download(
85
+ repo_id=version_info["repo_id"],
86
  filename=version_info["model_file"]
87
  )
88
  tokenizer_file = hf_hub_download(
89
+ repo_id=version_info["repo_id"],
90
  filename="tokenizer.json"
91
  )
92
+
93
+ # Load state dict on CPU, inspect pentachora shapes if present
94
+ state_dict = load_safetensors(model_file, device="cpu")
95
+
96
+ # Build model & pre-create pentachora if needed
97
+ m = BeeperRoseGPT(CONFIG).to(device)
98
+ prepare_model_for_state_dict(m, state_dict, device=device)
99
+
100
+ # Try strict load first; if shapes drift (rare), fallback to non-strict
101
+ try:
102
+ missing, unexpected = m.load_state_dict(state_dict, strict=True)
103
+ # PyTorch returns NamedTuple; report counts
104
+ _msg = f"strict load ok | missing={len(missing)} unexpected={len(unexpected)}"
105
+ except Exception as e:
106
+ _msg = f"strict load failed ({e}); trying non-strict…"
107
+ # Non-strict load for very old snapshots
108
+ m.load_state_dict(state_dict, strict=False)
109
+
110
+ m.eval()
111
+
112
+ # Tokenizer
113
+ t = Tokenizer.from_file(tokenizer_file)
114
+
115
+ # Swap globals
116
+ infer, tok = m, t
117
  current_version = version_name
118
+ return f"Successfully loaded: {version_name} ({_msg})"
119
+
120
  except Exception as e:
121
+ infer = None
122
+ tok = None
123
+ current_version = None
124
  return f"Error loading {version_name}: {str(e)}"
125
 
126
+
127
+ # Load default on startup — prefer v4, fallback to v3
128
  try:
129
  load_status = load_model_version("Beeper v4 (Advanced)")
130
  if "Error" in load_status:
131
  print(f"v4 not ready yet: {load_status}")
132
  load_status = load_model_version("Beeper v3 (Multi-Concept)")
133
+ except Exception as _:
134
  load_status = load_model_version("Beeper v3 (Multi-Concept)")
 
135
  print(load_status)
136
 
137
+
138
  # ----------------------------
139
+ # 💬 Chat wrapper
140
  # ----------------------------
141
+ def beeper_reply(
142
+ message: str,
143
+ history: list[tuple[str, str]] | None,
144
+ model_version: str,
145
+ temperature: float | None,
146
+ top_k: int | None,
147
+ top_p: float | None,
148
+ max_new_tokens: int = 80
149
+ ) -> str:
150
  global infer, tok, current_version
151
+
152
+ # Hot-swap versions if the dropdown changed
153
  if model_version != current_version:
154
  status = load_model_version(model_version)
155
  if "Error" in status:
156
  return f"⚠️ {status}"
157
+
 
158
  if infer is None or tok is None:
159
  return "⚠️ Model not loaded. Please select a version and try again."
160
+
161
+ # Light prompting heuristics (consistent with your example)
162
+ m = message.strip()
163
+ if "?" in m:
164
+ prompt = f"Q: {m}\nA:"
165
+ elif m.lower() in {"hi", "hello", "hey"}:
166
+ prompt = 'The little robot said hello. She said, "'
167
+ elif "story" in m.lower():
 
 
 
 
 
 
 
168
  prompt = "Once upon a time, there was a robot. "
169
  else:
170
+ prompt = m + ". "
171
+
172
+ # Generate
173
+ text = generate(
 
174
  model=infer,
175
  tok=tok,
176
+ cfg=CONFIG,
177
  prompt=prompt,
178
+ max_new_tokens=int(max_new_tokens),
179
+ temperature=float(temperature) if temperature is not None else None,
180
+ top_k=int(top_k) if top_k is not None else None,
181
+ top_p=float(top_p) if top_p is not None else None,
182
+ repetition_penalty=1.10,
183
+ presence_penalty=0.8,
184
+ frequency_penalty=0.1,
185
  device=device,
186
+ detokenize=True,
187
  )
188
+
189
+ # Strip prompt echoes & artifacts
190
+ if text.startswith(prompt):
191
+ text = text[len(prompt):]
192
+ text = text.replace("Q:", "").replace("A:", "")
193
+
194
+ lines = [ln.strip() for ln in text.splitlines() if ln.strip()]
195
+ if lines:
196
+ text = lines[0]
197
+
198
+ # If user message echoed at head, trim after first occurrence
199
+ head = m[:20].lower()
200
+ if text.lower().startswith(head):
201
+ idx = text.lower().find(head)
202
+ text = text[idx + len(head):].strip() or text
203
+
204
+ for artifact in ("User:", "Beeper:", "U ser:", "Beep er:", "User ", "Beeper "):
205
+ text = text.replace(artifact, "")
206
+
207
+ text = text.strip()
208
+ if not text or len(text) < 3:
209
+ text = "I like robots and stories!"
210
+
211
+ if text[-1:] not in ".!?”\"'":
212
+ text += "."
213
+
214
+ return text[:200]
215
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
 
217
  # ----------------------------
218
  # 🖼️ Interface
 
220
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
221
  gr.Markdown(
222
  """
223
+ # 🤖 Beeper A Rose-based Tiny Language Model
224
+ Hello! I'm Beeper, a small language model trained with love and care. Please be patient with me I'm still learning! 💕
225
  """
226
  )
227
+
228
  with gr.Row():
229
  with gr.Column(scale=3):
230
  model_dropdown = gr.Dropdown(
231
  choices=list(MODEL_VERSIONS.keys()),
232
+ value="Beeper v3 (Multi-Concept)", # safer default
233
  label="Select Beeper Version",
234
+ info="Choose which version of Beeper to chat with",
235
  )
236
  with gr.Column(scale=7):
237
+ version_info = gr.Markdown("**Current:** " + MODEL_VERSIONS["Beeper v3 (Multi-Concept)"]["description"])
238
+
239
+ def update_version_info(version_name: str):
240
+ return f"**Current:** {MODEL_VERSIONS[version_name]['description']}"
241
+
 
 
242
  model_dropdown.change(
243
  fn=update_version_info,
244
  inputs=[model_dropdown],
245
+ outputs=[version_info],
246
  )
247
+
248
+ chatbot = gr.Chatbot(label="Chat with Beeper", height=400)
 
249
  msg = gr.Textbox(label="Message", placeholder="Type your message here...")
250
+
251
  with gr.Row():
252
  with gr.Column(scale=2):
253
  temperature_slider = gr.Slider(0.1, 1.5, value=0.9, step=0.1, label="Temperature")
 
256
  with gr.Column(scale=2):
257
  top_p_slider = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
258
  with gr.Column(scale=2):
259
+ max_new_tokens_slider = gr.Slider(20, 512, value=128, step=1, label="Max new tokens")
260
+
261
  with gr.Row():
262
  submit = gr.Button("Send", variant="primary")
263
  clear = gr.Button("Clear")
264
+
 
265
  gr.Examples(
266
  examples=[
267
  ["Hello Beeper! How are you today?"],
 
270
  ["What makes you happy?"],
271
  ["Tell me about your dreams"],
272
  ],
273
+ inputs=msg,
274
  )
275
+
 
276
  def respond(message, chat_history, model_version, temperature, top_k, top_p, max_new_tokens):
277
+ if chat_history is None:
278
  chat_history = []
279
  response = beeper_reply(message, chat_history, model_version, temperature, top_k, top_p, max_new_tokens)
280
+ chat_history.append((message, response))
281
  return "", chat_history
282
+
283
  msg.submit(
284
+ respond,
285
+ [msg, chatbot, model_dropdown, temperature_slider, top_k_slider, top_p_slider, max_new_tokens_slider],
286
+ [msg, chatbot],
287
  )
288
  submit.click(
289
+ respond,
290
+ [msg, chatbot, model_dropdown, temperature_slider, top_k_slider, top_p_slider, max_new_tokens_slider],
291
+ [msg, chatbot],
292
  )
293
  clear.click(lambda: None, None, chatbot, queue=False)
294
 
295
  if __name__ == "__main__":
296
+ demo.launch()