orionweller commited on
Commit
9b671c4
·
verified ·
1 Parent(s): bbed6df

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -21
app.py CHANGED
@@ -69,6 +69,8 @@ def prepare_mlm_sample(text, mask_ratio=0.15):
69
  num_to_mask = max(1, min(8, int(len(maskable_indices) * mask_ratio)))
70
  # Randomly select indices to mask
71
  indices_to_mask = random.sample(maskable_indices, min(num_to_mask, len(maskable_indices)))
 
 
72
 
73
  # Create a copy of tokens to mask
74
  masked_tokens_list = tokens.copy()
@@ -323,6 +325,7 @@ def generate_new_sample(mask_ratio):
323
 
324
  def check_answer(user_input, task):
325
  """Check user answer based on current task."""
 
326
  if task == "mlm":
327
  return check_mlm_answer(user_input)
328
  else: # NTP
@@ -373,7 +376,11 @@ with gr.Blocks(title="MLM and NTP Testing") as demo:
373
  new_button = gr.Button("New Sample")
374
  reset_button = gr.Button("Reset Stats")
375
 
376
- with gr.Group() as mlm_group:
 
 
 
 
377
  mlm_instructions = gr.Markdown("""
378
  ### MLM Instructions
379
  1. For each [MASK] token, provide your guess for the original word.
@@ -381,28 +388,50 @@ with gr.Blocks(title="MLM and NTP Testing") as demo:
381
  3. Make sure you provide exactly the same number of answers as [MASK] tokens.
382
 
383
  **Example format:** `word1, word2, word3` or `word1,word2,word3`
384
- """)
385
 
386
- mlm_answer = gr.Textbox(
387
- label="Your answers (comma-separated)",
388
- placeholder="word1, word2, word3",
389
- lines=1
390
- )
391
-
392
- with gr.Group(visible=False) as ntp_group:
393
- ntp_answer = gr.Textbox(
394
- label="Your Next Token Prediction",
395
- placeholder="Predict the next token/word...",
396
  lines=1
397
  )
398
 
399
  with gr.Row():
400
- check_button = gr.Button("Check Answer")
401
 
402
  result = gr.Textbox(label="Result", lines=6)
403
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
404
  # Set up event handlers
405
- task_radio.change(switch_task, inputs=[task_radio], outputs=[mlm_group, ntp_group])
 
 
 
 
406
 
407
  # Update the sample text and also update the mask count
408
  def new_sample_with_count(mask_ratio_pct, task):
@@ -426,16 +455,23 @@ with gr.Blocks(title="MLM and NTP Testing") as demo:
426
 
427
  reset_button.click(reset_stats, inputs=None, outputs=[result])
428
 
 
 
 
 
 
 
 
429
  check_button.click(
430
- check_answer,
431
- inputs=[
432
- gr.Textbox(value=lambda: mlm_answer.value if current_task == "mlm" else ntp_answer.value),
433
- task_radio
434
- ],
435
  outputs=[result]
436
  )
437
 
438
- mlm_answer.submit(check_mlm_answer, inputs=[mlm_answer], outputs=[result])
439
- ntp_answer.submit(check_ntp_answer, inputs=[ntp_answer], outputs=[result])
 
 
 
440
 
441
  demo.launch()
 
69
  num_to_mask = max(1, min(8, int(len(maskable_indices) * mask_ratio)))
70
  # Randomly select indices to mask
71
  indices_to_mask = random.sample(maskable_indices, min(num_to_mask, len(maskable_indices)))
72
+ # Sort indices to ensure they're in order
73
+ indices_to_mask.sort()
74
 
75
  # Create a copy of tokens to mask
76
  masked_tokens_list = tokens.copy()
 
325
 
326
  def check_answer(user_input, task):
327
  """Check user answer based on current task."""
328
+ # Make the current task visible in UI and more prominent
329
  if task == "mlm":
330
  return check_mlm_answer(user_input)
331
  else: # NTP
 
376
  new_button = gr.Button("New Sample")
377
  reset_button = gr.Button("Reset Stats")
378
 
379
+ # Consolidated input area - only one visible at a time
380
+ input_area = gr.Group()
381
+
382
+ with input_area:
383
+ # Task-specific input instructions
384
  mlm_instructions = gr.Markdown("""
385
  ### MLM Instructions
386
  1. For each [MASK] token, provide your guess for the original word.
 
388
  3. Make sure you provide exactly the same number of answers as [MASK] tokens.
389
 
390
  **Example format:** `word1, word2, word3` or `word1,word2,word3`
391
+ """, visible=True)
392
 
393
+ ntp_instructions = gr.Markdown("""
394
+ ### NTP Instructions
395
+ Predict the next word or token that would follow the text.
396
+ Type a single word or token for each prediction.
397
+ """, visible=False)
398
+
399
+ # Unified input box
400
+ answer_input = gr.Textbox(
401
+ label="Your answer",
402
+ placeholder="For MLM: word1, word2, word3 | For NTP: single word",
403
  lines=1
404
  )
405
 
406
  with gr.Row():
407
+ check_button = gr.Button("Check Answer", variant="primary")
408
 
409
  result = gr.Textbox(label="Result", lines=6)
410
 
411
+ # Function to switch task type
412
+ def switch_task_unified(task):
413
+ if task == "mlm":
414
+ mask_text = f"**Number of [MASK] tokens to guess: {len(masked_tokens)}**"
415
+ return (
416
+ gr.update(visible=True), # mlm_instructions
417
+ gr.update(visible=False), # ntp_instructions
418
+ gr.update(placeholder="comma-separated answers (e.g., word1, word2, word3)"),
419
+ mask_text
420
+ )
421
+ else: # ntp
422
+ return (
423
+ gr.update(visible=False), # mlm_instructions
424
+ gr.update(visible=True), # ntp_instructions
425
+ gr.update(placeholder="Type the next word/token you predict"),
426
+ "**Next Token Prediction mode - guess one token at a time**"
427
+ )
428
+
429
  # Set up event handlers
430
+ task_radio.change(
431
+ switch_task_unified,
432
+ inputs=[task_radio],
433
+ outputs=[mlm_instructions, ntp_instructions, answer_input, mask_count]
434
+ )
435
 
436
  # Update the sample text and also update the mask count
437
  def new_sample_with_count(mask_ratio_pct, task):
 
455
 
456
  reset_button.click(reset_stats, inputs=None, outputs=[result])
457
 
458
+ # Unified check answer function
459
+ def unified_check_answer(user_input, task):
460
+ if task == "mlm":
461
+ return check_mlm_answer(user_input)
462
+ else: # ntp
463
+ return check_ntp_answer(user_input)
464
+
465
  check_button.click(
466
+ unified_check_answer,
467
+ inputs=[answer_input, task_radio],
 
 
 
468
  outputs=[result]
469
  )
470
 
471
+ answer_input.submit(
472
+ unified_check_answer,
473
+ inputs=[answer_input, task_radio],
474
+ outputs=[result]
475
+ )
476
 
477
  demo.launch()