Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -65,8 +65,8 @@ def prepare_mlm_sample(text, mask_ratio=0.15):
|
|
65 |
if not token.startswith("##") and not token.startswith("[") and not token.endswith("]")
|
66 |
and token not in [".", ",", "!", "?", ";", ":", "'", "\"", "-"]]
|
67 |
|
68 |
-
# Calculate how many tokens to mask
|
69 |
-
num_to_mask = max(1, int(len(maskable_indices) * mask_ratio))
|
70 |
# Randomly select indices to mask
|
71 |
indices_to_mask = random.sample(maskable_indices, min(num_to_mask, len(maskable_indices)))
|
72 |
|
@@ -87,6 +87,11 @@ def prepare_mlm_sample(text, mask_ratio=0.15):
|
|
87 |
# Convert back to text with masks
|
88 |
masked_text = tokenizer.convert_tokens_to_string(masked_tokens_list)
|
89 |
|
|
|
|
|
|
|
|
|
|
|
90 |
return masked_text, indices_to_mask, original_tokens
|
91 |
|
92 |
def prepare_ntp_sample(text, cut_ratio=0.3):
|
@@ -150,18 +155,33 @@ def check_mlm_answer(user_answers):
|
|
150 |
"""Check user MLM answers against the masked tokens."""
|
151 |
global user_stats
|
152 |
|
153 |
-
#
|
154 |
-
|
155 |
-
|
156 |
-
#
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
161 |
|
162 |
# Ensure we have the same number of answers as masks
|
163 |
if len(user_tokens) != len(masked_tokens):
|
164 |
-
return f"Please provide {len(masked_tokens)} answers. You provided {len(user_tokens)}.\nFormat: word1, word2, word3"
|
165 |
|
166 |
# Compare each answer
|
167 |
correct = 0
|
@@ -338,6 +358,9 @@ with gr.Blocks(title="MLM and NTP Testing") as demo:
|
|
338 |
info="Percentage of tokens to mask (MLM) or text to hide (NTP)"
|
339 |
)
|
340 |
|
|
|
|
|
|
|
341 |
sample_text = gr.Textbox(
|
342 |
label="Text Sample",
|
343 |
placeholder="Click 'New Sample' to get started",
|
@@ -351,12 +374,20 @@ with gr.Blocks(title="MLM and NTP Testing") as demo:
|
|
351 |
reset_button = gr.Button("Reset Stats")
|
352 |
|
353 |
with gr.Group() as mlm_group:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
354 |
mlm_answer = gr.Textbox(
|
355 |
-
label="Your
|
356 |
-
placeholder="word1, word2, word3
|
357 |
lines=1
|
358 |
)
|
359 |
-
gr.Markdown("**Example input format:** finding, its, phishing, in, links, 49, and, it")
|
360 |
|
361 |
with gr.Group(visible=False) as ntp_group:
|
362 |
ntp_answer = gr.Textbox(
|
@@ -372,7 +403,27 @@ with gr.Blocks(title="MLM and NTP Testing") as demo:
|
|
372 |
|
373 |
# Set up event handlers
|
374 |
task_radio.change(switch_task, inputs=[task_radio], outputs=[mlm_group, ntp_group])
|
375 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
376 |
reset_button.click(reset_stats, inputs=None, outputs=[result])
|
377 |
|
378 |
check_button.click(
|
|
|
65 |
if not token.startswith("##") and not token.startswith("[") and not token.endswith("]")
|
66 |
and token not in [".", ",", "!", "?", ";", ":", "'", "\"", "-"]]
|
67 |
|
68 |
+
# Calculate how many tokens to mask, but ensure at least 1 and at most 8
|
69 |
+
num_to_mask = max(1, min(8, int(len(maskable_indices) * mask_ratio)))
|
70 |
# Randomly select indices to mask
|
71 |
indices_to_mask = random.sample(maskable_indices, min(num_to_mask, len(maskable_indices)))
|
72 |
|
|
|
87 |
# Convert back to text with masks
|
88 |
masked_text = tokenizer.convert_tokens_to_string(masked_tokens_list)
|
89 |
|
90 |
+
# Print debugging info
|
91 |
+
print(f"Original tokens: {original_tokens}")
|
92 |
+
print(f"Masked indices: {indices_to_mask}")
|
93 |
+
print(f"Number of masks: {len(original_tokens)}")
|
94 |
+
|
95 |
return masked_text, indices_to_mask, original_tokens
|
96 |
|
97 |
def prepare_ntp_sample(text, cut_ratio=0.3):
|
|
|
155 |
"""Check user MLM answers against the masked tokens."""
|
156 |
global user_stats
|
157 |
|
158 |
+
# Print for debugging
|
159 |
+
print(f"Original user input: '{user_answers}'")
|
160 |
+
|
161 |
+
# Handle the case where input is empty
|
162 |
+
if not user_answers or user_answers.isspace():
|
163 |
+
return "Please provide your answers. No input was detected."
|
164 |
+
|
165 |
+
# Basic cleanup - trim and lowercase
|
166 |
+
user_answers = user_answers.strip().lower()
|
167 |
+
print(f"After basic cleanup: '{user_answers}'")
|
168 |
+
|
169 |
+
# Explicit comma-based splitting with protection for empty entries
|
170 |
+
if ',' in user_answers:
|
171 |
+
# Split by commas and strip each item
|
172 |
+
user_tokens = [token.strip() for token in user_answers.split(',')]
|
173 |
+
# Filter out empty tokens
|
174 |
+
user_tokens = [token for token in user_tokens if token]
|
175 |
+
else:
|
176 |
+
# If no commas, split by whitespace
|
177 |
+
user_tokens = [token for token in user_answers.split() if token]
|
178 |
+
|
179 |
+
print(f"Parsed tokens: {user_tokens}, count: {len(user_tokens)}")
|
180 |
+
print(f"Expected tokens: {masked_tokens}, count: {len(masked_tokens)}")
|
181 |
|
182 |
# Ensure we have the same number of answers as masks
|
183 |
if len(user_tokens) != len(masked_tokens):
|
184 |
+
return f"Please provide exactly {len(masked_tokens)} answers (one for each [MASK]). You provided {len(user_tokens)}.\n\nFormat example: word1, word2, word3"
|
185 |
|
186 |
# Compare each answer
|
187 |
correct = 0
|
|
|
358 |
info="Percentage of tokens to mask (MLM) or text to hide (NTP)"
|
359 |
)
|
360 |
|
361 |
+
# Count the visible [MASK] tokens for user reference
|
362 |
+
mask_count = gr.Markdown("**Number of [MASK] tokens to guess: 0**")
|
363 |
+
|
364 |
sample_text = gr.Textbox(
|
365 |
label="Text Sample",
|
366 |
placeholder="Click 'New Sample' to get started",
|
|
|
374 |
reset_button = gr.Button("Reset Stats")
|
375 |
|
376 |
with gr.Group() as mlm_group:
|
377 |
+
mlm_instructions = gr.Markdown("""
|
378 |
+
### MLM Instructions
|
379 |
+
1. For each [MASK] token, provide your guess for the original word.
|
380 |
+
2. Separate your answers with commas.
|
381 |
+
3. Make sure you provide exactly the same number of answers as [MASK] tokens.
|
382 |
+
|
383 |
+
**Example format:** `word1, word2, word3` or `word1,word2,word3`
|
384 |
+
""")
|
385 |
+
|
386 |
mlm_answer = gr.Textbox(
|
387 |
+
label="Your answers (comma-separated)",
|
388 |
+
placeholder="word1, word2, word3",
|
389 |
lines=1
|
390 |
)
|
|
|
391 |
|
392 |
with gr.Group(visible=False) as ntp_group:
|
393 |
ntp_answer = gr.Textbox(
|
|
|
403 |
|
404 |
# Set up event handlers
|
405 |
task_radio.change(switch_task, inputs=[task_radio], outputs=[mlm_group, ntp_group])
|
406 |
+
|
407 |
+
# Update the sample text and also update the mask count
|
408 |
+
def new_sample_with_count(mask_ratio_pct, task):
|
409 |
+
ratio = float(mask_ratio_pct) / 100.0
|
410 |
+
sample = get_new_sample(task, ratio)
|
411 |
+
mask_count_text = ""
|
412 |
+
|
413 |
+
if task == "mlm":
|
414 |
+
count = len(masked_tokens)
|
415 |
+
mask_count_text = f"**Number of [MASK] tokens to guess: {count}**"
|
416 |
+
else:
|
417 |
+
mask_count_text = "**Next Token Prediction mode - guess one token at a time**"
|
418 |
+
|
419 |
+
return sample, mask_count_text, ""
|
420 |
+
|
421 |
+
new_button.click(
|
422 |
+
new_sample_with_count,
|
423 |
+
inputs=[mask_ratio, task_radio],
|
424 |
+
outputs=[sample_text, mask_count, result]
|
425 |
+
)
|
426 |
+
|
427 |
reset_button.click(reset_stats, inputs=None, outputs=[result])
|
428 |
|
429 |
check_button.click(
|