orionweller commited on
Commit
bbed6df
·
verified ·
1 Parent(s): 8f1d1e1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -15
app.py CHANGED
@@ -65,8 +65,8 @@ def prepare_mlm_sample(text, mask_ratio=0.15):
65
  if not token.startswith("##") and not token.startswith("[") and not token.endswith("]")
66
  and token not in [".", ",", "!", "?", ";", ":", "'", "\"", "-"]]
67
 
68
- # Calculate how many tokens to mask
69
- num_to_mask = max(1, int(len(maskable_indices) * mask_ratio))
70
  # Randomly select indices to mask
71
  indices_to_mask = random.sample(maskable_indices, min(num_to_mask, len(maskable_indices)))
72
 
@@ -87,6 +87,11 @@ def prepare_mlm_sample(text, mask_ratio=0.15):
87
  # Convert back to text with masks
88
  masked_text = tokenizer.convert_tokens_to_string(masked_tokens_list)
89
 
 
 
 
 
 
90
  return masked_text, indices_to_mask, original_tokens
91
 
92
  def prepare_ntp_sample(text, cut_ratio=0.3):
@@ -150,18 +155,33 @@ def check_mlm_answer(user_answers):
150
  """Check user MLM answers against the masked tokens."""
151
  global user_stats
152
 
153
- # Improved parsing of user answers to better handle different formats
154
- # First replace any whitespace around commas with just commas
155
- cleaned_answers = re.sub(r'\s*,\s*', ',', user_answers.strip())
156
- # Then split by comma or whitespace
157
- user_tokens = []
158
- for token in re.split(r',|\s+', cleaned_answers):
159
- if token: # Only add non-empty tokens
160
- user_tokens.append(token.strip().lower())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
 
162
  # Ensure we have the same number of answers as masks
163
  if len(user_tokens) != len(masked_tokens):
164
- return f"Please provide {len(masked_tokens)} answers. You provided {len(user_tokens)}.\nFormat: word1, word2, word3"
165
 
166
  # Compare each answer
167
  correct = 0
@@ -338,6 +358,9 @@ with gr.Blocks(title="MLM and NTP Testing") as demo:
338
  info="Percentage of tokens to mask (MLM) or text to hide (NTP)"
339
  )
340
 
 
 
 
341
  sample_text = gr.Textbox(
342
  label="Text Sample",
343
  placeholder="Click 'New Sample' to get started",
@@ -351,12 +374,20 @@ with gr.Blocks(title="MLM and NTP Testing") as demo:
351
  reset_button = gr.Button("Reset Stats")
352
 
353
  with gr.Group() as mlm_group:
 
 
 
 
 
 
 
 
 
354
  mlm_answer = gr.Textbox(
355
- label="Your MLM answers (separated by commas)",
356
- placeholder="word1, word2, word3, etc.",
357
  lines=1
358
  )
359
- gr.Markdown("**Example input format:** finding, its, phishing, in, links, 49, and, it")
360
 
361
  with gr.Group(visible=False) as ntp_group:
362
  ntp_answer = gr.Textbox(
@@ -372,7 +403,27 @@ with gr.Blocks(title="MLM and NTP Testing") as demo:
372
 
373
  # Set up event handlers
374
  task_radio.change(switch_task, inputs=[task_radio], outputs=[mlm_group, ntp_group])
375
- new_button.click(generate_new_sample, inputs=[mask_ratio], outputs=[sample_text, result])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
376
  reset_button.click(reset_stats, inputs=None, outputs=[result])
377
 
378
  check_button.click(
 
65
  if not token.startswith("##") and not token.startswith("[") and not token.endswith("]")
66
  and token not in [".", ",", "!", "?", ";", ":", "'", "\"", "-"]]
67
 
68
+ # Calculate how many tokens to mask, but ensure at least 1 and at most 8
69
+ num_to_mask = max(1, min(8, int(len(maskable_indices) * mask_ratio)))
70
  # Randomly select indices to mask
71
  indices_to_mask = random.sample(maskable_indices, min(num_to_mask, len(maskable_indices)))
72
 
 
87
  # Convert back to text with masks
88
  masked_text = tokenizer.convert_tokens_to_string(masked_tokens_list)
89
 
90
+ # Print debugging info
91
+ print(f"Original tokens: {original_tokens}")
92
+ print(f"Masked indices: {indices_to_mask}")
93
+ print(f"Number of masks: {len(original_tokens)}")
94
+
95
  return masked_text, indices_to_mask, original_tokens
96
 
97
  def prepare_ntp_sample(text, cut_ratio=0.3):
 
155
  """Check user MLM answers against the masked tokens."""
156
  global user_stats
157
 
158
+ # Print for debugging
159
+ print(f"Original user input: '{user_answers}'")
160
+
161
+ # Handle the case where input is empty
162
+ if not user_answers or user_answers.isspace():
163
+ return "Please provide your answers. No input was detected."
164
+
165
+ # Basic cleanup - trim and lowercase
166
+ user_answers = user_answers.strip().lower()
167
+ print(f"After basic cleanup: '{user_answers}'")
168
+
169
+ # Explicit comma-based splitting with protection for empty entries
170
+ if ',' in user_answers:
171
+ # Split by commas and strip each item
172
+ user_tokens = [token.strip() for token in user_answers.split(',')]
173
+ # Filter out empty tokens
174
+ user_tokens = [token for token in user_tokens if token]
175
+ else:
176
+ # If no commas, split by whitespace
177
+ user_tokens = [token for token in user_answers.split() if token]
178
+
179
+ print(f"Parsed tokens: {user_tokens}, count: {len(user_tokens)}")
180
+ print(f"Expected tokens: {masked_tokens}, count: {len(masked_tokens)}")
181
 
182
  # Ensure we have the same number of answers as masks
183
  if len(user_tokens) != len(masked_tokens):
184
+ return f"Please provide exactly {len(masked_tokens)} answers (one for each [MASK]). You provided {len(user_tokens)}.\n\nFormat example: word1, word2, word3"
185
 
186
  # Compare each answer
187
  correct = 0
 
358
  info="Percentage of tokens to mask (MLM) or text to hide (NTP)"
359
  )
360
 
361
+ # Count the visible [MASK] tokens for user reference
362
+ mask_count = gr.Markdown("**Number of [MASK] tokens to guess: 0**")
363
+
364
  sample_text = gr.Textbox(
365
  label="Text Sample",
366
  placeholder="Click 'New Sample' to get started",
 
374
  reset_button = gr.Button("Reset Stats")
375
 
376
  with gr.Group() as mlm_group:
377
+ mlm_instructions = gr.Markdown("""
378
+ ### MLM Instructions
379
+ 1. For each [MASK] token, provide your guess for the original word.
380
+ 2. Separate your answers with commas.
381
+ 3. Make sure you provide exactly the same number of answers as [MASK] tokens.
382
+
383
+ **Example format:** `word1, word2, word3` or `word1,word2,word3`
384
+ """)
385
+
386
  mlm_answer = gr.Textbox(
387
+ label="Your answers (comma-separated)",
388
+ placeholder="word1, word2, word3",
389
  lines=1
390
  )
 
391
 
392
  with gr.Group(visible=False) as ntp_group:
393
  ntp_answer = gr.Textbox(
 
403
 
404
  # Set up event handlers
405
  task_radio.change(switch_task, inputs=[task_radio], outputs=[mlm_group, ntp_group])
406
+
407
+ # Update the sample text and also update the mask count
408
+ def new_sample_with_count(mask_ratio_pct, task):
409
+ ratio = float(mask_ratio_pct) / 100.0
410
+ sample = get_new_sample(task, ratio)
411
+ mask_count_text = ""
412
+
413
+ if task == "mlm":
414
+ count = len(masked_tokens)
415
+ mask_count_text = f"**Number of [MASK] tokens to guess: {count}**"
416
+ else:
417
+ mask_count_text = "**Next Token Prediction mode - guess one token at a time**"
418
+
419
+ return sample, mask_count_text, ""
420
+
421
+ new_button.click(
422
+ new_sample_with_count,
423
+ inputs=[mask_ratio, task_radio],
424
+ outputs=[sample_text, mask_count, result]
425
+ )
426
+
427
  reset_button.click(reset_stats, inputs=None, outputs=[result])
428
 
429
  check_button.click(