|
import gradio as gr |
|
import random |
|
import re |
|
import numpy as np |
|
from datasets import load_dataset |
|
from transformers import AutoTokenizer |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") |
|
|
|
|
|
user_stats = { |
|
"mlm": {"correct": 0, "total": 0}, |
|
"ntp": {"correct": 0, "total": 0} |
|
} |
|
|
|
|
|
def load_sample_data(sample_size=100): |
|
try: |
|
|
|
dataset = load_dataset("mlfoundations/dclm-baseline-1.0-parquet", streaming=True) |
|
dataset_field = "text" |
|
except Exception as e: |
|
print(f"Error loading requested dataset: {e}") |
|
|
|
dataset = load_dataset("vblagoje/cc_news", streaming=True) |
|
dataset_field = "text" |
|
|
|
|
|
samples = [] |
|
for i, example in enumerate(dataset["train"]): |
|
if i >= sample_size: |
|
break |
|
|
|
if dataset_field in example and example[dataset_field]: |
|
|
|
text = re.sub(r'\s+', ' ', example[dataset_field]).strip() |
|
|
|
if len(text.split()) > 20: |
|
|
|
sentences = re.split(r'(?<=[.!?])\s+', text) |
|
if len(sentences) >= 2: |
|
|
|
two_sentence_text = ' '.join(sentences[:2]) |
|
samples.append(two_sentence_text) |
|
|
|
return samples |
|
|
|
|
|
data_samples = load_sample_data(100) |
|
current_sample = None |
|
masked_text = "" |
|
original_text = "" |
|
masked_indices = [] |
|
masked_tokens = [] |
|
current_task = "mlm" |
|
|
|
def prepare_mlm_sample(text, mask_ratio=0.15): |
|
"""Prepare a text sample for MLM by masking random tokens.""" |
|
global masked_indices, masked_tokens, original_text |
|
|
|
tokens = tokenizer.tokenize(text) |
|
|
|
maskable_indices = [i for i, token in enumerate(tokens) |
|
if not token.startswith("##") and not token.startswith("[") and not token.endswith("]") |
|
and token not in [".", ",", "!", "?", ";", ":", "'", "\"", "-"]] |
|
|
|
|
|
num_to_mask = max(1, min(8, int(len(maskable_indices) * mask_ratio))) |
|
|
|
indices_to_mask = random.sample(maskable_indices, min(num_to_mask, len(maskable_indices))) |
|
|
|
indices_to_mask.sort() |
|
|
|
|
|
masked_tokens_list = tokens.copy() |
|
original_tokens = [] |
|
|
|
|
|
for idx in indices_to_mask: |
|
original_tokens.append(masked_tokens_list[idx]) |
|
masked_tokens_list[idx] = "[MASK]" |
|
|
|
|
|
masked_indices = indices_to_mask |
|
masked_tokens = original_tokens |
|
original_text = text |
|
|
|
|
|
masked_text = tokenizer.convert_tokens_to_string(masked_tokens_list) |
|
|
|
|
|
print(f"Original tokens: {original_tokens}") |
|
print(f"Masked indices: {indices_to_mask}") |
|
print(f"Number of masks: {len(original_tokens)}") |
|
|
|
return masked_text, indices_to_mask, original_tokens |
|
|
|
def prepare_ntp_sample(text, cut_ratio=0.3): |
|
"""Prepare a text sample for NTP by cutting off the end.""" |
|
|
|
tokens = tokenizer.tokenize(text) |
|
|
|
|
|
if len(tokens) < 5: |
|
return text, "" |
|
|
|
|
|
|
|
cutoff = max(3, int(len(tokens) * (1 - cut_ratio))) |
|
cutoff = min(cutoff, len(tokens) - 1) |
|
|
|
|
|
visible_tokens = tokens[:cutoff] |
|
|
|
|
|
hidden_tokens = tokens[cutoff:] |
|
|
|
|
|
visible_text = tokenizer.convert_tokens_to_string(visible_tokens) |
|
hidden_text = tokenizer.convert_tokens_to_string(hidden_tokens) |
|
|
|
return visible_text, hidden_text |
|
|
|
def get_new_sample(task, mask_ratio=0.15): |
|
"""Get a new text sample based on the task.""" |
|
global current_sample, masked_text, masked_indices, masked_tokens, original_text, ntp_state |
|
|
|
|
|
current_sample = random.choice(data_samples) |
|
|
|
if task == "mlm": |
|
|
|
masked_text, masked_indices, masked_tokens = prepare_mlm_sample(current_sample, mask_ratio) |
|
return masked_text |
|
else: |
|
|
|
visible_text, hidden_text = prepare_ntp_sample(current_sample, mask_ratio) |
|
|
|
original_text = current_sample |
|
masked_text = visible_text |
|
|
|
|
|
ntp_state = { |
|
"full_text": "", |
|
"revealed_text": "", |
|
"next_token_idx": 0, |
|
"tokens": [] |
|
} |
|
|
|
|
|
prepare_next_token_prediction() |
|
|
|
return visible_text |
|
|
|
def check_mlm_answer(user_answers): |
|
"""Check user MLM answers against the masked tokens.""" |
|
global user_stats |
|
|
|
|
|
print(f"Original user input: '{user_answers}'") |
|
|
|
|
|
if not user_answers or user_answers.isspace(): |
|
return "Please provide your answers. No input was detected." |
|
|
|
|
|
user_answers = user_answers.strip().lower() |
|
print(f"After basic cleanup: '{user_answers}'") |
|
|
|
|
|
if ',' in user_answers: |
|
|
|
user_tokens = [token.strip() for token in user_answers.split(',')] |
|
|
|
user_tokens = [token for token in user_tokens if token] |
|
else: |
|
|
|
user_tokens = [token for token in user_answers.split() if token] |
|
|
|
print(f"Parsed tokens: {user_tokens}, count: {len(user_tokens)}") |
|
print(f"Expected tokens: {masked_tokens}, count: {len(masked_tokens)}") |
|
|
|
|
|
if len(user_tokens) != len(masked_tokens): |
|
return f"Please provide exactly {len(masked_tokens)} answers (one for each [MASK]). You provided {len(user_tokens)}.\n\nFormat example: word1, word2, word3" |
|
|
|
|
|
correct = 0 |
|
feedback = [] |
|
|
|
for i, (user_token, orig_token) in enumerate(zip(user_tokens, masked_tokens)): |
|
orig_token = orig_token.lower() |
|
|
|
if orig_token.startswith("##"): |
|
orig_token = orig_token[2:] |
|
|
|
if user_token == orig_token: |
|
correct += 1 |
|
feedback.append(f"✓ Token {i+1}: '{user_token}' is correct!") |
|
else: |
|
feedback.append(f"✗ Token {i+1}: '{user_token}' should be '{orig_token}'") |
|
|
|
|
|
user_stats["mlm"]["correct"] += correct |
|
user_stats["mlm"]["total"] += len(masked_tokens) |
|
|
|
|
|
accuracy = correct / len(masked_tokens) if masked_tokens else 0 |
|
accuracy_percentage = accuracy * 100 |
|
|
|
|
|
feedback.insert(0, f"Your accuracy: {correct}/{len(masked_tokens)} ({accuracy_percentage:.1f}%)") |
|
|
|
|
|
overall_accuracy = user_stats["mlm"]["correct"] / user_stats["mlm"]["total"] if user_stats["mlm"]["total"] > 0 else 0 |
|
feedback.append(f"\nOverall MLM Accuracy: {user_stats['mlm']['correct']}/{user_stats['mlm']['total']} ({overall_accuracy*100:.1f}%)") |
|
|
|
return "\n".join(feedback) |
|
|
|
|
|
ntp_state = { |
|
"full_text": "", |
|
"revealed_text": "", |
|
"next_token_idx": 0, |
|
"tokens": [] |
|
} |
|
|
|
def prepare_next_token_prediction(): |
|
"""Prepare for the next token prediction.""" |
|
global ntp_state, masked_text, original_text |
|
|
|
|
|
full_hidden = original_text[len(masked_text):].strip() |
|
|
|
|
|
ntp_state["tokens"] = tokenizer.tokenize(full_hidden) |
|
ntp_state["full_text"] = full_hidden |
|
ntp_state["revealed_text"] = "" |
|
ntp_state["next_token_idx"] = 0 |
|
|
|
|
|
if not ntp_state["tokens"]: |
|
|
|
new_text = get_new_sample("ntp", 0.3) |
|
prepare_next_token_prediction() |
|
|
|
def check_ntp_answer(user_continuation): |
|
"""Check user NTP answer for the next token only.""" |
|
global user_stats, ntp_state, masked_text |
|
|
|
|
|
if not ntp_state["tokens"]: |
|
prepare_next_token_prediction() |
|
|
|
|
|
if ntp_state["next_token_idx"] >= len(ntp_state["tokens"]): |
|
|
|
return "You've completed this prediction! Click 'New Sample' for another." |
|
|
|
|
|
next_token = ntp_state["tokens"][ntp_state["next_token_idx"]] |
|
|
|
|
|
user_text = user_continuation.strip() |
|
|
|
|
|
user_tokens = tokenizer.tokenize(user_text) |
|
user_token = user_tokens[0].lower() if user_tokens else "" |
|
|
|
|
|
next_token_clean = next_token.lower() |
|
if next_token_clean.startswith("##"): |
|
next_token_clean = next_token_clean[2:] |
|
|
|
if user_token.startswith("##"): |
|
user_token = user_token[2:] |
|
|
|
|
|
is_correct = (user_token == next_token_clean) |
|
|
|
|
|
if is_correct: |
|
user_stats["ntp"]["correct"] += 1 |
|
user_stats["ntp"]["total"] += 1 |
|
|
|
|
|
ntp_state["revealed_text"] += " " + tokenizer.convert_tokens_to_string([next_token]) |
|
ntp_state["next_token_idx"] += 1 |
|
|
|
|
|
overall_accuracy = user_stats["ntp"]["correct"] / user_stats["ntp"]["total"] if user_stats["ntp"]["total"] > 0 else 0 |
|
|
|
feedback = [] |
|
if is_correct: |
|
feedback.append(f"✓ Correct! The next token was indeed '{next_token_clean}'") |
|
else: |
|
feedback.append(f"✗ Not quite. The actual next token was '{next_token_clean}'") |
|
|
|
|
|
feedback.append(f"\nRevealed so far: {masked_text}{ntp_state['revealed_text']}") |
|
|
|
|
|
if ntp_state["next_token_idx"] < len(ntp_state["tokens"]): |
|
feedback.append(f"\nPredict the next token...") |
|
else: |
|
feedback.append(f"\nPrediction complete! Full text was:\n{original_text}") |
|
|
|
|
|
feedback.append(f"\nOverall NTP Accuracy: {user_stats['ntp']['correct']}/{user_stats['ntp']['total']} ({overall_accuracy*100:.1f}%)") |
|
|
|
return "\n".join(feedback) |
|
|
|
def switch_task(task): |
|
"""Switch between MLM and NTP tasks.""" |
|
global current_task |
|
current_task = task |
|
return gr.update(visible=(task == "mlm")), gr.update(visible=(task == "ntp")) |
|
|
|
def generate_new_sample(mask_ratio): |
|
"""Generate a new sample based on current task.""" |
|
ratio = float(mask_ratio) / 100.0 |
|
sample = get_new_sample(current_task, ratio) |
|
return sample, "" |
|
|
|
def check_answer(user_input, task): |
|
"""Check user answer based on current task.""" |
|
|
|
if task == "mlm": |
|
return check_mlm_answer(user_input) |
|
else: |
|
return check_ntp_answer(user_input) |
|
|
|
def reset_stats(): |
|
"""Reset user statistics.""" |
|
global user_stats |
|
user_stats = { |
|
"mlm": {"correct": 0, "total": 0}, |
|
"ntp": {"correct": 0, "total": 0} |
|
} |
|
return "Statistics have been reset." |
|
|
|
|
|
with gr.Blocks(title="MLM and NTP Testing") as demo: |
|
gr.Markdown("# Language Model Testing: MLM vs NTP") |
|
gr.Markdown("Test your skills at Masked Language Modeling (MLM) and Next Token Prediction (NTP)") |
|
|
|
with gr.Row(): |
|
task_radio = gr.Radio( |
|
["mlm", "ntp"], |
|
label="Task Type", |
|
value="mlm", |
|
info="MLM: Guess the masked words | NTP: Predict what comes next" |
|
) |
|
mask_ratio = gr.Slider( |
|
minimum=5, |
|
maximum=50, |
|
value=15, |
|
step=5, |
|
label="Mask/Cut Ratio (%)", |
|
info="Percentage of tokens to mask (MLM) or text to hide (NTP)" |
|
) |
|
|
|
|
|
mask_count = gr.Markdown("**Number of [MASK] tokens to guess: 0**") |
|
|
|
sample_text = gr.Textbox( |
|
label="Text Sample", |
|
placeholder="Click 'New Sample' to get started", |
|
value=get_new_sample("mlm", 0.15), |
|
lines=10, |
|
interactive=False |
|
) |
|
|
|
with gr.Row(): |
|
new_button = gr.Button("New Sample") |
|
reset_button = gr.Button("Reset Stats") |
|
|
|
|
|
input_area = gr.Group() |
|
|
|
with input_area: |
|
|
|
mlm_instructions = gr.Markdown(""" |
|
### MLM Instructions |
|
1. For each [MASK] token, provide your guess for the original word. |
|
2. Separate your answers with commas. |
|
3. Make sure you provide exactly the same number of answers as [MASK] tokens. |
|
|
|
**Example format:** `word1, word2, word3` or `word1,word2,word3` |
|
""", visible=True) |
|
|
|
ntp_instructions = gr.Markdown(""" |
|
### NTP Instructions |
|
Predict the next word or token that would follow the text. |
|
Type a single word or token for each prediction. |
|
""", visible=False) |
|
|
|
|
|
answer_input = gr.Textbox( |
|
label="Your answer", |
|
placeholder="For MLM: word1, word2, word3 | For NTP: single word", |
|
lines=1 |
|
) |
|
|
|
with gr.Row(): |
|
check_button = gr.Button("Check Answer", variant="primary") |
|
|
|
result = gr.Textbox(label="Result", lines=6) |
|
|
|
|
|
def switch_task_unified(task): |
|
if task == "mlm": |
|
mask_text = f"**Number of [MASK] tokens to guess: {len(masked_tokens)}**" |
|
return ( |
|
gr.update(visible=True), |
|
gr.update(visible=False), |
|
gr.update(placeholder="comma-separated answers (e.g., word1, word2, word3)"), |
|
mask_text |
|
) |
|
else: |
|
return ( |
|
gr.update(visible=False), |
|
gr.update(visible=True), |
|
gr.update(placeholder="Type the next word/token you predict"), |
|
"**Next Token Prediction mode - guess one token at a time**" |
|
) |
|
|
|
|
|
task_radio.change( |
|
switch_task_unified, |
|
inputs=[task_radio], |
|
outputs=[mlm_instructions, ntp_instructions, answer_input, mask_count] |
|
) |
|
|
|
|
|
def new_sample_with_count(mask_ratio_pct, task): |
|
ratio = float(mask_ratio_pct) / 100.0 |
|
sample = get_new_sample(task, ratio) |
|
mask_count_text = "" |
|
|
|
if task == "mlm": |
|
count = len(masked_tokens) |
|
mask_count_text = f"**Number of [MASK] tokens to guess: {count}**" |
|
else: |
|
mask_count_text = "**Next Token Prediction mode - guess one token at a time**" |
|
|
|
return sample, mask_count_text, "" |
|
|
|
new_button.click( |
|
new_sample_with_count, |
|
inputs=[mask_ratio, task_radio], |
|
outputs=[sample_text, mask_count, result] |
|
) |
|
|
|
reset_button.click(reset_stats, inputs=None, outputs=[result]) |
|
|
|
|
|
def unified_check_answer(user_input, task): |
|
if task == "mlm": |
|
return check_mlm_answer(user_input) |
|
else: |
|
return check_ntp_answer(user_input) |
|
|
|
check_button.click( |
|
unified_check_answer, |
|
inputs=[answer_input, task_radio], |
|
outputs=[result] |
|
) |
|
|
|
answer_input.submit( |
|
unified_check_answer, |
|
inputs=[answer_input, task_radio], |
|
outputs=[result] |
|
) |
|
|
|
demo.launch() |