orionweller's picture
Update app.py
565fb95 verified
import gradio as gr
import random
import re
import numpy as np
from datasets import load_dataset
from transformers import AutoTokenizer
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
# Initialize variables to track stats
user_stats = {
"mlm": {"correct": 0, "total": 0},
"ntp": {"correct": 0, "total": 0}
}
# Function to load and sample from the requested dataset
def load_sample_data(sample_size=100):
try:
# Try to load the requested dataset
dataset = load_dataset("mlfoundations/dclm-baseline-1.0-parquet", streaming=True)
dataset_field = "text" # Assuming the field name is "text"
except Exception as e:
print(f"Error loading requested dataset: {e}")
# Fallback to cc_news if there's an issue
dataset = load_dataset("vblagoje/cc_news", streaming=True)
dataset_field = "text"
# Sample from the dataset
samples = []
for i, example in enumerate(dataset["train"]):
if i >= sample_size:
break
# Get text from the appropriate field
if dataset_field in example and example[dataset_field]:
# Clean text by removing extra whitespaces
text = re.sub(r'\s+', ' ', example[dataset_field]).strip()
# Only include longer texts to make the task meaningful
if len(text.split()) > 20:
# Truncate to two sentences
sentences = re.split(r'(?<=[.!?])\s+', text)
if len(sentences) >= 2:
# Take only the first two sentences
two_sentence_text = ' '.join(sentences[:2])
samples.append(two_sentence_text)
return samples
# Load data at startup
data_samples = load_sample_data(100)
current_sample = None
masked_text = ""
original_text = ""
masked_indices = []
masked_tokens = []
current_task = "mlm"
def prepare_mlm_sample(text, mask_ratio=0.15):
"""Prepare a text sample for MLM by masking random tokens."""
global masked_indices, masked_tokens, original_text
tokens = tokenizer.tokenize(text)
print(f"Text length: {len(text)} characters, {len(tokens)} tokens")
# Only mask whole words, not special tokens or punctuation
maskable_indices = [i for i, token in enumerate(tokens)
if not token.startswith("##") and not token.startswith("[") and not token.endswith("]")
and token not in [".", ",", "!", "?", ";", ":", "'", "\"", "-"]]
print(f"Maskable indices count: {len(maskable_indices)}")
print(f"Mask ratio: {mask_ratio}")
# Calculate how many tokens to mask based on the mask ratio
# No arbitrary cap - use the actual percentage
num_to_mask = max(1, int(len(maskable_indices) * mask_ratio))
print(f"Number of tokens to mask: {num_to_mask}")
# Randomly select indices to mask
indices_to_mask = random.sample(maskable_indices, min(num_to_mask, len(maskable_indices)))
# Sort indices to ensure they're in order
indices_to_mask.sort()
# Create a copy of tokens to mask
masked_tokens_list = tokens.copy()
original_tokens = []
# Replace selected tokens with [MASK]
for idx in indices_to_mask:
original_tokens.append(masked_tokens_list[idx])
masked_tokens_list[idx] = "[MASK]"
# Save info for evaluation
masked_indices = indices_to_mask
masked_tokens = original_tokens
original_text = text
# Convert back to text with masks
masked_text = tokenizer.convert_tokens_to_string(masked_tokens_list)
# Print debugging info
print(f"Original tokens: {original_tokens}")
print(f"Masked indices: {indices_to_mask}")
print(f"Number of masks: {len(original_tokens)}")
return masked_text, indices_to_mask, original_tokens
def prepare_ntp_sample(text, cut_ratio=0.3):
"""Prepare a text sample for NTP by cutting off the end."""
# Tokenize text to ensure reasonable cutting
tokens = tokenizer.tokenize(text)
# Print debug info
print(f"NTP preparation - Text length: {len(text)} characters, {len(tokens)} tokens")
print(f"Cut ratio: {cut_ratio}")
# Ensure we have enough tokens
if len(tokens) < 5:
return text, "" # Return original if too short
# Calculate cutoff point based on the cut ratio
cutoff = max(3, int(len(tokens) * (1 - cut_ratio)))
cutoff = min(cutoff, len(tokens) - 1) # Ensure there's at least 1 token to predict
print(f"Cutoff point: {cutoff} (keeping {cutoff} tokens, cutting {len(tokens) - cutoff} tokens)")
# Get the visible part
visible_tokens = tokens[:cutoff]
# Get the hidden part (to be predicted)
hidden_tokens = tokens[cutoff:]
# Convert back to text
visible_text = tokenizer.convert_tokens_to_string(visible_tokens)
hidden_text = tokenizer.convert_tokens_to_string(hidden_tokens)
print(f"Visible text length: {len(visible_text)} chars")
print(f"Hidden text length: {len(hidden_text)} chars")
return visible_text, hidden_text
def get_new_sample(task, mask_ratio=0.15):
"""Get a new text sample based on the task."""
global current_sample, masked_text, masked_indices, masked_tokens, original_text, ntp_state, current_task
# Update current task
current_task = task
# Select a random sample
current_sample = random.choice(data_samples)
# Print debugging info
print(f"Getting new sample for task: {task} with mask ratio: {mask_ratio}")
if task == "mlm":
# Prepare MLM sample
masked_text, masked_indices, masked_tokens = prepare_mlm_sample(current_sample, mask_ratio)
return masked_text
else: # NTP
# Prepare NTP sample
visible_text, hidden_text = prepare_ntp_sample(current_sample, mask_ratio)
# Store original and visible for comparison
original_text = current_sample
masked_text = visible_text
# Reset NTP state for new iteration
ntp_state = {
"full_text": "",
"revealed_text": "",
"next_token_idx": 0,
"tokens": []
}
# Prepare for token-by-token prediction
prepare_next_token_prediction()
return visible_text
def check_mlm_answer(user_answers):
"""Check user MLM answers against the masked tokens."""
global user_stats
# Print for debugging
print(f"Original user input: '{user_answers}'")
# Handle the case where input is empty
if not user_answers or user_answers.isspace():
return "Please provide your answers. No input was detected."
# Basic cleanup - trim and lowercase
user_answers = user_answers.strip().lower()
print(f"After basic cleanup: '{user_answers}'")
# Explicit comma-based splitting with protection for empty entries
if ',' in user_answers:
# Split by commas and strip each item
user_tokens = [token.strip() for token in user_answers.split(',')]
# Filter out empty tokens
user_tokens = [token for token in user_tokens if token]
else:
# If no commas, split by whitespace
user_tokens = [token for token in user_answers.split() if token]
print(f"Parsed tokens: {user_tokens}, count: {len(user_tokens)}")
print(f"Expected tokens: {masked_tokens}, count: {len(masked_tokens)}")
# Ensure we have the same number of answers as masks
if len(user_tokens) != len(masked_tokens):
return f"Please provide exactly {len(masked_tokens)} answers (one for each [MASK]). You provided {len(user_tokens)}.\n\nFormat example: word1, word2, word3"
# Compare each answer
correct = 0
feedback = []
for i, (user_token, orig_token) in enumerate(zip(user_tokens, masked_tokens)):
orig_token = orig_token.lower()
# Remove ## from subword tokens for comparison
if orig_token.startswith("##"):
orig_token = orig_token[2:]
if user_token == orig_token:
correct += 1
feedback.append(f"✓ Token {i+1}: '{user_token}' is correct!")
else:
feedback.append(f"✗ Token {i+1}: '{user_token}' should be '{orig_token}'")
# Update stats
user_stats["mlm"]["correct"] += correct
user_stats["mlm"]["total"] += len(masked_tokens)
# Calculate accuracy
accuracy = correct / len(masked_tokens) if masked_tokens else 0
accuracy_percentage = accuracy * 100
# Add overall accuracy to feedback
feedback.insert(0, f"Your accuracy: {correct}/{len(masked_tokens)} ({accuracy_percentage:.1f}%)")
# Calculate overall stats
overall_accuracy = user_stats["mlm"]["correct"] / user_stats["mlm"]["total"] if user_stats["mlm"]["total"] > 0 else 0
feedback.append(f"\nOverall MLM Accuracy: {user_stats['mlm']['correct']}/{user_stats['mlm']['total']} ({overall_accuracy*100:.1f}%)")
return "\n".join(feedback)
# Variable to store NTP state
ntp_state = {
"full_text": "",
"revealed_text": "",
"next_token_idx": 0,
"tokens": []
}
def prepare_next_token_prediction():
"""Prepare for the next token prediction."""
global ntp_state, masked_text, original_text
# Get the hidden part
full_hidden = original_text[len(masked_text):].strip()
# Tokenize the hidden part
hidden_tokens = tokenizer.tokenize(full_hidden)
# Print debug info
print(f"NTP State setup:")
print(f" Full text: '{original_text}'")
print(f" Visible text: '{masked_text}'")
print(f" Hidden text: '{full_hidden}'")
print(f" Hidden tokens: {hidden_tokens}")
# Set up the NTP state
ntp_state["tokens"] = hidden_tokens
ntp_state["full_text"] = full_hidden
ntp_state["revealed_text"] = ""
ntp_state["next_token_idx"] = 0
# Make sure we have tokens to predict
if not ntp_state["tokens"]:
print("Warning: No tokens to predict, will try another sample")
# If we don't have tokens, get a new sample with a higher cut ratio
new_text = get_new_sample("ntp", 0.4) # Use higher cut ratio
prepare_next_token_prediction()
def check_ntp_answer(user_continuation):
"""Check user NTP answer for the next token only."""
global user_stats, ntp_state, masked_text
# If we haven't set up NTP state yet, do it now
if not ntp_state["tokens"]:
prepare_next_token_prediction()
# Print debug info
print(f"Current NTP state:")
print(f" Next token index: {ntp_state['next_token_idx']}")
print(f" Total tokens: {len(ntp_state['tokens'])}")
print(f" User input: '{user_continuation}'")
# No more tokens to predict
if ntp_state["next_token_idx"] >= len(ntp_state["tokens"]):
# Reset for next round
return "You've completed this prediction! Click 'New Sample' for another."
# Get the next token to predict
next_token = ntp_state["tokens"][ntp_state["next_token_idx"]]
print(f" Expected next token: '{next_token}'")
# Get user's prediction
user_text = user_continuation.strip()
# Tokenize user's prediction to get their first token
user_tokens = tokenizer.tokenize(user_text)
user_token = user_tokens[0].lower() if user_tokens else ""
print(f" User's tokenized input: {user_tokens}")
# Clean up tokens for comparison
next_token_clean = next_token.lower()
if next_token_clean.startswith("##"):
next_token_clean = next_token_clean[2:]
if user_token.startswith("##"):
user_token = user_token[2:]
# Check if correct
is_correct = (user_token == next_token_clean)
print(f" Comparison: '{user_token}' vs '{next_token_clean}' -> {'Correct' if is_correct else 'Incorrect'}")
# Update stats
if is_correct:
user_stats["ntp"]["correct"] += 1
user_stats["ntp"]["total"] += 1
# Reveal this token and prepare for next
ntp_state["revealed_text"] += tokenizer.convert_tokens_to_string([next_token])
ntp_state["next_token_idx"] += 1
# Calculate overall accuracy
overall_accuracy = user_stats["ntp"]["correct"] / user_stats["ntp"]["total"] if user_stats["ntp"]["total"] > 0 else 0
feedback = []
if is_correct:
feedback.append(f"✓ Correct! The next token was indeed '{next_token_clean}'")
else:
feedback.append(f"✗ Not quite. The actual next token was '{next_token_clean}'")
# Show progress
feedback.append(f"\nText so far: {masked_text}{ntp_state['revealed_text']}")
# If there are more tokens, prompt for next
if ntp_state["next_token_idx"] < len(ntp_state["tokens"]):
feedback.append(f"\nPredict the next token...")
else:
feedback.append(f"\nPrediction complete! Full text was:\n{original_text}")
# Show overall stats
feedback.append(f"\nOverall NTP Accuracy: {user_stats['ntp']['correct']}/{user_stats['ntp']['total']} ({overall_accuracy*100:.1f}%)")
return "\n".join(feedback)
def switch_task(task):
"""Switch between MLM and NTP tasks."""
global current_task
current_task = task
return gr.update(visible=(task == "mlm")), gr.update(visible=(task == "ntp"))
def generate_new_sample(mask_ratio):
"""Generate a new sample based on current task."""
ratio = float(mask_ratio) / 100.0 # Convert percentage to ratio
sample = get_new_sample(current_task, ratio)
return sample, ""
def check_answer(user_input, task):
"""Check user answer based on current task."""
# Make the current task visible in UI and more prominent
if task == "mlm":
return check_mlm_answer(user_input)
else: # NTP
return check_ntp_answer(user_input)
def reset_stats():
"""Reset user statistics."""
global user_stats
user_stats = {
"mlm": {"correct": 0, "total": 0},
"ntp": {"correct": 0, "total": 0}
}
return "Statistics have been reset."
# Set up Gradio interface
with gr.Blocks(title="MLM and NTP Testing") as demo:
gr.Markdown("# Language Model Testing: MLM vs NTP")
gr.Markdown("Test your skills at Masked Language Modeling (MLM) and Next Token Prediction (NTP)")
with gr.Row():
task_radio = gr.Radio(
["mlm", "ntp"],
label="Task Type",
value="mlm",
info="MLM: Guess the masked words | NTP: Predict what comes next"
)
mask_ratio = gr.Slider(
minimum=5,
maximum=50,
value=15,
step=5,
label="Mask/Cut Ratio (%)",
info="Percentage of tokens to mask (MLM) or text to hide (NTP)"
)
# Count the visible [MASK] tokens for user reference
mask_count = gr.Markdown("**Number of [MASK] tokens to guess: 0**")
sample_text = gr.Textbox(
label="Text Sample",
placeholder="Click 'New Sample' to get started",
value=get_new_sample("mlm", 0.15),
lines=10,
interactive=False
)
with gr.Row():
new_button = gr.Button("New Sample", variant="primary")
reset_button = gr.Button("Reset Stats")
# Consolidated input area - only one visible at a time
input_area = gr.Group()
with input_area:
# Task-specific input instructions
mlm_instructions = gr.Markdown("""
### MLM Instructions
1. For each [MASK] token, provide your guess for the original word.
2. Separate your answers with commas.
3. Make sure you provide exactly the same number of answers as [MASK] tokens.
**Example format:** `word1, word2, word3` or `word1,word2,word3`
""", visible=True)
ntp_instructions = gr.Markdown("""
### NTP Instructions
Predict the next word or token that would follow the text.
Type a single word or token for each prediction.
""", visible=False)
# Unified input box
answer_input = gr.Textbox(
label="Your answer",
placeholder="For MLM: word1, word2, word3 | For NTP: single word",
lines=1
)
with gr.Row():
check_button = gr.Button("Check Answer", variant="primary")
result = gr.Textbox(label="Result", lines=6)
# Function to switch task type
def switch_task_unified(task):
if task == "mlm":
mask_text = f"**Number of [MASK] tokens to guess: {len(masked_tokens)}**"
return (
gr.update(visible=True), # mlm_instructions
gr.update(visible=False), # ntp_instructions
gr.update(placeholder="comma-separated answers (e.g., word1, word2, word3)"),
mask_text
)
else: # ntp
return (
gr.update(visible=False), # mlm_instructions
gr.update(visible=True), # ntp_instructions
gr.update(placeholder="Type the next word/token you predict"),
"**Next Token Prediction mode - guess one token at a time**"
)
# Set up event handlers
task_radio.change(
switch_task_unified,
inputs=[task_radio],
outputs=[mlm_instructions, ntp_instructions, answer_input, mask_count]
)
# Update the sample text when mask ratio changes (without clicking new sample)
def update_on_ratio_change(mask_ratio_pct, task):
print(f"Ratio changed to {mask_ratio_pct}%")
# Don't generate a new sample here, just update the UI to show the effect of ratio change
return f"Current mask/cut ratio: {mask_ratio_pct}%. Click 'New Sample' to apply."
mask_ratio.change(
update_on_ratio_change,
inputs=[mask_ratio, task_radio],
outputs=[result]
)
# Update the sample text and also update the mask count
def new_sample_with_count(mask_ratio_pct, task):
print(f"Generating new sample with mask ratio: {mask_ratio_pct}% for task: {task}")
ratio = float(mask_ratio_pct) / 100.0
sample = get_new_sample(task, ratio)
mask_count_text = ""
if task == "mlm":
count = len(masked_tokens)
mask_count_text = f"**Number of [MASK] tokens to guess: {count}**"
print(f"Generated MLM sample with {count} masks at ratio {ratio}")
else:
mask_count_text = "**Next Token Prediction mode - guess one token at a time**"
print(f"Generated NTP sample with cut ratio {ratio}")
return sample, mask_count_text, ""
new_button.click(
new_sample_with_count,
inputs=[mask_ratio, task_radio],
outputs=[sample_text, mask_count, result]
)
reset_button.click(reset_stats, inputs=None, outputs=[result])
# Unified check answer function
def unified_check_answer(user_input, task):
if task == "mlm":
return check_mlm_answer(user_input)
else: # ntp
return check_ntp_answer(user_input)
check_button.click(
unified_check_answer,
inputs=[answer_input, task_radio],
outputs=[result]
)
answer_input.submit(
unified_check_answer,
inputs=[answer_input, task_radio],
outputs=[result]
)
demo.launch()