Spaces:
Runtime error
Runtime error
import gradio as gr | |
import random | |
import re | |
import numpy as np | |
from datasets import load_dataset | |
from transformers import AutoTokenizer | |
# Load tokenizer | |
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") | |
# Initialize variables to track stats | |
user_stats = { | |
"mlm": {"correct": 0, "total": 0}, | |
"ntp": {"correct": 0, "total": 0} | |
} | |
# Function to load and sample from cc_news dataset | |
def load_sample_data(sample_size=100): | |
dataset = load_dataset("vblagoje/cc_news", streaming=True) | |
# Sample from the dataset | |
samples = [] | |
for i, example in enumerate(dataset["train"]): | |
if i >= sample_size: | |
break | |
# Only use text field | |
if "text" in example and example["text"]: | |
# Clean text by removing extra whitespaces | |
text = re.sub(r'\s+', ' ', example["text"]).strip() | |
# Only include longer texts to make the task meaningful | |
if len(text.split()) > 50: | |
samples.append(text) | |
return samples | |
# Load data at startup | |
data_samples = load_sample_data(100) | |
current_sample = None | |
masked_text = "" | |
original_text = "" | |
masked_indices = [] | |
masked_tokens = [] | |
current_task = "mlm" | |
def prepare_mlm_sample(text, mask_ratio=0.15): | |
"""Prepare a text sample for MLM by masking random tokens.""" | |
global masked_indices, masked_tokens, original_text | |
tokens = tokenizer.tokenize(text) | |
# Only mask whole words, not special tokens or punctuation | |
maskable_indices = [i for i, token in enumerate(tokens) | |
if not token.startswith("##") and not token.startswith("[") and not token.endswith("]") | |
and token not in [".", ",", "!", "?", ";", ":", "'", "\"", "-"]] | |
# Calculate how many tokens to mask | |
num_to_mask = max(1, int(len(maskable_indices) * mask_ratio)) | |
# Randomly select indices to mask | |
indices_to_mask = random.sample(maskable_indices, min(num_to_mask, len(maskable_indices))) | |
# Create a copy of tokens to mask | |
masked_tokens_list = tokens.copy() | |
original_tokens = [] | |
# Replace selected tokens with [MASK] | |
for idx in indices_to_mask: | |
original_tokens.append(masked_tokens_list[idx]) | |
masked_tokens_list[idx] = "[MASK]" | |
# Save info for evaluation | |
masked_indices = indices_to_mask | |
masked_tokens = original_tokens | |
original_text = text | |
# Convert back to text with masks | |
masked_text = tokenizer.convert_tokens_to_string(masked_tokens_list) | |
return masked_text, indices_to_mask, original_tokens | |
def prepare_ntp_sample(text, cut_ratio=0.3): | |
"""Prepare a text sample for NTP by cutting off the end.""" | |
# Tokenize text to ensure reasonable cutting | |
tokens = tokenizer.tokenize(text) | |
# Calculate cutoff point (70% of tokens if cut_ratio is 0.3) | |
cutoff = int(len(tokens) * (1 - cut_ratio)) | |
# Get the visible part | |
visible_tokens = tokens[:cutoff] | |
# Get the hidden part (to be predicted) | |
hidden_tokens = tokens[cutoff:] | |
# Convert back to text | |
visible_text = tokenizer.convert_tokens_to_string(visible_tokens) | |
hidden_text = tokenizer.convert_tokens_to_string(hidden_tokens) | |
return visible_text, hidden_text | |
def get_new_sample(task, mask_ratio=0.15): | |
"""Get a new text sample based on the task.""" | |
global current_sample, masked_text, masked_indices, masked_tokens, original_text | |
# Select a random sample | |
current_sample = random.choice(data_samples) | |
if task == "mlm": | |
# Prepare MLM sample | |
masked_text, masked_indices, masked_tokens = prepare_mlm_sample(current_sample, mask_ratio) | |
return masked_text | |
else: # NTP | |
# Prepare NTP sample | |
visible_text, hidden_text = prepare_ntp_sample(current_sample, mask_ratio) | |
# Store original and visible for comparison | |
original_text = current_sample | |
masked_text = visible_text | |
return visible_text | |
def check_mlm_answer(user_answers): | |
"""Check user MLM answers against the masked tokens.""" | |
global user_stats | |
# Split user answers by spaces or commas | |
user_tokens = [token.strip().lower() for token in re.split(r'[,\s]+', user_answers)] | |
# Ensure we have the same number of answers as masks | |
if len(user_tokens) != len(masked_tokens): | |
return f"Please provide {len(masked_tokens)} answers. You provided {len(user_tokens)}." | |
# Compare each answer | |
correct = 0 | |
feedback = [] | |
for i, (user_token, orig_token) in enumerate(zip(user_tokens, masked_tokens)): | |
orig_token = orig_token.lower() | |
# Remove ## from subword tokens for comparison | |
if orig_token.startswith("##"): | |
orig_token = orig_token[2:] | |
if user_token == orig_token: | |
correct += 1 | |
feedback.append(f"β Token {i+1}: '{user_token}' is correct!") | |
else: | |
feedback.append(f"β Token {i+1}: '{user_token}' should be '{orig_token}'") | |
# Update stats | |
user_stats["mlm"]["correct"] += correct | |
user_stats["mlm"]["total"] += len(masked_tokens) | |
# Calculate accuracy | |
accuracy = correct / len(masked_tokens) if masked_tokens else 0 | |
accuracy_percentage = accuracy * 100 | |
# Add overall accuracy to feedback | |
feedback.insert(0, f"Your accuracy: {correct}/{len(masked_tokens)} ({accuracy_percentage:.1f}%)") | |
# Calculate overall stats | |
overall_accuracy = user_stats["mlm"]["correct"] / user_stats["mlm"]["total"] if user_stats["mlm"]["total"] > 0 else 0 | |
feedback.append(f"\nOverall MLM Accuracy: {user_stats['mlm']['correct']}/{user_stats['mlm']['total']} ({overall_accuracy*100:.1f}%)") | |
return "\n".join(feedback) | |
def check_ntp_answer(user_continuation): | |
"""Check user NTP answer against the original text.""" | |
global user_stats, original_text, masked_text | |
# Get the hidden part of the original text | |
hidden_text = original_text[len(masked_text):].strip() | |
user_text = user_continuation.strip() | |
# Tokenize for better comparison | |
hidden_tokens = tokenizer.tokenize(hidden_text) | |
user_tokens = tokenizer.tokenize(user_text) | |
# Calculate overlap using first few tokens (more lenient) | |
max_compare = min(10, len(hidden_tokens), len(user_tokens)) | |
if max_compare == 0: | |
return "Error: No hidden tokens to compare with." | |
correct = 0 | |
for i in range(max_compare): | |
hidden_token = hidden_tokens[i].lower() | |
user_token = user_tokens[i].lower() if i < len(user_tokens) else "" | |
# Remove ## from subword tokens | |
if hidden_token.startswith("##"): | |
hidden_token = hidden_token[2:] | |
if user_token.startswith("##"): | |
user_token = user_token[2:] | |
if user_token == hidden_token: | |
correct += 1 | |
# Update stats | |
user_stats["ntp"]["correct"] += correct | |
user_stats["ntp"]["total"] += max_compare | |
# Calculate accuracy | |
accuracy = correct / max_compare | |
accuracy_percentage = accuracy * 100 | |
feedback = [f"Your prediction accuracy: {correct}/{max_compare} ({accuracy_percentage:.1f}%)"] | |
# Show original continuation | |
feedback.append(f"\nActual continuation:\n{hidden_text}") | |
# Calculate overall stats | |
overall_accuracy = user_stats["ntp"]["correct"] / user_stats["ntp"]["total"] if user_stats["ntp"]["total"] > 0 else 0 | |
feedback.append(f"\nOverall NTP Accuracy: {user_stats['ntp']['correct']}/{user_stats['ntp']['total']} ({overall_accuracy*100:.1f}%)") | |
return "\n".join(feedback) | |
def switch_task(task): | |
"""Switch between MLM and NTP tasks.""" | |
global current_task | |
current_task = task | |
return gr.update(visible=(task == "mlm")), gr.update(visible=(task == "ntp")) | |
def generate_new_sample(mask_ratio): | |
"""Generate a new sample based on current task.""" | |
ratio = float(mask_ratio) / 100.0 # Convert percentage to ratio | |
sample = get_new_sample(current_task, ratio) | |
return sample, "" | |
def check_answer(user_input, task): | |
"""Check user answer based on current task.""" | |
if task == "mlm": | |
return check_mlm_answer(user_input) | |
else: # NTP | |
return check_ntp_answer(user_input) | |
def reset_stats(): | |
"""Reset user statistics.""" | |
global user_stats | |
user_stats = { | |
"mlm": {"correct": 0, "total": 0}, | |
"ntp": {"correct": 0, "total": 0} | |
} | |
return "Statistics have been reset." | |
# Set up Gradio interface | |
with gr.Blocks(title="MLM and NTP Testing") as demo: | |
gr.Markdown("# Language Model Testing: MLM vs NTP") | |
gr.Markdown("Test your skills at Masked Language Modeling (MLM) and Next Token Prediction (NTP)") | |
with gr.Row(): | |
task_radio = gr.Radio( | |
["mlm", "ntp"], | |
label="Task Type", | |
value="mlm", | |
info="MLM: Guess the masked words | NTP: Predict what comes next" | |
) | |
mask_ratio = gr.Slider( | |
minimum=5, | |
maximum=50, | |
value=15, | |
step=5, | |
label="Mask/Cut Ratio (%)", | |
info="Percentage of tokens to mask (MLM) or text to hide (NTP)" | |
) | |
sample_text = gr.Textbox( | |
label="Text Sample", | |
placeholder="Click 'New Sample' to get started", | |
value=get_new_sample("mlm", 0.15), | |
lines=10, | |
interactive=False | |
) | |
with gr.Row(): | |
new_button = gr.Button("New Sample") | |
reset_button = gr.Button("Reset Stats") | |
with gr.Group() as mlm_group: | |
mlm_answer = gr.Textbox( | |
label="Your MLM answers (separated by spaces or commas)", | |
placeholder="Type your guesses for the masked words", | |
lines=1 | |
) | |
with gr.Group(visible=False) as ntp_group: | |
ntp_answer = gr.Textbox( | |
label="Your NTP continuation", | |
placeholder="Predict how the text continues...", | |
lines=3 | |
) | |
with gr.Row(): | |
check_button = gr.Button("Check Answer") | |
result = gr.Textbox(label="Result", lines=6) | |
# Set up event handlers | |
task_radio.change(switch_task, inputs=[task_radio], outputs=[mlm_group, ntp_group]) | |
new_button.click(generate_new_sample, inputs=[mask_ratio], outputs=[sample_text, result]) | |
reset_button.click(reset_stats, inputs=None, outputs=[result]) | |
check_button.click( | |
check_answer, | |
inputs=[ | |
gr.Textbox(value=lambda: mlm_answer.value if current_task == "mlm" else ntp_answer.value), | |
task_radio | |
], | |
outputs=[result] | |
) | |
mlm_answer.submit(check_mlm_answer, inputs=[mlm_answer], outputs=[result]) | |
ntp_answer.submit(check_ntp_answer, inputs=[ntp_answer], outputs=[result]) | |
demo.launch() |