Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import random
|
3 |
+
import re
|
4 |
+
import numpy as np
|
5 |
+
from datasets import load_dataset
|
6 |
+
from transformers import AutoTokenizer
|
7 |
+
|
8 |
+
# Load tokenizer
|
9 |
+
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
10 |
+
|
11 |
+
# Initialize variables to track stats
|
12 |
+
user_stats = {
|
13 |
+
"mlm": {"correct": 0, "total": 0},
|
14 |
+
"ntp": {"correct": 0, "total": 0}
|
15 |
+
}
|
16 |
+
|
17 |
+
# Function to load and sample from cc_news dataset
|
18 |
+
def load_sample_data(sample_size=100):
|
19 |
+
dataset = load_dataset("vblagoje/cc_news", streaming=True)
|
20 |
+
|
21 |
+
# Sample from the dataset
|
22 |
+
samples = []
|
23 |
+
for i, example in enumerate(dataset["train"]):
|
24 |
+
if i >= sample_size:
|
25 |
+
break
|
26 |
+
# Only use text field
|
27 |
+
if "text" in example and example["text"]:
|
28 |
+
# Clean text by removing extra whitespaces
|
29 |
+
text = re.sub(r'\s+', ' ', example["text"]).strip()
|
30 |
+
# Only include longer texts to make the task meaningful
|
31 |
+
if len(text.split()) > 50:
|
32 |
+
samples.append(text)
|
33 |
+
|
34 |
+
return samples
|
35 |
+
|
36 |
+
# Load data at startup
|
37 |
+
data_samples = load_sample_data(100)
|
38 |
+
current_sample = None
|
39 |
+
masked_text = ""
|
40 |
+
original_text = ""
|
41 |
+
masked_indices = []
|
42 |
+
masked_tokens = []
|
43 |
+
current_task = "mlm"
|
44 |
+
|
45 |
+
def prepare_mlm_sample(text, mask_ratio=0.15):
|
46 |
+
"""Prepare a text sample for MLM by masking random tokens."""
|
47 |
+
global masked_indices, masked_tokens, original_text
|
48 |
+
|
49 |
+
tokens = tokenizer.tokenize(text)
|
50 |
+
# Only mask whole words, not special tokens or punctuation
|
51 |
+
maskable_indices = [i for i, token in enumerate(tokens)
|
52 |
+
if not token.startswith("##") and not token.startswith("[") and not token.endswith("]")
|
53 |
+
and token not in [".", ",", "!", "?", ";", ":", "'", "\"", "-"]]
|
54 |
+
|
55 |
+
# Calculate how many tokens to mask
|
56 |
+
num_to_mask = max(1, int(len(maskable_indices) * mask_ratio))
|
57 |
+
# Randomly select indices to mask
|
58 |
+
indices_to_mask = random.sample(maskable_indices, min(num_to_mask, len(maskable_indices)))
|
59 |
+
|
60 |
+
# Create a copy of tokens to mask
|
61 |
+
masked_tokens_list = tokens.copy()
|
62 |
+
original_tokens = []
|
63 |
+
|
64 |
+
# Replace selected tokens with [MASK]
|
65 |
+
for idx in indices_to_mask:
|
66 |
+
original_tokens.append(masked_tokens_list[idx])
|
67 |
+
masked_tokens_list[idx] = "[MASK]"
|
68 |
+
|
69 |
+
# Save info for evaluation
|
70 |
+
masked_indices = indices_to_mask
|
71 |
+
masked_tokens = original_tokens
|
72 |
+
original_text = text
|
73 |
+
|
74 |
+
# Convert back to text with masks
|
75 |
+
masked_text = tokenizer.convert_tokens_to_string(masked_tokens_list)
|
76 |
+
|
77 |
+
return masked_text, indices_to_mask, original_tokens
|
78 |
+
|
79 |
+
def prepare_ntp_sample(text, cut_ratio=0.3):
|
80 |
+
"""Prepare a text sample for NTP by cutting off the end."""
|
81 |
+
# Tokenize text to ensure reasonable cutting
|
82 |
+
tokens = tokenizer.tokenize(text)
|
83 |
+
|
84 |
+
# Calculate cutoff point (70% of tokens if cut_ratio is 0.3)
|
85 |
+
cutoff = int(len(tokens) * (1 - cut_ratio))
|
86 |
+
|
87 |
+
# Get the visible part
|
88 |
+
visible_tokens = tokens[:cutoff]
|
89 |
+
|
90 |
+
# Get the hidden part (to be predicted)
|
91 |
+
hidden_tokens = tokens[cutoff:]
|
92 |
+
|
93 |
+
# Convert back to text
|
94 |
+
visible_text = tokenizer.convert_tokens_to_string(visible_tokens)
|
95 |
+
hidden_text = tokenizer.convert_tokens_to_string(hidden_tokens)
|
96 |
+
|
97 |
+
return visible_text, hidden_text
|
98 |
+
|
99 |
+
def get_new_sample(task, mask_ratio=0.15):
|
100 |
+
"""Get a new text sample based on the task."""
|
101 |
+
global current_sample, masked_text, masked_indices, masked_tokens, original_text
|
102 |
+
|
103 |
+
# Select a random sample
|
104 |
+
current_sample = random.choice(data_samples)
|
105 |
+
|
106 |
+
if task == "mlm":
|
107 |
+
# Prepare MLM sample
|
108 |
+
masked_text, masked_indices, masked_tokens = prepare_mlm_sample(current_sample, mask_ratio)
|
109 |
+
return masked_text
|
110 |
+
else: # NTP
|
111 |
+
# Prepare NTP sample
|
112 |
+
visible_text, hidden_text = prepare_ntp_sample(current_sample, mask_ratio)
|
113 |
+
# Store original and visible for comparison
|
114 |
+
original_text = current_sample
|
115 |
+
masked_text = visible_text
|
116 |
+
return visible_text
|
117 |
+
|
118 |
+
def check_mlm_answer(user_answers):
|
119 |
+
"""Check user MLM answers against the masked tokens."""
|
120 |
+
global user_stats
|
121 |
+
|
122 |
+
# Split user answers by spaces or commas
|
123 |
+
user_tokens = [token.strip().lower() for token in re.split(r'[,\s]+', user_answers)]
|
124 |
+
|
125 |
+
# Ensure we have the same number of answers as masks
|
126 |
+
if len(user_tokens) != len(masked_tokens):
|
127 |
+
return f"Please provide {len(masked_tokens)} answers. You provided {len(user_tokens)}."
|
128 |
+
|
129 |
+
# Compare each answer
|
130 |
+
correct = 0
|
131 |
+
feedback = []
|
132 |
+
|
133 |
+
for i, (user_token, orig_token) in enumerate(zip(user_tokens, masked_tokens)):
|
134 |
+
orig_token = orig_token.lower()
|
135 |
+
# Remove ## from subword tokens for comparison
|
136 |
+
if orig_token.startswith("##"):
|
137 |
+
orig_token = orig_token[2:]
|
138 |
+
|
139 |
+
if user_token == orig_token:
|
140 |
+
correct += 1
|
141 |
+
feedback.append(f"✓ Token {i+1}: '{user_token}' is correct!")
|
142 |
+
else:
|
143 |
+
feedback.append(f"✗ Token {i+1}: '{user_token}' should be '{orig_token}'")
|
144 |
+
|
145 |
+
# Update stats
|
146 |
+
user_stats["mlm"]["correct"] += correct
|
147 |
+
user_stats["mlm"]["total"] += len(masked_tokens)
|
148 |
+
|
149 |
+
# Calculate accuracy
|
150 |
+
accuracy = correct / len(masked_tokens) if masked_tokens else 0
|
151 |
+
accuracy_percentage = accuracy * 100
|
152 |
+
|
153 |
+
# Add overall accuracy to feedback
|
154 |
+
feedback.insert(0, f"Your accuracy: {correct}/{len(masked_tokens)} ({accuracy_percentage:.1f}%)")
|
155 |
+
|
156 |
+
# Calculate overall stats
|
157 |
+
overall_accuracy = user_stats["mlm"]["correct"] / user_stats["mlm"]["total"] if user_stats["mlm"]["total"] > 0 else 0
|
158 |
+
feedback.append(f"\nOverall MLM Accuracy: {user_stats['mlm']['correct']}/{user_stats['mlm']['total']} ({overall_accuracy*100:.1f}%)")
|
159 |
+
|
160 |
+
return "\n".join(feedback)
|
161 |
+
|
162 |
+
def check_ntp_answer(user_continuation):
|
163 |
+
"""Check user NTP answer against the original text."""
|
164 |
+
global user_stats, original_text, masked_text
|
165 |
+
|
166 |
+
# Get the hidden part of the original text
|
167 |
+
hidden_text = original_text[len(masked_text):].strip()
|
168 |
+
user_text = user_continuation.strip()
|
169 |
+
|
170 |
+
# Tokenize for better comparison
|
171 |
+
hidden_tokens = tokenizer.tokenize(hidden_text)
|
172 |
+
user_tokens = tokenizer.tokenize(user_text)
|
173 |
+
|
174 |
+
# Calculate overlap using first few tokens (more lenient)
|
175 |
+
max_compare = min(10, len(hidden_tokens), len(user_tokens))
|
176 |
+
if max_compare == 0:
|
177 |
+
return "Error: No hidden tokens to compare with."
|
178 |
+
|
179 |
+
correct = 0
|
180 |
+
for i in range(max_compare):
|
181 |
+
hidden_token = hidden_tokens[i].lower()
|
182 |
+
user_token = user_tokens[i].lower() if i < len(user_tokens) else ""
|
183 |
+
|
184 |
+
# Remove ## from subword tokens
|
185 |
+
if hidden_token.startswith("##"):
|
186 |
+
hidden_token = hidden_token[2:]
|
187 |
+
if user_token.startswith("##"):
|
188 |
+
user_token = user_token[2:]
|
189 |
+
|
190 |
+
if user_token == hidden_token:
|
191 |
+
correct += 1
|
192 |
+
|
193 |
+
# Update stats
|
194 |
+
user_stats["ntp"]["correct"] += correct
|
195 |
+
user_stats["ntp"]["total"] += max_compare
|
196 |
+
|
197 |
+
# Calculate accuracy
|
198 |
+
accuracy = correct / max_compare
|
199 |
+
accuracy_percentage = accuracy * 100
|
200 |
+
|
201 |
+
feedback = [f"Your prediction accuracy: {correct}/{max_compare} ({accuracy_percentage:.1f}%)"]
|
202 |
+
|
203 |
+
# Show original continuation
|
204 |
+
feedback.append(f"\nActual continuation:\n{hidden_text}")
|
205 |
+
|
206 |
+
# Calculate overall stats
|
207 |
+
overall_accuracy = user_stats["ntp"]["correct"] / user_stats["ntp"]["total"] if user_stats["ntp"]["total"] > 0 else 0
|
208 |
+
feedback.append(f"\nOverall NTP Accuracy: {user_stats['ntp']['correct']}/{user_stats['ntp']['total']} ({overall_accuracy*100:.1f}%)")
|
209 |
+
|
210 |
+
return "\n".join(feedback)
|
211 |
+
|
212 |
+
def switch_task(task):
|
213 |
+
"""Switch between MLM and NTP tasks."""
|
214 |
+
global current_task
|
215 |
+
current_task = task
|
216 |
+
return gr.update(visible=(task == "mlm")), gr.update(visible=(task == "ntp"))
|
217 |
+
|
218 |
+
def generate_new_sample(mask_ratio):
|
219 |
+
"""Generate a new sample based on current task."""
|
220 |
+
ratio = float(mask_ratio) / 100.0 # Convert percentage to ratio
|
221 |
+
sample = get_new_sample(current_task, ratio)
|
222 |
+
return sample, ""
|
223 |
+
|
224 |
+
def check_answer(user_input, task):
|
225 |
+
"""Check user answer based on current task."""
|
226 |
+
if task == "mlm":
|
227 |
+
return check_mlm_answer(user_input)
|
228 |
+
else: # NTP
|
229 |
+
return check_ntp_answer(user_input)
|
230 |
+
|
231 |
+
def reset_stats():
|
232 |
+
"""Reset user statistics."""
|
233 |
+
global user_stats
|
234 |
+
user_stats = {
|
235 |
+
"mlm": {"correct": 0, "total": 0},
|
236 |
+
"ntp": {"correct": 0, "total": 0}
|
237 |
+
}
|
238 |
+
return "Statistics have been reset."
|
239 |
+
|
240 |
+
# Set up Gradio interface
|
241 |
+
with gr.Blocks(title="MLM and NTP Testing") as demo:
|
242 |
+
gr.Markdown("# Language Model Testing: MLM vs NTP")
|
243 |
+
gr.Markdown("Test your skills at Masked Language Modeling (MLM) and Next Token Prediction (NTP)")
|
244 |
+
|
245 |
+
with gr.Row():
|
246 |
+
task_radio = gr.Radio(
|
247 |
+
["mlm", "ntp"],
|
248 |
+
label="Task Type",
|
249 |
+
value="mlm",
|
250 |
+
info="MLM: Guess the masked words | NTP: Predict what comes next"
|
251 |
+
)
|
252 |
+
mask_ratio = gr.Slider(
|
253 |
+
minimum=5,
|
254 |
+
maximum=50,
|
255 |
+
value=15,
|
256 |
+
step=5,
|
257 |
+
label="Mask/Cut Ratio (%)",
|
258 |
+
info="Percentage of tokens to mask (MLM) or text to hide (NTP)"
|
259 |
+
)
|
260 |
+
|
261 |
+
sample_text = gr.Textbox(
|
262 |
+
label="Text Sample",
|
263 |
+
placeholder="Click 'New Sample' to get started",
|
264 |
+
value=get_new_sample("mlm", 0.15),
|
265 |
+
lines=10,
|
266 |
+
interactive=False
|
267 |
+
)
|
268 |
+
|
269 |
+
with gr.Row():
|
270 |
+
new_button = gr.Button("New Sample")
|
271 |
+
reset_button = gr.Button("Reset Stats")
|
272 |
+
|
273 |
+
with gr.Group() as mlm_group:
|
274 |
+
mlm_answer = gr.Textbox(
|
275 |
+
label="Your MLM answers (separated by spaces or commas)",
|
276 |
+
placeholder="Type your guesses for the masked words",
|
277 |
+
lines=1
|
278 |
+
)
|
279 |
+
|
280 |
+
with gr.Group(visible=False) as ntp_group:
|
281 |
+
ntp_answer = gr.Textbox(
|
282 |
+
label="Your NTP continuation",
|
283 |
+
placeholder="Predict how the text continues...",
|
284 |
+
lines=3
|
285 |
+
)
|
286 |
+
|
287 |
+
with gr.Row():
|
288 |
+
check_button = gr.Button("Check Answer")
|
289 |
+
|
290 |
+
result = gr.Textbox(label="Result", lines=6)
|
291 |
+
|
292 |
+
# Set up event handlers
|
293 |
+
task_radio.change(switch_task, inputs=[task_radio], outputs=[mlm_group, ntp_group])
|
294 |
+
new_button.click(generate_new_sample, inputs=[mask_ratio], outputs=[sample_text, result])
|
295 |
+
reset_button.click(reset_stats, inputs=None, outputs=[result])
|
296 |
+
|
297 |
+
check_button.click(
|
298 |
+
check_answer,
|
299 |
+
inputs=[
|
300 |
+
gr.Textbox(value=lambda: mlm_answer.value if current_task == "mlm" else ntp_answer.value),
|
301 |
+
task_radio
|
302 |
+
],
|
303 |
+
outputs=[result]
|
304 |
+
)
|
305 |
+
|
306 |
+
mlm_answer.submit(check_mlm_answer, inputs=[mlm_answer], outputs=[result])
|
307 |
+
ntp_answer.submit(check_ntp_answer, inputs=[ntp_answer], outputs=[result])
|
308 |
+
|
309 |
+
demo.launch()
|