Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -6,7 +6,7 @@ from datasets import load_dataset
|
|
6 |
from transformers import AutoTokenizer
|
7 |
|
8 |
# Load tokenizer
|
9 |
-
tokenizer = AutoTokenizer.from_pretrained("
|
10 |
|
11 |
# Initialize variables to track stats
|
12 |
user_stats = {
|
@@ -28,8 +28,13 @@ def load_sample_data(sample_size=100):
|
|
28 |
# Clean text by removing extra whitespaces
|
29 |
text = re.sub(r'\s+', ' ', example["text"]).strip()
|
30 |
# Only include longer texts to make the task meaningful
|
31 |
-
if len(text.split()) >
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
return samples
|
35 |
|
@@ -81,8 +86,14 @@ def prepare_ntp_sample(text, cut_ratio=0.3):
|
|
81 |
# Tokenize text to ensure reasonable cutting
|
82 |
tokens = tokenizer.tokenize(text)
|
83 |
|
|
|
|
|
|
|
|
|
84 |
# Calculate cutoff point (70% of tokens if cut_ratio is 0.3)
|
85 |
-
|
|
|
|
|
86 |
|
87 |
# Get the visible part
|
88 |
visible_tokens = tokens[:cutoff]
|
@@ -98,7 +109,7 @@ def prepare_ntp_sample(text, cut_ratio=0.3):
|
|
98 |
|
99 |
def get_new_sample(task, mask_ratio=0.15):
|
100 |
"""Get a new text sample based on the task."""
|
101 |
-
global current_sample, masked_text, masked_indices, masked_tokens, original_text
|
102 |
|
103 |
# Select a random sample
|
104 |
current_sample = random.choice(data_samples)
|
@@ -113,6 +124,18 @@ def get_new_sample(task, mask_ratio=0.15):
|
|
113 |
# Store original and visible for comparison
|
114 |
original_text = current_sample
|
115 |
masked_text = visible_text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
return visible_text
|
117 |
|
118 |
def check_mlm_answer(user_answers):
|
@@ -159,52 +182,95 @@ def check_mlm_answer(user_answers):
|
|
159 |
|
160 |
return "\n".join(feedback)
|
161 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
def check_ntp_answer(user_continuation):
|
163 |
-
"""Check user NTP answer
|
164 |
-
global user_stats,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
|
166 |
-
# Get the
|
167 |
-
|
|
|
|
|
168 |
user_text = user_continuation.strip()
|
169 |
|
170 |
-
# Tokenize
|
171 |
-
hidden_tokens = tokenizer.tokenize(hidden_text)
|
172 |
user_tokens = tokenizer.tokenize(user_text)
|
|
|
173 |
|
174 |
-
#
|
175 |
-
|
176 |
-
if
|
177 |
-
|
178 |
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
# Remove ## from subword tokens
|
185 |
-
if hidden_token.startswith("##"):
|
186 |
-
hidden_token = hidden_token[2:]
|
187 |
-
if user_token.startswith("##"):
|
188 |
-
user_token = user_token[2:]
|
189 |
-
|
190 |
-
if user_token == hidden_token:
|
191 |
-
correct += 1
|
192 |
|
193 |
# Update stats
|
194 |
-
|
195 |
-
|
|
|
196 |
|
197 |
-
#
|
198 |
-
|
199 |
-
|
200 |
|
201 |
-
|
|
|
202 |
|
203 |
-
|
204 |
-
|
|
|
|
|
|
|
205 |
|
206 |
-
#
|
207 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
208 |
feedback.append(f"\nOverall NTP Accuracy: {user_stats['ntp']['correct']}/{user_stats['ntp']['total']} ({overall_accuracy*100:.1f}%)")
|
209 |
|
210 |
return "\n".join(feedback)
|
@@ -279,9 +345,9 @@ with gr.Blocks(title="MLM and NTP Testing") as demo:
|
|
279 |
|
280 |
with gr.Group(visible=False) as ntp_group:
|
281 |
ntp_answer = gr.Textbox(
|
282 |
-
label="Your
|
283 |
-
placeholder="Predict
|
284 |
-
lines=
|
285 |
)
|
286 |
|
287 |
with gr.Row():
|
|
|
6 |
from transformers import AutoTokenizer
|
7 |
|
8 |
# Load tokenizer
|
9 |
+
tokenizer = AutoTokenizer.from_pretrained("answerdotai/modernbert-base")
|
10 |
|
11 |
# Initialize variables to track stats
|
12 |
user_stats = {
|
|
|
28 |
# Clean text by removing extra whitespaces
|
29 |
text = re.sub(r'\s+', ' ', example["text"]).strip()
|
30 |
# Only include longer texts to make the task meaningful
|
31 |
+
if len(text.split()) > 20:
|
32 |
+
# Truncate to two sentences
|
33 |
+
sentences = re.split(r'(?<=[.!?])\s+', text)
|
34 |
+
if len(sentences) >= 2:
|
35 |
+
# Take only the first two sentences
|
36 |
+
two_sentence_text = ' '.join(sentences[:2])
|
37 |
+
samples.append(two_sentence_text)
|
38 |
|
39 |
return samples
|
40 |
|
|
|
86 |
# Tokenize text to ensure reasonable cutting
|
87 |
tokens = tokenizer.tokenize(text)
|
88 |
|
89 |
+
# Ensure we have enough tokens
|
90 |
+
if len(tokens) < 5:
|
91 |
+
return text, "" # Return original if too short
|
92 |
+
|
93 |
# Calculate cutoff point (70% of tokens if cut_ratio is 0.3)
|
94 |
+
# But make sure we have at least 3 tokens visible and 1 token hidden
|
95 |
+
cutoff = max(3, int(len(tokens) * (1 - cut_ratio)))
|
96 |
+
cutoff = min(cutoff, len(tokens) - 1) # Ensure there's at least 1 token to predict
|
97 |
|
98 |
# Get the visible part
|
99 |
visible_tokens = tokens[:cutoff]
|
|
|
109 |
|
110 |
def get_new_sample(task, mask_ratio=0.15):
|
111 |
"""Get a new text sample based on the task."""
|
112 |
+
global current_sample, masked_text, masked_indices, masked_tokens, original_text, ntp_state
|
113 |
|
114 |
# Select a random sample
|
115 |
current_sample = random.choice(data_samples)
|
|
|
124 |
# Store original and visible for comparison
|
125 |
original_text = current_sample
|
126 |
masked_text = visible_text
|
127 |
+
|
128 |
+
# Reset NTP state for new iteration
|
129 |
+
ntp_state = {
|
130 |
+
"full_text": "",
|
131 |
+
"revealed_text": "",
|
132 |
+
"next_token_idx": 0,
|
133 |
+
"tokens": []
|
134 |
+
}
|
135 |
+
|
136 |
+
# Prepare for token-by-token prediction
|
137 |
+
prepare_next_token_prediction()
|
138 |
+
|
139 |
return visible_text
|
140 |
|
141 |
def check_mlm_answer(user_answers):
|
|
|
182 |
|
183 |
return "\n".join(feedback)
|
184 |
|
185 |
+
# Variable to store NTP state
|
186 |
+
ntp_state = {
|
187 |
+
"full_text": "",
|
188 |
+
"revealed_text": "",
|
189 |
+
"next_token_idx": 0,
|
190 |
+
"tokens": []
|
191 |
+
}
|
192 |
+
|
193 |
+
def prepare_next_token_prediction():
|
194 |
+
"""Prepare for the next token prediction."""
|
195 |
+
global ntp_state, masked_text, original_text
|
196 |
+
|
197 |
+
# Get the hidden part
|
198 |
+
full_hidden = original_text[len(masked_text):].strip()
|
199 |
+
|
200 |
+
# Tokenize the hidden part
|
201 |
+
ntp_state["tokens"] = tokenizer.tokenize(full_hidden)
|
202 |
+
ntp_state["full_text"] = full_hidden
|
203 |
+
ntp_state["revealed_text"] = ""
|
204 |
+
ntp_state["next_token_idx"] = 0
|
205 |
+
|
206 |
+
# Make sure we have tokens to predict
|
207 |
+
if not ntp_state["tokens"]:
|
208 |
+
# If we don't have tokens, get a new sample
|
209 |
+
new_text = get_new_sample("ntp", 0.3)
|
210 |
+
prepare_next_token_prediction()
|
211 |
+
|
212 |
def check_ntp_answer(user_continuation):
|
213 |
+
"""Check user NTP answer for the next token only."""
|
214 |
+
global user_stats, ntp_state, masked_text
|
215 |
+
|
216 |
+
# If we haven't set up NTP state yet, do it now
|
217 |
+
if not ntp_state["tokens"]:
|
218 |
+
prepare_next_token_prediction()
|
219 |
+
|
220 |
+
# No more tokens to predict
|
221 |
+
if ntp_state["next_token_idx"] >= len(ntp_state["tokens"]):
|
222 |
+
# Reset for next round
|
223 |
+
return "You've completed this prediction! Click 'New Sample' for another."
|
224 |
|
225 |
+
# Get the next token to predict
|
226 |
+
next_token = ntp_state["tokens"][ntp_state["next_token_idx"]]
|
227 |
+
|
228 |
+
# Get user's prediction
|
229 |
user_text = user_continuation.strip()
|
230 |
|
231 |
+
# Tokenize user's prediction to get their first token
|
|
|
232 |
user_tokens = tokenizer.tokenize(user_text)
|
233 |
+
user_token = user_tokens[0].lower() if user_tokens else ""
|
234 |
|
235 |
+
# Clean up tokens for comparison
|
236 |
+
next_token_clean = next_token.lower()
|
237 |
+
if next_token_clean.startswith("##"):
|
238 |
+
next_token_clean = next_token_clean[2:]
|
239 |
|
240 |
+
if user_token.startswith("##"):
|
241 |
+
user_token = user_token[2:]
|
242 |
+
|
243 |
+
# Check if correct
|
244 |
+
is_correct = (user_token == next_token_clean)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
245 |
|
246 |
# Update stats
|
247 |
+
if is_correct:
|
248 |
+
user_stats["ntp"]["correct"] += 1
|
249 |
+
user_stats["ntp"]["total"] += 1
|
250 |
|
251 |
+
# Reveal this token and prepare for next
|
252 |
+
ntp_state["revealed_text"] += " " + tokenizer.convert_tokens_to_string([next_token])
|
253 |
+
ntp_state["next_token_idx"] += 1
|
254 |
|
255 |
+
# Calculate overall accuracy
|
256 |
+
overall_accuracy = user_stats["ntp"]["correct"] / user_stats["ntp"]["total"] if user_stats["ntp"]["total"] > 0 else 0
|
257 |
|
258 |
+
feedback = []
|
259 |
+
if is_correct:
|
260 |
+
feedback.append(f"✓ Correct! The next token was indeed '{next_token_clean}'")
|
261 |
+
else:
|
262 |
+
feedback.append(f"✗ Not quite. The actual next token was '{next_token_clean}'")
|
263 |
|
264 |
+
# Show progress
|
265 |
+
feedback.append(f"\nRevealed so far: {masked_text}{ntp_state['revealed_text']}")
|
266 |
+
|
267 |
+
# If there are more tokens, prompt for next
|
268 |
+
if ntp_state["next_token_idx"] < len(ntp_state["tokens"]):
|
269 |
+
feedback.append(f"\nPredict the next token...")
|
270 |
+
else:
|
271 |
+
feedback.append(f"\nPrediction complete! Full text was:\n{original_text}")
|
272 |
+
|
273 |
+
# Show overall stats
|
274 |
feedback.append(f"\nOverall NTP Accuracy: {user_stats['ntp']['correct']}/{user_stats['ntp']['total']} ({overall_accuracy*100:.1f}%)")
|
275 |
|
276 |
return "\n".join(feedback)
|
|
|
345 |
|
346 |
with gr.Group(visible=False) as ntp_group:
|
347 |
ntp_answer = gr.Textbox(
|
348 |
+
label="Your Next Token Prediction",
|
349 |
+
placeholder="Predict the next token/word...",
|
350 |
+
lines=1
|
351 |
)
|
352 |
|
353 |
with gr.Row():
|