orionweller commited on
Commit
cbbc299
·
verified ·
1 Parent(s): 59fd051

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -11
app.py CHANGED
@@ -6,7 +6,7 @@ from datasets import load_dataset
6
  from transformers import AutoTokenizer
7
 
8
  # Load tokenizer
9
- tokenizer = AutoTokenizer.from_pretrained("answerdotai/modernbert-base")
10
 
11
  # Initialize variables to track stats
12
  user_stats = {
@@ -14,19 +14,27 @@ user_stats = {
14
  "ntp": {"correct": 0, "total": 0}
15
  }
16
 
17
- # Function to load and sample from cc_news dataset
18
  def load_sample_data(sample_size=100):
19
- dataset = load_dataset("vblagoje/cc_news", streaming=True)
 
 
 
 
 
 
 
 
20
 
21
  # Sample from the dataset
22
  samples = []
23
  for i, example in enumerate(dataset["train"]):
24
  if i >= sample_size:
25
  break
26
- # Only use text field
27
- if "text" in example and example["text"]:
28
  # Clean text by removing extra whitespaces
29
- text = re.sub(r'\s+', ' ', example["text"]).strip()
30
  # Only include longer texts to make the task meaningful
31
  if len(text.split()) > 20:
32
  # Truncate to two sentences
@@ -142,12 +150,18 @@ def check_mlm_answer(user_answers):
142
  """Check user MLM answers against the masked tokens."""
143
  global user_stats
144
 
145
- # Split user answers by spaces or commas
146
- user_tokens = [token.strip().lower() for token in re.split(r'[,\s]+', user_answers)]
 
 
 
 
 
 
147
 
148
  # Ensure we have the same number of answers as masks
149
  if len(user_tokens) != len(masked_tokens):
150
- return f"Please provide {len(masked_tokens)} answers. You provided {len(user_tokens)}."
151
 
152
  # Compare each answer
153
  correct = 0
@@ -338,10 +352,11 @@ with gr.Blocks(title="MLM and NTP Testing") as demo:
338
 
339
  with gr.Group() as mlm_group:
340
  mlm_answer = gr.Textbox(
341
- label="Your MLM answers (separated by spaces or commas)",
342
- placeholder="Type your guesses for the masked words",
343
  lines=1
344
  )
 
345
 
346
  with gr.Group(visible=False) as ntp_group:
347
  ntp_answer = gr.Textbox(
 
6
  from transformers import AutoTokenizer
7
 
8
  # Load tokenizer
9
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
10
 
11
  # Initialize variables to track stats
12
  user_stats = {
 
14
  "ntp": {"correct": 0, "total": 0}
15
  }
16
 
17
+ # Function to load and sample from the requested dataset
18
  def load_sample_data(sample_size=100):
19
+ try:
20
+ # Try to load the requested dataset
21
+ dataset = load_dataset("mlfoundations/dclm-baseline-1.0-parquet", streaming=True)
22
+ dataset_field = "text" # Assuming the field name is "text"
23
+ except Exception as e:
24
+ print(f"Error loading requested dataset: {e}")
25
+ # Fallback to cc_news if there's an issue
26
+ dataset = load_dataset("vblagoje/cc_news", streaming=True)
27
+ dataset_field = "text"
28
 
29
  # Sample from the dataset
30
  samples = []
31
  for i, example in enumerate(dataset["train"]):
32
  if i >= sample_size:
33
  break
34
+ # Get text from the appropriate field
35
+ if dataset_field in example and example[dataset_field]:
36
  # Clean text by removing extra whitespaces
37
+ text = re.sub(r'\s+', ' ', example[dataset_field]).strip()
38
  # Only include longer texts to make the task meaningful
39
  if len(text.split()) > 20:
40
  # Truncate to two sentences
 
150
  """Check user MLM answers against the masked tokens."""
151
  global user_stats
152
 
153
+ # Improved parsing of user answers to better handle different formats
154
+ # First replace any whitespace around commas with just commas
155
+ cleaned_answers = re.sub(r'\s*,\s*', ',', user_answers.strip())
156
+ # Then split by comma or whitespace
157
+ user_tokens = []
158
+ for token in re.split(r',|\s+', cleaned_answers):
159
+ if token: # Only add non-empty tokens
160
+ user_tokens.append(token.strip().lower())
161
 
162
  # Ensure we have the same number of answers as masks
163
  if len(user_tokens) != len(masked_tokens):
164
+ return f"Please provide {len(masked_tokens)} answers. You provided {len(user_tokens)}.\nFormat: word1, word2, word3"
165
 
166
  # Compare each answer
167
  correct = 0
 
352
 
353
  with gr.Group() as mlm_group:
354
  mlm_answer = gr.Textbox(
355
+ label="Your MLM answers (separated by commas)",
356
+ placeholder="word1, word2, word3, etc.",
357
  lines=1
358
  )
359
+ gr.Markdown("**Example input format:** finding, its, phishing, in, links, 49, and, it")
360
 
361
  with gr.Group(visible=False) as ntp_group:
362
  ntp_answer = gr.Textbox(