orionweller commited on
Commit
30f5f00
·
verified ·
1 Parent(s): 5d376ee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +309 -0
app.py CHANGED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import random
3
+ import re
4
+ import numpy as np
5
+ from datasets import load_dataset
6
+ from transformers import AutoTokenizer
7
+
8
+ # Load tokenizer
9
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
10
+
11
+ # Initialize variables to track stats
12
+ user_stats = {
13
+ "mlm": {"correct": 0, "total": 0},
14
+ "ntp": {"correct": 0, "total": 0}
15
+ }
16
+
17
+ # Function to load and sample from cc_news dataset
18
+ def load_sample_data(sample_size=100):
19
+ dataset = load_dataset("vblagoje/cc_news", streaming=True)
20
+
21
+ # Sample from the dataset
22
+ samples = []
23
+ for i, example in enumerate(dataset["train"]):
24
+ if i >= sample_size:
25
+ break
26
+ # Only use text field
27
+ if "text" in example and example["text"]:
28
+ # Clean text by removing extra whitespaces
29
+ text = re.sub(r'\s+', ' ', example["text"]).strip()
30
+ # Only include longer texts to make the task meaningful
31
+ if len(text.split()) > 50:
32
+ samples.append(text)
33
+
34
+ return samples
35
+
36
+ # Load data at startup
37
+ data_samples = load_sample_data(100)
38
+ current_sample = None
39
+ masked_text = ""
40
+ original_text = ""
41
+ masked_indices = []
42
+ masked_tokens = []
43
+ current_task = "mlm"
44
+
45
+ def prepare_mlm_sample(text, mask_ratio=0.15):
46
+ """Prepare a text sample for MLM by masking random tokens."""
47
+ global masked_indices, masked_tokens, original_text
48
+
49
+ tokens = tokenizer.tokenize(text)
50
+ # Only mask whole words, not special tokens or punctuation
51
+ maskable_indices = [i for i, token in enumerate(tokens)
52
+ if not token.startswith("##") and not token.startswith("[") and not token.endswith("]")
53
+ and token not in [".", ",", "!", "?", ";", ":", "'", "\"", "-"]]
54
+
55
+ # Calculate how many tokens to mask
56
+ num_to_mask = max(1, int(len(maskable_indices) * mask_ratio))
57
+ # Randomly select indices to mask
58
+ indices_to_mask = random.sample(maskable_indices, min(num_to_mask, len(maskable_indices)))
59
+
60
+ # Create a copy of tokens to mask
61
+ masked_tokens_list = tokens.copy()
62
+ original_tokens = []
63
+
64
+ # Replace selected tokens with [MASK]
65
+ for idx in indices_to_mask:
66
+ original_tokens.append(masked_tokens_list[idx])
67
+ masked_tokens_list[idx] = "[MASK]"
68
+
69
+ # Save info for evaluation
70
+ masked_indices = indices_to_mask
71
+ masked_tokens = original_tokens
72
+ original_text = text
73
+
74
+ # Convert back to text with masks
75
+ masked_text = tokenizer.convert_tokens_to_string(masked_tokens_list)
76
+
77
+ return masked_text, indices_to_mask, original_tokens
78
+
79
+ def prepare_ntp_sample(text, cut_ratio=0.3):
80
+ """Prepare a text sample for NTP by cutting off the end."""
81
+ # Tokenize text to ensure reasonable cutting
82
+ tokens = tokenizer.tokenize(text)
83
+
84
+ # Calculate cutoff point (70% of tokens if cut_ratio is 0.3)
85
+ cutoff = int(len(tokens) * (1 - cut_ratio))
86
+
87
+ # Get the visible part
88
+ visible_tokens = tokens[:cutoff]
89
+
90
+ # Get the hidden part (to be predicted)
91
+ hidden_tokens = tokens[cutoff:]
92
+
93
+ # Convert back to text
94
+ visible_text = tokenizer.convert_tokens_to_string(visible_tokens)
95
+ hidden_text = tokenizer.convert_tokens_to_string(hidden_tokens)
96
+
97
+ return visible_text, hidden_text
98
+
99
+ def get_new_sample(task, mask_ratio=0.15):
100
+ """Get a new text sample based on the task."""
101
+ global current_sample, masked_text, masked_indices, masked_tokens, original_text
102
+
103
+ # Select a random sample
104
+ current_sample = random.choice(data_samples)
105
+
106
+ if task == "mlm":
107
+ # Prepare MLM sample
108
+ masked_text, masked_indices, masked_tokens = prepare_mlm_sample(current_sample, mask_ratio)
109
+ return masked_text
110
+ else: # NTP
111
+ # Prepare NTP sample
112
+ visible_text, hidden_text = prepare_ntp_sample(current_sample, mask_ratio)
113
+ # Store original and visible for comparison
114
+ original_text = current_sample
115
+ masked_text = visible_text
116
+ return visible_text
117
+
118
+ def check_mlm_answer(user_answers):
119
+ """Check user MLM answers against the masked tokens."""
120
+ global user_stats
121
+
122
+ # Split user answers by spaces or commas
123
+ user_tokens = [token.strip().lower() for token in re.split(r'[,\s]+', user_answers)]
124
+
125
+ # Ensure we have the same number of answers as masks
126
+ if len(user_tokens) != len(masked_tokens):
127
+ return f"Please provide {len(masked_tokens)} answers. You provided {len(user_tokens)}."
128
+
129
+ # Compare each answer
130
+ correct = 0
131
+ feedback = []
132
+
133
+ for i, (user_token, orig_token) in enumerate(zip(user_tokens, masked_tokens)):
134
+ orig_token = orig_token.lower()
135
+ # Remove ## from subword tokens for comparison
136
+ if orig_token.startswith("##"):
137
+ orig_token = orig_token[2:]
138
+
139
+ if user_token == orig_token:
140
+ correct += 1
141
+ feedback.append(f"✓ Token {i+1}: '{user_token}' is correct!")
142
+ else:
143
+ feedback.append(f"✗ Token {i+1}: '{user_token}' should be '{orig_token}'")
144
+
145
+ # Update stats
146
+ user_stats["mlm"]["correct"] += correct
147
+ user_stats["mlm"]["total"] += len(masked_tokens)
148
+
149
+ # Calculate accuracy
150
+ accuracy = correct / len(masked_tokens) if masked_tokens else 0
151
+ accuracy_percentage = accuracy * 100
152
+
153
+ # Add overall accuracy to feedback
154
+ feedback.insert(0, f"Your accuracy: {correct}/{len(masked_tokens)} ({accuracy_percentage:.1f}%)")
155
+
156
+ # Calculate overall stats
157
+ overall_accuracy = user_stats["mlm"]["correct"] / user_stats["mlm"]["total"] if user_stats["mlm"]["total"] > 0 else 0
158
+ feedback.append(f"\nOverall MLM Accuracy: {user_stats['mlm']['correct']}/{user_stats['mlm']['total']} ({overall_accuracy*100:.1f}%)")
159
+
160
+ return "\n".join(feedback)
161
+
162
+ def check_ntp_answer(user_continuation):
163
+ """Check user NTP answer against the original text."""
164
+ global user_stats, original_text, masked_text
165
+
166
+ # Get the hidden part of the original text
167
+ hidden_text = original_text[len(masked_text):].strip()
168
+ user_text = user_continuation.strip()
169
+
170
+ # Tokenize for better comparison
171
+ hidden_tokens = tokenizer.tokenize(hidden_text)
172
+ user_tokens = tokenizer.tokenize(user_text)
173
+
174
+ # Calculate overlap using first few tokens (more lenient)
175
+ max_compare = min(10, len(hidden_tokens), len(user_tokens))
176
+ if max_compare == 0:
177
+ return "Error: No hidden tokens to compare with."
178
+
179
+ correct = 0
180
+ for i in range(max_compare):
181
+ hidden_token = hidden_tokens[i].lower()
182
+ user_token = user_tokens[i].lower() if i < len(user_tokens) else ""
183
+
184
+ # Remove ## from subword tokens
185
+ if hidden_token.startswith("##"):
186
+ hidden_token = hidden_token[2:]
187
+ if user_token.startswith("##"):
188
+ user_token = user_token[2:]
189
+
190
+ if user_token == hidden_token:
191
+ correct += 1
192
+
193
+ # Update stats
194
+ user_stats["ntp"]["correct"] += correct
195
+ user_stats["ntp"]["total"] += max_compare
196
+
197
+ # Calculate accuracy
198
+ accuracy = correct / max_compare
199
+ accuracy_percentage = accuracy * 100
200
+
201
+ feedback = [f"Your prediction accuracy: {correct}/{max_compare} ({accuracy_percentage:.1f}%)"]
202
+
203
+ # Show original continuation
204
+ feedback.append(f"\nActual continuation:\n{hidden_text}")
205
+
206
+ # Calculate overall stats
207
+ overall_accuracy = user_stats["ntp"]["correct"] / user_stats["ntp"]["total"] if user_stats["ntp"]["total"] > 0 else 0
208
+ feedback.append(f"\nOverall NTP Accuracy: {user_stats['ntp']['correct']}/{user_stats['ntp']['total']} ({overall_accuracy*100:.1f}%)")
209
+
210
+ return "\n".join(feedback)
211
+
212
+ def switch_task(task):
213
+ """Switch between MLM and NTP tasks."""
214
+ global current_task
215
+ current_task = task
216
+ return gr.update(visible=(task == "mlm")), gr.update(visible=(task == "ntp"))
217
+
218
+ def generate_new_sample(mask_ratio):
219
+ """Generate a new sample based on current task."""
220
+ ratio = float(mask_ratio) / 100.0 # Convert percentage to ratio
221
+ sample = get_new_sample(current_task, ratio)
222
+ return sample, ""
223
+
224
+ def check_answer(user_input, task):
225
+ """Check user answer based on current task."""
226
+ if task == "mlm":
227
+ return check_mlm_answer(user_input)
228
+ else: # NTP
229
+ return check_ntp_answer(user_input)
230
+
231
+ def reset_stats():
232
+ """Reset user statistics."""
233
+ global user_stats
234
+ user_stats = {
235
+ "mlm": {"correct": 0, "total": 0},
236
+ "ntp": {"correct": 0, "total": 0}
237
+ }
238
+ return "Statistics have been reset."
239
+
240
+ # Set up Gradio interface
241
+ with gr.Blocks(title="MLM and NTP Testing") as demo:
242
+ gr.Markdown("# Language Model Testing: MLM vs NTP")
243
+ gr.Markdown("Test your skills at Masked Language Modeling (MLM) and Next Token Prediction (NTP)")
244
+
245
+ with gr.Row():
246
+ task_radio = gr.Radio(
247
+ ["mlm", "ntp"],
248
+ label="Task Type",
249
+ value="mlm",
250
+ info="MLM: Guess the masked words | NTP: Predict what comes next"
251
+ )
252
+ mask_ratio = gr.Slider(
253
+ minimum=5,
254
+ maximum=50,
255
+ value=15,
256
+ step=5,
257
+ label="Mask/Cut Ratio (%)",
258
+ info="Percentage of tokens to mask (MLM) or text to hide (NTP)"
259
+ )
260
+
261
+ sample_text = gr.Textbox(
262
+ label="Text Sample",
263
+ placeholder="Click 'New Sample' to get started",
264
+ value=get_new_sample("mlm", 0.15),
265
+ lines=10,
266
+ interactive=False
267
+ )
268
+
269
+ with gr.Row():
270
+ new_button = gr.Button("New Sample")
271
+ reset_button = gr.Button("Reset Stats")
272
+
273
+ with gr.Group() as mlm_group:
274
+ mlm_answer = gr.Textbox(
275
+ label="Your MLM answers (separated by spaces or commas)",
276
+ placeholder="Type your guesses for the masked words",
277
+ lines=1
278
+ )
279
+
280
+ with gr.Group(visible=False) as ntp_group:
281
+ ntp_answer = gr.Textbox(
282
+ label="Your NTP continuation",
283
+ placeholder="Predict how the text continues...",
284
+ lines=3
285
+ )
286
+
287
+ with gr.Row():
288
+ check_button = gr.Button("Check Answer")
289
+
290
+ result = gr.Textbox(label="Result", lines=6)
291
+
292
+ # Set up event handlers
293
+ task_radio.change(switch_task, inputs=[task_radio], outputs=[mlm_group, ntp_group])
294
+ new_button.click(generate_new_sample, inputs=[mask_ratio], outputs=[sample_text, result])
295
+ reset_button.click(reset_stats, inputs=None, outputs=[result])
296
+
297
+ check_button.click(
298
+ check_answer,
299
+ inputs=[
300
+ gr.Textbox(value=lambda: mlm_answer.value if current_task == "mlm" else ntp_answer.value),
301
+ task_radio
302
+ ],
303
+ outputs=[result]
304
+ )
305
+
306
+ mlm_answer.submit(check_mlm_answer, inputs=[mlm_answer], outputs=[result])
307
+ ntp_answer.submit(check_ntp_answer, inputs=[ntp_answer], outputs=[result])
308
+
309
+ demo.launch()