File size: 13,907 Bytes
30f5f00
 
 
 
 
 
 
 
cbbc299
30f5f00
 
 
 
 
 
 
cbbc299
30f5f00
cbbc299
 
 
 
 
 
 
 
 
30f5f00
 
 
 
 
 
cbbc299
 
30f5f00
cbbc299
30f5f00
59fd051
 
 
 
 
 
 
30f5f00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59fd051
 
 
 
30f5f00
59fd051
 
 
30f5f00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59fd051
30f5f00
 
 
 
 
 
 
 
 
 
 
 
 
 
59fd051
 
 
 
 
 
 
 
 
 
 
 
30f5f00
 
 
 
 
 
cbbc299
 
 
 
 
 
 
 
30f5f00
 
 
cbbc299
30f5f00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59fd051
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30f5f00
59fd051
 
 
 
 
 
 
 
 
 
 
30f5f00
59fd051
 
 
 
30f5f00
 
59fd051
30f5f00
59fd051
30f5f00
59fd051
 
 
 
30f5f00
59fd051
 
 
 
 
30f5f00
 
59fd051
 
 
30f5f00
59fd051
 
 
30f5f00
59fd051
 
30f5f00
59fd051
 
 
 
 
30f5f00
59fd051
 
 
 
 
 
 
 
 
 
30f5f00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cbbc299
 
30f5f00
 
cbbc299
30f5f00
 
 
59fd051
 
 
30f5f00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
import gradio as gr
import random
import re
import numpy as np
from datasets import load_dataset
from transformers import AutoTokenizer

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

# Initialize variables to track stats
user_stats = {
    "mlm": {"correct": 0, "total": 0},
    "ntp": {"correct": 0, "total": 0}
}

# Function to load and sample from the requested dataset
def load_sample_data(sample_size=100):
    try:
        # Try to load the requested dataset
        dataset = load_dataset("mlfoundations/dclm-baseline-1.0-parquet", streaming=True)
        dataset_field = "text"  # Assuming the field name is "text"
    except Exception as e:
        print(f"Error loading requested dataset: {e}")
        # Fallback to cc_news if there's an issue
        dataset = load_dataset("vblagoje/cc_news", streaming=True)
        dataset_field = "text"
    
    # Sample from the dataset
    samples = []
    for i, example in enumerate(dataset["train"]):
        if i >= sample_size:
            break
        # Get text from the appropriate field
        if dataset_field in example and example[dataset_field]:
            # Clean text by removing extra whitespaces
            text = re.sub(r'\s+', ' ', example[dataset_field]).strip()
            # Only include longer texts to make the task meaningful
            if len(text.split()) > 20:
                # Truncate to two sentences
                sentences = re.split(r'(?<=[.!?])\s+', text)
                if len(sentences) >= 2:
                    # Take only the first two sentences
                    two_sentence_text = ' '.join(sentences[:2])
                    samples.append(two_sentence_text)
    
    return samples

# Load data at startup
data_samples = load_sample_data(100)
current_sample = None
masked_text = ""
original_text = ""
masked_indices = []
masked_tokens = []
current_task = "mlm"

def prepare_mlm_sample(text, mask_ratio=0.15):
    """Prepare a text sample for MLM by masking random tokens."""
    global masked_indices, masked_tokens, original_text
    
    tokens = tokenizer.tokenize(text)
    # Only mask whole words, not special tokens or punctuation
    maskable_indices = [i for i, token in enumerate(tokens) 
                        if not token.startswith("##") and not token.startswith("[") and not token.endswith("]") 
                        and token not in [".", ",", "!", "?", ";", ":", "'", "\"", "-"]]
    
    # Calculate how many tokens to mask
    num_to_mask = max(1, int(len(maskable_indices) * mask_ratio))
    # Randomly select indices to mask
    indices_to_mask = random.sample(maskable_indices, min(num_to_mask, len(maskable_indices)))
    
    # Create a copy of tokens to mask
    masked_tokens_list = tokens.copy()
    original_tokens = []
    
    # Replace selected tokens with [MASK]
    for idx in indices_to_mask:
        original_tokens.append(masked_tokens_list[idx])
        masked_tokens_list[idx] = "[MASK]"
    
    # Save info for evaluation
    masked_indices = indices_to_mask
    masked_tokens = original_tokens
    original_text = text
    
    # Convert back to text with masks
    masked_text = tokenizer.convert_tokens_to_string(masked_tokens_list)
    
    return masked_text, indices_to_mask, original_tokens

def prepare_ntp_sample(text, cut_ratio=0.3):
    """Prepare a text sample for NTP by cutting off the end."""
    # Tokenize text to ensure reasonable cutting
    tokens = tokenizer.tokenize(text)
    
    # Ensure we have enough tokens
    if len(tokens) < 5:
        return text, ""  # Return original if too short
    
    # Calculate cutoff point (70% of tokens if cut_ratio is 0.3)
    # But make sure we have at least 3 tokens visible and 1 token hidden
    cutoff = max(3, int(len(tokens) * (1 - cut_ratio)))
    cutoff = min(cutoff, len(tokens) - 1)  # Ensure there's at least 1 token to predict
    
    # Get the visible part
    visible_tokens = tokens[:cutoff]
    
    # Get the hidden part (to be predicted)
    hidden_tokens = tokens[cutoff:]
    
    # Convert back to text
    visible_text = tokenizer.convert_tokens_to_string(visible_tokens)
    hidden_text = tokenizer.convert_tokens_to_string(hidden_tokens)
    
    return visible_text, hidden_text

def get_new_sample(task, mask_ratio=0.15):
    """Get a new text sample based on the task."""
    global current_sample, masked_text, masked_indices, masked_tokens, original_text, ntp_state
    
    # Select a random sample
    current_sample = random.choice(data_samples)
    
    if task == "mlm":
        # Prepare MLM sample
        masked_text, masked_indices, masked_tokens = prepare_mlm_sample(current_sample, mask_ratio)
        return masked_text
    else:  # NTP
        # Prepare NTP sample
        visible_text, hidden_text = prepare_ntp_sample(current_sample, mask_ratio)
        # Store original and visible for comparison
        original_text = current_sample
        masked_text = visible_text
        
        # Reset NTP state for new iteration
        ntp_state = {
            "full_text": "",
            "revealed_text": "",
            "next_token_idx": 0,
            "tokens": []
        }
        
        # Prepare for token-by-token prediction
        prepare_next_token_prediction()
        
        return visible_text

def check_mlm_answer(user_answers):
    """Check user MLM answers against the masked tokens."""
    global user_stats
    
    # Improved parsing of user answers to better handle different formats
    # First replace any whitespace around commas with just commas
    cleaned_answers = re.sub(r'\s*,\s*', ',', user_answers.strip())
    # Then split by comma or whitespace
    user_tokens = []
    for token in re.split(r',|\s+', cleaned_answers):
        if token:  # Only add non-empty tokens
            user_tokens.append(token.strip().lower())
    
    # Ensure we have the same number of answers as masks
    if len(user_tokens) != len(masked_tokens):
        return f"Please provide {len(masked_tokens)} answers. You provided {len(user_tokens)}.\nFormat: word1, word2, word3"
    
    # Compare each answer
    correct = 0
    feedback = []
    
    for i, (user_token, orig_token) in enumerate(zip(user_tokens, masked_tokens)):
        orig_token = orig_token.lower()
        # Remove ## from subword tokens for comparison
        if orig_token.startswith("##"):
            orig_token = orig_token[2:]
        
        if user_token == orig_token:
            correct += 1
            feedback.append(f"βœ“ Token {i+1}: '{user_token}' is correct!")
        else:
            feedback.append(f"βœ— Token {i+1}: '{user_token}' should be '{orig_token}'")
    
    # Update stats
    user_stats["mlm"]["correct"] += correct
    user_stats["mlm"]["total"] += len(masked_tokens)
    
    # Calculate accuracy
    accuracy = correct / len(masked_tokens) if masked_tokens else 0
    accuracy_percentage = accuracy * 100
    
    # Add overall accuracy to feedback
    feedback.insert(0, f"Your accuracy: {correct}/{len(masked_tokens)} ({accuracy_percentage:.1f}%)")
    
    # Calculate overall stats
    overall_accuracy = user_stats["mlm"]["correct"] / user_stats["mlm"]["total"] if user_stats["mlm"]["total"] > 0 else 0
    feedback.append(f"\nOverall MLM Accuracy: {user_stats['mlm']['correct']}/{user_stats['mlm']['total']} ({overall_accuracy*100:.1f}%)")
    
    return "\n".join(feedback)

# Variable to store NTP state
ntp_state = {
    "full_text": "",
    "revealed_text": "",
    "next_token_idx": 0,
    "tokens": []
}

def prepare_next_token_prediction():
    """Prepare for the next token prediction."""
    global ntp_state, masked_text, original_text
    
    # Get the hidden part
    full_hidden = original_text[len(masked_text):].strip()
    
    # Tokenize the hidden part
    ntp_state["tokens"] = tokenizer.tokenize(full_hidden)
    ntp_state["full_text"] = full_hidden
    ntp_state["revealed_text"] = ""
    ntp_state["next_token_idx"] = 0
    
    # Make sure we have tokens to predict
    if not ntp_state["tokens"]:
        # If we don't have tokens, get a new sample
        new_text = get_new_sample("ntp", 0.3)
        prepare_next_token_prediction()

def check_ntp_answer(user_continuation):
    """Check user NTP answer for the next token only."""
    global user_stats, ntp_state, masked_text
    
    # If we haven't set up NTP state yet, do it now
    if not ntp_state["tokens"]:
        prepare_next_token_prediction()
    
    # No more tokens to predict
    if ntp_state["next_token_idx"] >= len(ntp_state["tokens"]):
        # Reset for next round
        return "You've completed this prediction! Click 'New Sample' for another."
    
    # Get the next token to predict
    next_token = ntp_state["tokens"][ntp_state["next_token_idx"]]
    
    # Get user's prediction
    user_text = user_continuation.strip()
    
    # Tokenize user's prediction to get their first token
    user_tokens = tokenizer.tokenize(user_text)
    user_token = user_tokens[0].lower() if user_tokens else ""
    
    # Clean up tokens for comparison
    next_token_clean = next_token.lower()
    if next_token_clean.startswith("##"):
        next_token_clean = next_token_clean[2:]
    
    if user_token.startswith("##"):
        user_token = user_token[2:]
    
    # Check if correct
    is_correct = (user_token == next_token_clean)
    
    # Update stats
    if is_correct:
        user_stats["ntp"]["correct"] += 1
    user_stats["ntp"]["total"] += 1
    
    # Reveal this token and prepare for next
    ntp_state["revealed_text"] += " " + tokenizer.convert_tokens_to_string([next_token])
    ntp_state["next_token_idx"] += 1
    
    # Calculate overall accuracy
    overall_accuracy = user_stats["ntp"]["correct"] / user_stats["ntp"]["total"] if user_stats["ntp"]["total"] > 0 else 0
    
    feedback = []
    if is_correct:
        feedback.append(f"βœ“ Correct! The next token was indeed '{next_token_clean}'")
    else:
        feedback.append(f"βœ— Not quite. The actual next token was '{next_token_clean}'")
    
    # Show progress
    feedback.append(f"\nRevealed so far: {masked_text}{ntp_state['revealed_text']}")
    
    # If there are more tokens, prompt for next
    if ntp_state["next_token_idx"] < len(ntp_state["tokens"]):
        feedback.append(f"\nPredict the next token...")
    else:
        feedback.append(f"\nPrediction complete! Full text was:\n{original_text}")
    
    # Show overall stats
    feedback.append(f"\nOverall NTP Accuracy: {user_stats['ntp']['correct']}/{user_stats['ntp']['total']} ({overall_accuracy*100:.1f}%)")
    
    return "\n".join(feedback)

def switch_task(task):
    """Switch between MLM and NTP tasks."""
    global current_task
    current_task = task
    return gr.update(visible=(task == "mlm")), gr.update(visible=(task == "ntp"))

def generate_new_sample(mask_ratio):
    """Generate a new sample based on current task."""
    ratio = float(mask_ratio) / 100.0  # Convert percentage to ratio
    sample = get_new_sample(current_task, ratio)
    return sample, ""

def check_answer(user_input, task):
    """Check user answer based on current task."""
    if task == "mlm":
        return check_mlm_answer(user_input)
    else:  # NTP
        return check_ntp_answer(user_input)

def reset_stats():
    """Reset user statistics."""
    global user_stats
    user_stats = {
        "mlm": {"correct": 0, "total": 0},
        "ntp": {"correct": 0, "total": 0}
    }
    return "Statistics have been reset."

# Set up Gradio interface
with gr.Blocks(title="MLM and NTP Testing") as demo:
    gr.Markdown("# Language Model Testing: MLM vs NTP")
    gr.Markdown("Test your skills at Masked Language Modeling (MLM) and Next Token Prediction (NTP)")
    
    with gr.Row():
        task_radio = gr.Radio(
            ["mlm", "ntp"], 
            label="Task Type", 
            value="mlm",
            info="MLM: Guess the masked words | NTP: Predict what comes next"
        )
        mask_ratio = gr.Slider(
            minimum=5, 
            maximum=50, 
            value=15, 
            step=5, 
            label="Mask/Cut Ratio (%)",
            info="Percentage of tokens to mask (MLM) or text to hide (NTP)"
        )
    
    sample_text = gr.Textbox(
        label="Text Sample", 
        placeholder="Click 'New Sample' to get started",
        value=get_new_sample("mlm", 0.15),
        lines=10,
        interactive=False
    )
    
    with gr.Row():
        new_button = gr.Button("New Sample")
        reset_button = gr.Button("Reset Stats")
    
    with gr.Group() as mlm_group:
        mlm_answer = gr.Textbox(
            label="Your MLM answers (separated by commas)", 
            placeholder="word1, word2, word3, etc.",
            lines=1
        )
        gr.Markdown("**Example input format:** finding, its, phishing, in, links, 49, and, it")
    
    with gr.Group(visible=False) as ntp_group:
        ntp_answer = gr.Textbox(
            label="Your Next Token Prediction", 
            placeholder="Predict the next token/word...",
            lines=1
        )
    
    with gr.Row():
        check_button = gr.Button("Check Answer")
    
    result = gr.Textbox(label="Result", lines=6)
    
    # Set up event handlers
    task_radio.change(switch_task, inputs=[task_radio], outputs=[mlm_group, ntp_group])
    new_button.click(generate_new_sample, inputs=[mask_ratio], outputs=[sample_text, result])
    reset_button.click(reset_stats, inputs=None, outputs=[result])
    
    check_button.click(
        check_answer, 
        inputs=[
            gr.Textbox(value=lambda: mlm_answer.value if current_task == "mlm" else ntp_answer.value), 
            task_radio
        ], 
        outputs=[result]
    )
    
    mlm_answer.submit(check_mlm_answer, inputs=[mlm_answer], outputs=[result])
    ntp_answer.submit(check_ntp_answer, inputs=[ntp_answer], outputs=[result])

demo.launch()