Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -69,6 +69,8 @@ def prepare_mlm_sample(text, mask_ratio=0.15):
|
|
69 |
num_to_mask = max(1, min(8, int(len(maskable_indices) * mask_ratio)))
|
70 |
# Randomly select indices to mask
|
71 |
indices_to_mask = random.sample(maskable_indices, min(num_to_mask, len(maskable_indices)))
|
|
|
|
|
72 |
|
73 |
# Create a copy of tokens to mask
|
74 |
masked_tokens_list = tokens.copy()
|
@@ -323,6 +325,7 @@ def generate_new_sample(mask_ratio):
|
|
323 |
|
324 |
def check_answer(user_input, task):
|
325 |
"""Check user answer based on current task."""
|
|
|
326 |
if task == "mlm":
|
327 |
return check_mlm_answer(user_input)
|
328 |
else: # NTP
|
@@ -373,7 +376,11 @@ with gr.Blocks(title="MLM and NTP Testing") as demo:
|
|
373 |
new_button = gr.Button("New Sample")
|
374 |
reset_button = gr.Button("Reset Stats")
|
375 |
|
376 |
-
|
|
|
|
|
|
|
|
|
377 |
mlm_instructions = gr.Markdown("""
|
378 |
### MLM Instructions
|
379 |
1. For each [MASK] token, provide your guess for the original word.
|
@@ -381,28 +388,50 @@ with gr.Blocks(title="MLM and NTP Testing") as demo:
|
|
381 |
3. Make sure you provide exactly the same number of answers as [MASK] tokens.
|
382 |
|
383 |
**Example format:** `word1, word2, word3` or `word1,word2,word3`
|
384 |
-
""")
|
385 |
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
)
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
label="Your
|
395 |
-
placeholder="
|
396 |
lines=1
|
397 |
)
|
398 |
|
399 |
with gr.Row():
|
400 |
-
check_button = gr.Button("Check Answer")
|
401 |
|
402 |
result = gr.Textbox(label="Result", lines=6)
|
403 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
404 |
# Set up event handlers
|
405 |
-
task_radio.change(
|
|
|
|
|
|
|
|
|
406 |
|
407 |
# Update the sample text and also update the mask count
|
408 |
def new_sample_with_count(mask_ratio_pct, task):
|
@@ -426,16 +455,23 @@ with gr.Blocks(title="MLM and NTP Testing") as demo:
|
|
426 |
|
427 |
reset_button.click(reset_stats, inputs=None, outputs=[result])
|
428 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
429 |
check_button.click(
|
430 |
-
|
431 |
-
inputs=[
|
432 |
-
gr.Textbox(value=lambda: mlm_answer.value if current_task == "mlm" else ntp_answer.value),
|
433 |
-
task_radio
|
434 |
-
],
|
435 |
outputs=[result]
|
436 |
)
|
437 |
|
438 |
-
|
439 |
-
|
|
|
|
|
|
|
440 |
|
441 |
demo.launch()
|
|
|
69 |
num_to_mask = max(1, min(8, int(len(maskable_indices) * mask_ratio)))
|
70 |
# Randomly select indices to mask
|
71 |
indices_to_mask = random.sample(maskable_indices, min(num_to_mask, len(maskable_indices)))
|
72 |
+
# Sort indices to ensure they're in order
|
73 |
+
indices_to_mask.sort()
|
74 |
|
75 |
# Create a copy of tokens to mask
|
76 |
masked_tokens_list = tokens.copy()
|
|
|
325 |
|
326 |
def check_answer(user_input, task):
|
327 |
"""Check user answer based on current task."""
|
328 |
+
# Make the current task visible in UI and more prominent
|
329 |
if task == "mlm":
|
330 |
return check_mlm_answer(user_input)
|
331 |
else: # NTP
|
|
|
376 |
new_button = gr.Button("New Sample")
|
377 |
reset_button = gr.Button("Reset Stats")
|
378 |
|
379 |
+
# Consolidated input area - only one visible at a time
|
380 |
+
input_area = gr.Group()
|
381 |
+
|
382 |
+
with input_area:
|
383 |
+
# Task-specific input instructions
|
384 |
mlm_instructions = gr.Markdown("""
|
385 |
### MLM Instructions
|
386 |
1. For each [MASK] token, provide your guess for the original word.
|
|
|
388 |
3. Make sure you provide exactly the same number of answers as [MASK] tokens.
|
389 |
|
390 |
**Example format:** `word1, word2, word3` or `word1,word2,word3`
|
391 |
+
""", visible=True)
|
392 |
|
393 |
+
ntp_instructions = gr.Markdown("""
|
394 |
+
### NTP Instructions
|
395 |
+
Predict the next word or token that would follow the text.
|
396 |
+
Type a single word or token for each prediction.
|
397 |
+
""", visible=False)
|
398 |
+
|
399 |
+
# Unified input box
|
400 |
+
answer_input = gr.Textbox(
|
401 |
+
label="Your answer",
|
402 |
+
placeholder="For MLM: word1, word2, word3 | For NTP: single word",
|
403 |
lines=1
|
404 |
)
|
405 |
|
406 |
with gr.Row():
|
407 |
+
check_button = gr.Button("Check Answer", variant="primary")
|
408 |
|
409 |
result = gr.Textbox(label="Result", lines=6)
|
410 |
|
411 |
+
# Function to switch task type
|
412 |
+
def switch_task_unified(task):
|
413 |
+
if task == "mlm":
|
414 |
+
mask_text = f"**Number of [MASK] tokens to guess: {len(masked_tokens)}**"
|
415 |
+
return (
|
416 |
+
gr.update(visible=True), # mlm_instructions
|
417 |
+
gr.update(visible=False), # ntp_instructions
|
418 |
+
gr.update(placeholder="comma-separated answers (e.g., word1, word2, word3)"),
|
419 |
+
mask_text
|
420 |
+
)
|
421 |
+
else: # ntp
|
422 |
+
return (
|
423 |
+
gr.update(visible=False), # mlm_instructions
|
424 |
+
gr.update(visible=True), # ntp_instructions
|
425 |
+
gr.update(placeholder="Type the next word/token you predict"),
|
426 |
+
"**Next Token Prediction mode - guess one token at a time**"
|
427 |
+
)
|
428 |
+
|
429 |
# Set up event handlers
|
430 |
+
task_radio.change(
|
431 |
+
switch_task_unified,
|
432 |
+
inputs=[task_radio],
|
433 |
+
outputs=[mlm_instructions, ntp_instructions, answer_input, mask_count]
|
434 |
+
)
|
435 |
|
436 |
# Update the sample text and also update the mask count
|
437 |
def new_sample_with_count(mask_ratio_pct, task):
|
|
|
455 |
|
456 |
reset_button.click(reset_stats, inputs=None, outputs=[result])
|
457 |
|
458 |
+
# Unified check answer function
|
459 |
+
def unified_check_answer(user_input, task):
|
460 |
+
if task == "mlm":
|
461 |
+
return check_mlm_answer(user_input)
|
462 |
+
else: # ntp
|
463 |
+
return check_ntp_answer(user_input)
|
464 |
+
|
465 |
check_button.click(
|
466 |
+
unified_check_answer,
|
467 |
+
inputs=[answer_input, task_radio],
|
|
|
|
|
|
|
468 |
outputs=[result]
|
469 |
)
|
470 |
|
471 |
+
answer_input.submit(
|
472 |
+
unified_check_answer,
|
473 |
+
inputs=[answer_input, task_radio],
|
474 |
+
outputs=[result]
|
475 |
+
)
|
476 |
|
477 |
demo.launch()
|