orionweller's picture
Update app.py
30f5f00 verified
raw
history blame
10.9 kB
import gradio as gr
import random
import re
import numpy as np
from datasets import load_dataset
from transformers import AutoTokenizer
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
# Initialize variables to track stats
user_stats = {
"mlm": {"correct": 0, "total": 0},
"ntp": {"correct": 0, "total": 0}
}
# Function to load and sample from cc_news dataset
def load_sample_data(sample_size=100):
dataset = load_dataset("vblagoje/cc_news", streaming=True)
# Sample from the dataset
samples = []
for i, example in enumerate(dataset["train"]):
if i >= sample_size:
break
# Only use text field
if "text" in example and example["text"]:
# Clean text by removing extra whitespaces
text = re.sub(r'\s+', ' ', example["text"]).strip()
# Only include longer texts to make the task meaningful
if len(text.split()) > 50:
samples.append(text)
return samples
# Load data at startup
data_samples = load_sample_data(100)
current_sample = None
masked_text = ""
original_text = ""
masked_indices = []
masked_tokens = []
current_task = "mlm"
def prepare_mlm_sample(text, mask_ratio=0.15):
"""Prepare a text sample for MLM by masking random tokens."""
global masked_indices, masked_tokens, original_text
tokens = tokenizer.tokenize(text)
# Only mask whole words, not special tokens or punctuation
maskable_indices = [i for i, token in enumerate(tokens)
if not token.startswith("##") and not token.startswith("[") and not token.endswith("]")
and token not in [".", ",", "!", "?", ";", ":", "'", "\"", "-"]]
# Calculate how many tokens to mask
num_to_mask = max(1, int(len(maskable_indices) * mask_ratio))
# Randomly select indices to mask
indices_to_mask = random.sample(maskable_indices, min(num_to_mask, len(maskable_indices)))
# Create a copy of tokens to mask
masked_tokens_list = tokens.copy()
original_tokens = []
# Replace selected tokens with [MASK]
for idx in indices_to_mask:
original_tokens.append(masked_tokens_list[idx])
masked_tokens_list[idx] = "[MASK]"
# Save info for evaluation
masked_indices = indices_to_mask
masked_tokens = original_tokens
original_text = text
# Convert back to text with masks
masked_text = tokenizer.convert_tokens_to_string(masked_tokens_list)
return masked_text, indices_to_mask, original_tokens
def prepare_ntp_sample(text, cut_ratio=0.3):
"""Prepare a text sample for NTP by cutting off the end."""
# Tokenize text to ensure reasonable cutting
tokens = tokenizer.tokenize(text)
# Calculate cutoff point (70% of tokens if cut_ratio is 0.3)
cutoff = int(len(tokens) * (1 - cut_ratio))
# Get the visible part
visible_tokens = tokens[:cutoff]
# Get the hidden part (to be predicted)
hidden_tokens = tokens[cutoff:]
# Convert back to text
visible_text = tokenizer.convert_tokens_to_string(visible_tokens)
hidden_text = tokenizer.convert_tokens_to_string(hidden_tokens)
return visible_text, hidden_text
def get_new_sample(task, mask_ratio=0.15):
"""Get a new text sample based on the task."""
global current_sample, masked_text, masked_indices, masked_tokens, original_text
# Select a random sample
current_sample = random.choice(data_samples)
if task == "mlm":
# Prepare MLM sample
masked_text, masked_indices, masked_tokens = prepare_mlm_sample(current_sample, mask_ratio)
return masked_text
else: # NTP
# Prepare NTP sample
visible_text, hidden_text = prepare_ntp_sample(current_sample, mask_ratio)
# Store original and visible for comparison
original_text = current_sample
masked_text = visible_text
return visible_text
def check_mlm_answer(user_answers):
"""Check user MLM answers against the masked tokens."""
global user_stats
# Split user answers by spaces or commas
user_tokens = [token.strip().lower() for token in re.split(r'[,\s]+', user_answers)]
# Ensure we have the same number of answers as masks
if len(user_tokens) != len(masked_tokens):
return f"Please provide {len(masked_tokens)} answers. You provided {len(user_tokens)}."
# Compare each answer
correct = 0
feedback = []
for i, (user_token, orig_token) in enumerate(zip(user_tokens, masked_tokens)):
orig_token = orig_token.lower()
# Remove ## from subword tokens for comparison
if orig_token.startswith("##"):
orig_token = orig_token[2:]
if user_token == orig_token:
correct += 1
feedback.append(f"βœ“ Token {i+1}: '{user_token}' is correct!")
else:
feedback.append(f"βœ— Token {i+1}: '{user_token}' should be '{orig_token}'")
# Update stats
user_stats["mlm"]["correct"] += correct
user_stats["mlm"]["total"] += len(masked_tokens)
# Calculate accuracy
accuracy = correct / len(masked_tokens) if masked_tokens else 0
accuracy_percentage = accuracy * 100
# Add overall accuracy to feedback
feedback.insert(0, f"Your accuracy: {correct}/{len(masked_tokens)} ({accuracy_percentage:.1f}%)")
# Calculate overall stats
overall_accuracy = user_stats["mlm"]["correct"] / user_stats["mlm"]["total"] if user_stats["mlm"]["total"] > 0 else 0
feedback.append(f"\nOverall MLM Accuracy: {user_stats['mlm']['correct']}/{user_stats['mlm']['total']} ({overall_accuracy*100:.1f}%)")
return "\n".join(feedback)
def check_ntp_answer(user_continuation):
"""Check user NTP answer against the original text."""
global user_stats, original_text, masked_text
# Get the hidden part of the original text
hidden_text = original_text[len(masked_text):].strip()
user_text = user_continuation.strip()
# Tokenize for better comparison
hidden_tokens = tokenizer.tokenize(hidden_text)
user_tokens = tokenizer.tokenize(user_text)
# Calculate overlap using first few tokens (more lenient)
max_compare = min(10, len(hidden_tokens), len(user_tokens))
if max_compare == 0:
return "Error: No hidden tokens to compare with."
correct = 0
for i in range(max_compare):
hidden_token = hidden_tokens[i].lower()
user_token = user_tokens[i].lower() if i < len(user_tokens) else ""
# Remove ## from subword tokens
if hidden_token.startswith("##"):
hidden_token = hidden_token[2:]
if user_token.startswith("##"):
user_token = user_token[2:]
if user_token == hidden_token:
correct += 1
# Update stats
user_stats["ntp"]["correct"] += correct
user_stats["ntp"]["total"] += max_compare
# Calculate accuracy
accuracy = correct / max_compare
accuracy_percentage = accuracy * 100
feedback = [f"Your prediction accuracy: {correct}/{max_compare} ({accuracy_percentage:.1f}%)"]
# Show original continuation
feedback.append(f"\nActual continuation:\n{hidden_text}")
# Calculate overall stats
overall_accuracy = user_stats["ntp"]["correct"] / user_stats["ntp"]["total"] if user_stats["ntp"]["total"] > 0 else 0
feedback.append(f"\nOverall NTP Accuracy: {user_stats['ntp']['correct']}/{user_stats['ntp']['total']} ({overall_accuracy*100:.1f}%)")
return "\n".join(feedback)
def switch_task(task):
"""Switch between MLM and NTP tasks."""
global current_task
current_task = task
return gr.update(visible=(task == "mlm")), gr.update(visible=(task == "ntp"))
def generate_new_sample(mask_ratio):
"""Generate a new sample based on current task."""
ratio = float(mask_ratio) / 100.0 # Convert percentage to ratio
sample = get_new_sample(current_task, ratio)
return sample, ""
def check_answer(user_input, task):
"""Check user answer based on current task."""
if task == "mlm":
return check_mlm_answer(user_input)
else: # NTP
return check_ntp_answer(user_input)
def reset_stats():
"""Reset user statistics."""
global user_stats
user_stats = {
"mlm": {"correct": 0, "total": 0},
"ntp": {"correct": 0, "total": 0}
}
return "Statistics have been reset."
# Set up Gradio interface
with gr.Blocks(title="MLM and NTP Testing") as demo:
gr.Markdown("# Language Model Testing: MLM vs NTP")
gr.Markdown("Test your skills at Masked Language Modeling (MLM) and Next Token Prediction (NTP)")
with gr.Row():
task_radio = gr.Radio(
["mlm", "ntp"],
label="Task Type",
value="mlm",
info="MLM: Guess the masked words | NTP: Predict what comes next"
)
mask_ratio = gr.Slider(
minimum=5,
maximum=50,
value=15,
step=5,
label="Mask/Cut Ratio (%)",
info="Percentage of tokens to mask (MLM) or text to hide (NTP)"
)
sample_text = gr.Textbox(
label="Text Sample",
placeholder="Click 'New Sample' to get started",
value=get_new_sample("mlm", 0.15),
lines=10,
interactive=False
)
with gr.Row():
new_button = gr.Button("New Sample")
reset_button = gr.Button("Reset Stats")
with gr.Group() as mlm_group:
mlm_answer = gr.Textbox(
label="Your MLM answers (separated by spaces or commas)",
placeholder="Type your guesses for the masked words",
lines=1
)
with gr.Group(visible=False) as ntp_group:
ntp_answer = gr.Textbox(
label="Your NTP continuation",
placeholder="Predict how the text continues...",
lines=3
)
with gr.Row():
check_button = gr.Button("Check Answer")
result = gr.Textbox(label="Result", lines=6)
# Set up event handlers
task_radio.change(switch_task, inputs=[task_radio], outputs=[mlm_group, ntp_group])
new_button.click(generate_new_sample, inputs=[mask_ratio], outputs=[sample_text, result])
reset_button.click(reset_stats, inputs=None, outputs=[result])
check_button.click(
check_answer,
inputs=[
gr.Textbox(value=lambda: mlm_answer.value if current_task == "mlm" else ntp_answer.value),
task_radio
],
outputs=[result]
)
mlm_answer.submit(check_mlm_answer, inputs=[mlm_answer], outputs=[result])
ntp_answer.submit(check_ntp_answer, inputs=[ntp_answer], outputs=[result])
demo.launch()