orionweller commited on
Commit
565fb95
·
verified ·
1 Parent(s): ea15511

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -8
app.py CHANGED
@@ -70,9 +70,9 @@ def prepare_mlm_sample(text, mask_ratio=0.15):
70
  print(f"Maskable indices count: {len(maskable_indices)}")
71
  print(f"Mask ratio: {mask_ratio}")
72
 
73
- # Calculate how many tokens to mask, but ensure at least 1 and at most 8
74
- # Use the maskable_indices length with the ratio
75
- num_to_mask = max(1, min(8, int(len(maskable_indices) * mask_ratio)))
76
  print(f"Number of tokens to mask: {num_to_mask}")
77
 
78
  # Randomly select indices to mask
@@ -256,15 +256,26 @@ def prepare_next_token_prediction():
256
  full_hidden = original_text[len(masked_text):].strip()
257
 
258
  # Tokenize the hidden part
259
- ntp_state["tokens"] = tokenizer.tokenize(full_hidden)
 
 
 
 
 
 
 
 
 
 
260
  ntp_state["full_text"] = full_hidden
261
  ntp_state["revealed_text"] = ""
262
  ntp_state["next_token_idx"] = 0
263
 
264
  # Make sure we have tokens to predict
265
  if not ntp_state["tokens"]:
266
- # If we don't have tokens, get a new sample
267
- new_text = get_new_sample("ntp", 0.3)
 
268
  prepare_next_token_prediction()
269
 
270
  def check_ntp_answer(user_continuation):
@@ -275,6 +286,12 @@ def check_ntp_answer(user_continuation):
275
  if not ntp_state["tokens"]:
276
  prepare_next_token_prediction()
277
 
 
 
 
 
 
 
278
  # No more tokens to predict
279
  if ntp_state["next_token_idx"] >= len(ntp_state["tokens"]):
280
  # Reset for next round
@@ -282,6 +299,7 @@ def check_ntp_answer(user_continuation):
282
 
283
  # Get the next token to predict
284
  next_token = ntp_state["tokens"][ntp_state["next_token_idx"]]
 
285
 
286
  # Get user's prediction
287
  user_text = user_continuation.strip()
@@ -289,6 +307,7 @@ def check_ntp_answer(user_continuation):
289
  # Tokenize user's prediction to get their first token
290
  user_tokens = tokenizer.tokenize(user_text)
291
  user_token = user_tokens[0].lower() if user_tokens else ""
 
292
 
293
  # Clean up tokens for comparison
294
  next_token_clean = next_token.lower()
@@ -300,6 +319,7 @@ def check_ntp_answer(user_continuation):
300
 
301
  # Check if correct
302
  is_correct = (user_token == next_token_clean)
 
303
 
304
  # Update stats
305
  if is_correct:
@@ -307,7 +327,7 @@ def check_ntp_answer(user_continuation):
307
  user_stats["ntp"]["total"] += 1
308
 
309
  # Reveal this token and prepare for next
310
- ntp_state["revealed_text"] += " " + tokenizer.convert_tokens_to_string([next_token])
311
  ntp_state["next_token_idx"] += 1
312
 
313
  # Calculate overall accuracy
@@ -320,7 +340,7 @@ def check_ntp_answer(user_continuation):
320
  feedback.append(f"✗ Not quite. The actual next token was '{next_token_clean}'")
321
 
322
  # Show progress
323
- feedback.append(f"\nRevealed so far: {masked_text}{ntp_state['revealed_text']}")
324
 
325
  # If there are more tokens, prompt for next
326
  if ntp_state["next_token_idx"] < len(ntp_state["tokens"]):
 
70
  print(f"Maskable indices count: {len(maskable_indices)}")
71
  print(f"Mask ratio: {mask_ratio}")
72
 
73
+ # Calculate how many tokens to mask based on the mask ratio
74
+ # No arbitrary cap - use the actual percentage
75
+ num_to_mask = max(1, int(len(maskable_indices) * mask_ratio))
76
  print(f"Number of tokens to mask: {num_to_mask}")
77
 
78
  # Randomly select indices to mask
 
256
  full_hidden = original_text[len(masked_text):].strip()
257
 
258
  # Tokenize the hidden part
259
+ hidden_tokens = tokenizer.tokenize(full_hidden)
260
+
261
+ # Print debug info
262
+ print(f"NTP State setup:")
263
+ print(f" Full text: '{original_text}'")
264
+ print(f" Visible text: '{masked_text}'")
265
+ print(f" Hidden text: '{full_hidden}'")
266
+ print(f" Hidden tokens: {hidden_tokens}")
267
+
268
+ # Set up the NTP state
269
+ ntp_state["tokens"] = hidden_tokens
270
  ntp_state["full_text"] = full_hidden
271
  ntp_state["revealed_text"] = ""
272
  ntp_state["next_token_idx"] = 0
273
 
274
  # Make sure we have tokens to predict
275
  if not ntp_state["tokens"]:
276
+ print("Warning: No tokens to predict, will try another sample")
277
+ # If we don't have tokens, get a new sample with a higher cut ratio
278
+ new_text = get_new_sample("ntp", 0.4) # Use higher cut ratio
279
  prepare_next_token_prediction()
280
 
281
  def check_ntp_answer(user_continuation):
 
286
  if not ntp_state["tokens"]:
287
  prepare_next_token_prediction()
288
 
289
+ # Print debug info
290
+ print(f"Current NTP state:")
291
+ print(f" Next token index: {ntp_state['next_token_idx']}")
292
+ print(f" Total tokens: {len(ntp_state['tokens'])}")
293
+ print(f" User input: '{user_continuation}'")
294
+
295
  # No more tokens to predict
296
  if ntp_state["next_token_idx"] >= len(ntp_state["tokens"]):
297
  # Reset for next round
 
299
 
300
  # Get the next token to predict
301
  next_token = ntp_state["tokens"][ntp_state["next_token_idx"]]
302
+ print(f" Expected next token: '{next_token}'")
303
 
304
  # Get user's prediction
305
  user_text = user_continuation.strip()
 
307
  # Tokenize user's prediction to get their first token
308
  user_tokens = tokenizer.tokenize(user_text)
309
  user_token = user_tokens[0].lower() if user_tokens else ""
310
+ print(f" User's tokenized input: {user_tokens}")
311
 
312
  # Clean up tokens for comparison
313
  next_token_clean = next_token.lower()
 
319
 
320
  # Check if correct
321
  is_correct = (user_token == next_token_clean)
322
+ print(f" Comparison: '{user_token}' vs '{next_token_clean}' -> {'Correct' if is_correct else 'Incorrect'}")
323
 
324
  # Update stats
325
  if is_correct:
 
327
  user_stats["ntp"]["total"] += 1
328
 
329
  # Reveal this token and prepare for next
330
+ ntp_state["revealed_text"] += tokenizer.convert_tokens_to_string([next_token])
331
  ntp_state["next_token_idx"] += 1
332
 
333
  # Calculate overall accuracy
 
340
  feedback.append(f"✗ Not quite. The actual next token was '{next_token_clean}'")
341
 
342
  # Show progress
343
+ feedback.append(f"\nText so far: {masked_text}{ntp_state['revealed_text']}")
344
 
345
  # If there are more tokens, prompt for next
346
  if ntp_state["next_token_idx"] < len(ntp_state["tokens"]):