orionweller commited on
Commit
59fd051
·
verified ·
1 Parent(s): d1414a2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -41
app.py CHANGED
@@ -6,7 +6,7 @@ from datasets import load_dataset
6
  from transformers import AutoTokenizer
7
 
8
  # Load tokenizer
9
- tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
10
 
11
  # Initialize variables to track stats
12
  user_stats = {
@@ -28,8 +28,13 @@ def load_sample_data(sample_size=100):
28
  # Clean text by removing extra whitespaces
29
  text = re.sub(r'\s+', ' ', example["text"]).strip()
30
  # Only include longer texts to make the task meaningful
31
- if len(text.split()) > 50:
32
- samples.append(text)
 
 
 
 
 
33
 
34
  return samples
35
 
@@ -81,8 +86,14 @@ def prepare_ntp_sample(text, cut_ratio=0.3):
81
  # Tokenize text to ensure reasonable cutting
82
  tokens = tokenizer.tokenize(text)
83
 
 
 
 
 
84
  # Calculate cutoff point (70% of tokens if cut_ratio is 0.3)
85
- cutoff = int(len(tokens) * (1 - cut_ratio))
 
 
86
 
87
  # Get the visible part
88
  visible_tokens = tokens[:cutoff]
@@ -98,7 +109,7 @@ def prepare_ntp_sample(text, cut_ratio=0.3):
98
 
99
  def get_new_sample(task, mask_ratio=0.15):
100
  """Get a new text sample based on the task."""
101
- global current_sample, masked_text, masked_indices, masked_tokens, original_text
102
 
103
  # Select a random sample
104
  current_sample = random.choice(data_samples)
@@ -113,6 +124,18 @@ def get_new_sample(task, mask_ratio=0.15):
113
  # Store original and visible for comparison
114
  original_text = current_sample
115
  masked_text = visible_text
 
 
 
 
 
 
 
 
 
 
 
 
116
  return visible_text
117
 
118
  def check_mlm_answer(user_answers):
@@ -159,52 +182,95 @@ def check_mlm_answer(user_answers):
159
 
160
  return "\n".join(feedback)
161
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  def check_ntp_answer(user_continuation):
163
- """Check user NTP answer against the original text."""
164
- global user_stats, original_text, masked_text
 
 
 
 
 
 
 
 
 
165
 
166
- # Get the hidden part of the original text
167
- hidden_text = original_text[len(masked_text):].strip()
 
 
168
  user_text = user_continuation.strip()
169
 
170
- # Tokenize for better comparison
171
- hidden_tokens = tokenizer.tokenize(hidden_text)
172
  user_tokens = tokenizer.tokenize(user_text)
 
173
 
174
- # Calculate overlap using first few tokens (more lenient)
175
- max_compare = min(10, len(hidden_tokens), len(user_tokens))
176
- if max_compare == 0:
177
- return "Error: No hidden tokens to compare with."
178
 
179
- correct = 0
180
- for i in range(max_compare):
181
- hidden_token = hidden_tokens[i].lower()
182
- user_token = user_tokens[i].lower() if i < len(user_tokens) else ""
183
-
184
- # Remove ## from subword tokens
185
- if hidden_token.startswith("##"):
186
- hidden_token = hidden_token[2:]
187
- if user_token.startswith("##"):
188
- user_token = user_token[2:]
189
-
190
- if user_token == hidden_token:
191
- correct += 1
192
 
193
  # Update stats
194
- user_stats["ntp"]["correct"] += correct
195
- user_stats["ntp"]["total"] += max_compare
 
196
 
197
- # Calculate accuracy
198
- accuracy = correct / max_compare
199
- accuracy_percentage = accuracy * 100
200
 
201
- feedback = [f"Your prediction accuracy: {correct}/{max_compare} ({accuracy_percentage:.1f}%)"]
 
202
 
203
- # Show original continuation
204
- feedback.append(f"\nActual continuation:\n{hidden_text}")
 
 
 
205
 
206
- # Calculate overall stats
207
- overall_accuracy = user_stats["ntp"]["correct"] / user_stats["ntp"]["total"] if user_stats["ntp"]["total"] > 0 else 0
 
 
 
 
 
 
 
 
208
  feedback.append(f"\nOverall NTP Accuracy: {user_stats['ntp']['correct']}/{user_stats['ntp']['total']} ({overall_accuracy*100:.1f}%)")
209
 
210
  return "\n".join(feedback)
@@ -279,9 +345,9 @@ with gr.Blocks(title="MLM and NTP Testing") as demo:
279
 
280
  with gr.Group(visible=False) as ntp_group:
281
  ntp_answer = gr.Textbox(
282
- label="Your NTP continuation",
283
- placeholder="Predict how the text continues...",
284
- lines=3
285
  )
286
 
287
  with gr.Row():
 
6
  from transformers import AutoTokenizer
7
 
8
  # Load tokenizer
9
+ tokenizer = AutoTokenizer.from_pretrained("answerdotai/modernbert-base")
10
 
11
  # Initialize variables to track stats
12
  user_stats = {
 
28
  # Clean text by removing extra whitespaces
29
  text = re.sub(r'\s+', ' ', example["text"]).strip()
30
  # Only include longer texts to make the task meaningful
31
+ if len(text.split()) > 20:
32
+ # Truncate to two sentences
33
+ sentences = re.split(r'(?<=[.!?])\s+', text)
34
+ if len(sentences) >= 2:
35
+ # Take only the first two sentences
36
+ two_sentence_text = ' '.join(sentences[:2])
37
+ samples.append(two_sentence_text)
38
 
39
  return samples
40
 
 
86
  # Tokenize text to ensure reasonable cutting
87
  tokens = tokenizer.tokenize(text)
88
 
89
+ # Ensure we have enough tokens
90
+ if len(tokens) < 5:
91
+ return text, "" # Return original if too short
92
+
93
  # Calculate cutoff point (70% of tokens if cut_ratio is 0.3)
94
+ # But make sure we have at least 3 tokens visible and 1 token hidden
95
+ cutoff = max(3, int(len(tokens) * (1 - cut_ratio)))
96
+ cutoff = min(cutoff, len(tokens) - 1) # Ensure there's at least 1 token to predict
97
 
98
  # Get the visible part
99
  visible_tokens = tokens[:cutoff]
 
109
 
110
  def get_new_sample(task, mask_ratio=0.15):
111
  """Get a new text sample based on the task."""
112
+ global current_sample, masked_text, masked_indices, masked_tokens, original_text, ntp_state
113
 
114
  # Select a random sample
115
  current_sample = random.choice(data_samples)
 
124
  # Store original and visible for comparison
125
  original_text = current_sample
126
  masked_text = visible_text
127
+
128
+ # Reset NTP state for new iteration
129
+ ntp_state = {
130
+ "full_text": "",
131
+ "revealed_text": "",
132
+ "next_token_idx": 0,
133
+ "tokens": []
134
+ }
135
+
136
+ # Prepare for token-by-token prediction
137
+ prepare_next_token_prediction()
138
+
139
  return visible_text
140
 
141
  def check_mlm_answer(user_answers):
 
182
 
183
  return "\n".join(feedback)
184
 
185
+ # Variable to store NTP state
186
+ ntp_state = {
187
+ "full_text": "",
188
+ "revealed_text": "",
189
+ "next_token_idx": 0,
190
+ "tokens": []
191
+ }
192
+
193
+ def prepare_next_token_prediction():
194
+ """Prepare for the next token prediction."""
195
+ global ntp_state, masked_text, original_text
196
+
197
+ # Get the hidden part
198
+ full_hidden = original_text[len(masked_text):].strip()
199
+
200
+ # Tokenize the hidden part
201
+ ntp_state["tokens"] = tokenizer.tokenize(full_hidden)
202
+ ntp_state["full_text"] = full_hidden
203
+ ntp_state["revealed_text"] = ""
204
+ ntp_state["next_token_idx"] = 0
205
+
206
+ # Make sure we have tokens to predict
207
+ if not ntp_state["tokens"]:
208
+ # If we don't have tokens, get a new sample
209
+ new_text = get_new_sample("ntp", 0.3)
210
+ prepare_next_token_prediction()
211
+
212
  def check_ntp_answer(user_continuation):
213
+ """Check user NTP answer for the next token only."""
214
+ global user_stats, ntp_state, masked_text
215
+
216
+ # If we haven't set up NTP state yet, do it now
217
+ if not ntp_state["tokens"]:
218
+ prepare_next_token_prediction()
219
+
220
+ # No more tokens to predict
221
+ if ntp_state["next_token_idx"] >= len(ntp_state["tokens"]):
222
+ # Reset for next round
223
+ return "You've completed this prediction! Click 'New Sample' for another."
224
 
225
+ # Get the next token to predict
226
+ next_token = ntp_state["tokens"][ntp_state["next_token_idx"]]
227
+
228
+ # Get user's prediction
229
  user_text = user_continuation.strip()
230
 
231
+ # Tokenize user's prediction to get their first token
 
232
  user_tokens = tokenizer.tokenize(user_text)
233
+ user_token = user_tokens[0].lower() if user_tokens else ""
234
 
235
+ # Clean up tokens for comparison
236
+ next_token_clean = next_token.lower()
237
+ if next_token_clean.startswith("##"):
238
+ next_token_clean = next_token_clean[2:]
239
 
240
+ if user_token.startswith("##"):
241
+ user_token = user_token[2:]
242
+
243
+ # Check if correct
244
+ is_correct = (user_token == next_token_clean)
 
 
 
 
 
 
 
 
245
 
246
  # Update stats
247
+ if is_correct:
248
+ user_stats["ntp"]["correct"] += 1
249
+ user_stats["ntp"]["total"] += 1
250
 
251
+ # Reveal this token and prepare for next
252
+ ntp_state["revealed_text"] += " " + tokenizer.convert_tokens_to_string([next_token])
253
+ ntp_state["next_token_idx"] += 1
254
 
255
+ # Calculate overall accuracy
256
+ overall_accuracy = user_stats["ntp"]["correct"] / user_stats["ntp"]["total"] if user_stats["ntp"]["total"] > 0 else 0
257
 
258
+ feedback = []
259
+ if is_correct:
260
+ feedback.append(f"✓ Correct! The next token was indeed '{next_token_clean}'")
261
+ else:
262
+ feedback.append(f"✗ Not quite. The actual next token was '{next_token_clean}'")
263
 
264
+ # Show progress
265
+ feedback.append(f"\nRevealed so far: {masked_text}{ntp_state['revealed_text']}")
266
+
267
+ # If there are more tokens, prompt for next
268
+ if ntp_state["next_token_idx"] < len(ntp_state["tokens"]):
269
+ feedback.append(f"\nPredict the next token...")
270
+ else:
271
+ feedback.append(f"\nPrediction complete! Full text was:\n{original_text}")
272
+
273
+ # Show overall stats
274
  feedback.append(f"\nOverall NTP Accuracy: {user_stats['ntp']['correct']}/{user_stats['ntp']['total']} ({overall_accuracy*100:.1f}%)")
275
 
276
  return "\n".join(feedback)
 
345
 
346
  with gr.Group(visible=False) as ntp_group:
347
  ntp_answer = gr.Textbox(
348
+ label="Your Next Token Prediction",
349
+ placeholder="Predict the next token/word...",
350
+ lines=1
351
  )
352
 
353
  with gr.Row():