Update app.py
Browse files
app.py
CHANGED
@@ -70,9 +70,9 @@ def prepare_mlm_sample(text, mask_ratio=0.15):
|
|
70 |
print(f"Maskable indices count: {len(maskable_indices)}")
|
71 |
print(f"Mask ratio: {mask_ratio}")
|
72 |
|
73 |
-
# Calculate how many tokens to mask
|
74 |
-
#
|
75 |
-
num_to_mask = max(1,
|
76 |
print(f"Number of tokens to mask: {num_to_mask}")
|
77 |
|
78 |
# Randomly select indices to mask
|
@@ -256,15 +256,26 @@ def prepare_next_token_prediction():
|
|
256 |
full_hidden = original_text[len(masked_text):].strip()
|
257 |
|
258 |
# Tokenize the hidden part
|
259 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
260 |
ntp_state["full_text"] = full_hidden
|
261 |
ntp_state["revealed_text"] = ""
|
262 |
ntp_state["next_token_idx"] = 0
|
263 |
|
264 |
# Make sure we have tokens to predict
|
265 |
if not ntp_state["tokens"]:
|
266 |
-
|
267 |
-
|
|
|
268 |
prepare_next_token_prediction()
|
269 |
|
270 |
def check_ntp_answer(user_continuation):
|
@@ -275,6 +286,12 @@ def check_ntp_answer(user_continuation):
|
|
275 |
if not ntp_state["tokens"]:
|
276 |
prepare_next_token_prediction()
|
277 |
|
|
|
|
|
|
|
|
|
|
|
|
|
278 |
# No more tokens to predict
|
279 |
if ntp_state["next_token_idx"] >= len(ntp_state["tokens"]):
|
280 |
# Reset for next round
|
@@ -282,6 +299,7 @@ def check_ntp_answer(user_continuation):
|
|
282 |
|
283 |
# Get the next token to predict
|
284 |
next_token = ntp_state["tokens"][ntp_state["next_token_idx"]]
|
|
|
285 |
|
286 |
# Get user's prediction
|
287 |
user_text = user_continuation.strip()
|
@@ -289,6 +307,7 @@ def check_ntp_answer(user_continuation):
|
|
289 |
# Tokenize user's prediction to get their first token
|
290 |
user_tokens = tokenizer.tokenize(user_text)
|
291 |
user_token = user_tokens[0].lower() if user_tokens else ""
|
|
|
292 |
|
293 |
# Clean up tokens for comparison
|
294 |
next_token_clean = next_token.lower()
|
@@ -300,6 +319,7 @@ def check_ntp_answer(user_continuation):
|
|
300 |
|
301 |
# Check if correct
|
302 |
is_correct = (user_token == next_token_clean)
|
|
|
303 |
|
304 |
# Update stats
|
305 |
if is_correct:
|
@@ -307,7 +327,7 @@ def check_ntp_answer(user_continuation):
|
|
307 |
user_stats["ntp"]["total"] += 1
|
308 |
|
309 |
# Reveal this token and prepare for next
|
310 |
-
ntp_state["revealed_text"] +=
|
311 |
ntp_state["next_token_idx"] += 1
|
312 |
|
313 |
# Calculate overall accuracy
|
@@ -320,7 +340,7 @@ def check_ntp_answer(user_continuation):
|
|
320 |
feedback.append(f"✗ Not quite. The actual next token was '{next_token_clean}'")
|
321 |
|
322 |
# Show progress
|
323 |
-
feedback.append(f"\
|
324 |
|
325 |
# If there are more tokens, prompt for next
|
326 |
if ntp_state["next_token_idx"] < len(ntp_state["tokens"]):
|
|
|
70 |
print(f"Maskable indices count: {len(maskable_indices)}")
|
71 |
print(f"Mask ratio: {mask_ratio}")
|
72 |
|
73 |
+
# Calculate how many tokens to mask based on the mask ratio
|
74 |
+
# No arbitrary cap - use the actual percentage
|
75 |
+
num_to_mask = max(1, int(len(maskable_indices) * mask_ratio))
|
76 |
print(f"Number of tokens to mask: {num_to_mask}")
|
77 |
|
78 |
# Randomly select indices to mask
|
|
|
256 |
full_hidden = original_text[len(masked_text):].strip()
|
257 |
|
258 |
# Tokenize the hidden part
|
259 |
+
hidden_tokens = tokenizer.tokenize(full_hidden)
|
260 |
+
|
261 |
+
# Print debug info
|
262 |
+
print(f"NTP State setup:")
|
263 |
+
print(f" Full text: '{original_text}'")
|
264 |
+
print(f" Visible text: '{masked_text}'")
|
265 |
+
print(f" Hidden text: '{full_hidden}'")
|
266 |
+
print(f" Hidden tokens: {hidden_tokens}")
|
267 |
+
|
268 |
+
# Set up the NTP state
|
269 |
+
ntp_state["tokens"] = hidden_tokens
|
270 |
ntp_state["full_text"] = full_hidden
|
271 |
ntp_state["revealed_text"] = ""
|
272 |
ntp_state["next_token_idx"] = 0
|
273 |
|
274 |
# Make sure we have tokens to predict
|
275 |
if not ntp_state["tokens"]:
|
276 |
+
print("Warning: No tokens to predict, will try another sample")
|
277 |
+
# If we don't have tokens, get a new sample with a higher cut ratio
|
278 |
+
new_text = get_new_sample("ntp", 0.4) # Use higher cut ratio
|
279 |
prepare_next_token_prediction()
|
280 |
|
281 |
def check_ntp_answer(user_continuation):
|
|
|
286 |
if not ntp_state["tokens"]:
|
287 |
prepare_next_token_prediction()
|
288 |
|
289 |
+
# Print debug info
|
290 |
+
print(f"Current NTP state:")
|
291 |
+
print(f" Next token index: {ntp_state['next_token_idx']}")
|
292 |
+
print(f" Total tokens: {len(ntp_state['tokens'])}")
|
293 |
+
print(f" User input: '{user_continuation}'")
|
294 |
+
|
295 |
# No more tokens to predict
|
296 |
if ntp_state["next_token_idx"] >= len(ntp_state["tokens"]):
|
297 |
# Reset for next round
|
|
|
299 |
|
300 |
# Get the next token to predict
|
301 |
next_token = ntp_state["tokens"][ntp_state["next_token_idx"]]
|
302 |
+
print(f" Expected next token: '{next_token}'")
|
303 |
|
304 |
# Get user's prediction
|
305 |
user_text = user_continuation.strip()
|
|
|
307 |
# Tokenize user's prediction to get their first token
|
308 |
user_tokens = tokenizer.tokenize(user_text)
|
309 |
user_token = user_tokens[0].lower() if user_tokens else ""
|
310 |
+
print(f" User's tokenized input: {user_tokens}")
|
311 |
|
312 |
# Clean up tokens for comparison
|
313 |
next_token_clean = next_token.lower()
|
|
|
319 |
|
320 |
# Check if correct
|
321 |
is_correct = (user_token == next_token_clean)
|
322 |
+
print(f" Comparison: '{user_token}' vs '{next_token_clean}' -> {'Correct' if is_correct else 'Incorrect'}")
|
323 |
|
324 |
# Update stats
|
325 |
if is_correct:
|
|
|
327 |
user_stats["ntp"]["total"] += 1
|
328 |
|
329 |
# Reveal this token and prepare for next
|
330 |
+
ntp_state["revealed_text"] += tokenizer.convert_tokens_to_string([next_token])
|
331 |
ntp_state["next_token_idx"] += 1
|
332 |
|
333 |
# Calculate overall accuracy
|
|
|
340 |
feedback.append(f"✗ Not quite. The actual next token was '{next_token_clean}'")
|
341 |
|
342 |
# Show progress
|
343 |
+
feedback.append(f"\nText so far: {masked_text}{ntp_state['revealed_text']}")
|
344 |
|
345 |
# If there are more tokens, prompt for next
|
346 |
if ntp_state["next_token_idx"] < len(ntp_state["tokens"]):
|