import gradio as gr import random import re import numpy as np from datasets import load_dataset from transformers import AutoTokenizer # Load tokenizer tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") # Initialize variables to track stats user_stats = { "mlm": {"correct": 0, "total": 0}, "ntp": {"correct": 0, "total": 0} } # Function to load and sample from the requested dataset def load_sample_data(sample_size=100): try: # Try to load the requested dataset dataset = load_dataset("mlfoundations/dclm-baseline-1.0-parquet", streaming=True) dataset_field = "text" # Assuming the field name is "text" except Exception as e: print(f"Error loading requested dataset: {e}") # Fallback to cc_news if there's an issue dataset = load_dataset("vblagoje/cc_news", streaming=True) dataset_field = "text" # Sample from the dataset samples = [] for i, example in enumerate(dataset["train"]): if i >= sample_size: break # Get text from the appropriate field if dataset_field in example and example[dataset_field]: # Clean text by removing extra whitespaces text = re.sub(r'\s+', ' ', example[dataset_field]).strip() # Only include longer texts to make the task meaningful if len(text.split()) > 20: # Truncate to two sentences sentences = re.split(r'(?<=[.!?])\s+', text) if len(sentences) >= 2: # Take only the first two sentences two_sentence_text = ' '.join(sentences[:2]) samples.append(two_sentence_text) return samples # Load data at startup data_samples = load_sample_data(100) current_sample = None masked_text = "" original_text = "" masked_indices = [] masked_tokens = [] current_task = "mlm" def prepare_mlm_sample(text, mask_ratio=0.15): """Prepare a text sample for MLM by masking random tokens.""" global masked_indices, masked_tokens, original_text tokens = tokenizer.tokenize(text) print(f"Text length: {len(text)} characters, {len(tokens)} tokens") # Only mask whole words, not special tokens or punctuation maskable_indices = [i for i, token in enumerate(tokens) if not token.startswith("##") and not token.startswith("[") and not token.endswith("]") and token not in [".", ",", "!", "?", ";", ":", "'", "\"", "-"]] print(f"Maskable indices count: {len(maskable_indices)}") print(f"Mask ratio: {mask_ratio}") # Calculate how many tokens to mask based on the mask ratio # No arbitrary cap - use the actual percentage num_to_mask = max(1, int(len(maskable_indices) * mask_ratio)) print(f"Number of tokens to mask: {num_to_mask}") # Randomly select indices to mask indices_to_mask = random.sample(maskable_indices, min(num_to_mask, len(maskable_indices))) # Sort indices to ensure they're in order indices_to_mask.sort() # Create a copy of tokens to mask masked_tokens_list = tokens.copy() original_tokens = [] # Replace selected tokens with [MASK] for idx in indices_to_mask: original_tokens.append(masked_tokens_list[idx]) masked_tokens_list[idx] = "[MASK]" # Save info for evaluation masked_indices = indices_to_mask masked_tokens = original_tokens original_text = text # Convert back to text with masks masked_text = tokenizer.convert_tokens_to_string(masked_tokens_list) # Print debugging info print(f"Original tokens: {original_tokens}") print(f"Masked indices: {indices_to_mask}") print(f"Number of masks: {len(original_tokens)}") return masked_text, indices_to_mask, original_tokens def prepare_ntp_sample(text, cut_ratio=0.3): """Prepare a text sample for NTP by cutting off the end.""" # Tokenize text to ensure reasonable cutting tokens = tokenizer.tokenize(text) # Print debug info print(f"NTP preparation - Text length: {len(text)} characters, {len(tokens)} tokens") print(f"Cut ratio: {cut_ratio}") # Ensure we have enough tokens if len(tokens) < 5: return text, "" # Return original if too short # Calculate cutoff point based on the cut ratio cutoff = max(3, int(len(tokens) * (1 - cut_ratio))) cutoff = min(cutoff, len(tokens) - 1) # Ensure there's at least 1 token to predict print(f"Cutoff point: {cutoff} (keeping {cutoff} tokens, cutting {len(tokens) - cutoff} tokens)") # Get the visible part visible_tokens = tokens[:cutoff] # Get the hidden part (to be predicted) hidden_tokens = tokens[cutoff:] # Convert back to text visible_text = tokenizer.convert_tokens_to_string(visible_tokens) hidden_text = tokenizer.convert_tokens_to_string(hidden_tokens) print(f"Visible text length: {len(visible_text)} chars") print(f"Hidden text length: {len(hidden_text)} chars") return visible_text, hidden_text def get_new_sample(task, mask_ratio=0.15): """Get a new text sample based on the task.""" global current_sample, masked_text, masked_indices, masked_tokens, original_text, ntp_state, current_task # Update current task current_task = task # Select a random sample current_sample = random.choice(data_samples) # Print debugging info print(f"Getting new sample for task: {task} with mask ratio: {mask_ratio}") if task == "mlm": # Prepare MLM sample masked_text, masked_indices, masked_tokens = prepare_mlm_sample(current_sample, mask_ratio) return masked_text else: # NTP # Prepare NTP sample visible_text, hidden_text = prepare_ntp_sample(current_sample, mask_ratio) # Store original and visible for comparison original_text = current_sample masked_text = visible_text # Reset NTP state for new iteration ntp_state = { "full_text": "", "revealed_text": "", "next_token_idx": 0, "tokens": [] } # Prepare for token-by-token prediction prepare_next_token_prediction() return visible_text def check_mlm_answer(user_answers): """Check user MLM answers against the masked tokens.""" global user_stats # Print for debugging print(f"Original user input: '{user_answers}'") # Handle the case where input is empty if not user_answers or user_answers.isspace(): return "Please provide your answers. No input was detected." # Basic cleanup - trim and lowercase user_answers = user_answers.strip().lower() print(f"After basic cleanup: '{user_answers}'") # Explicit comma-based splitting with protection for empty entries if ',' in user_answers: # Split by commas and strip each item user_tokens = [token.strip() for token in user_answers.split(',')] # Filter out empty tokens user_tokens = [token for token in user_tokens if token] else: # If no commas, split by whitespace user_tokens = [token for token in user_answers.split() if token] print(f"Parsed tokens: {user_tokens}, count: {len(user_tokens)}") print(f"Expected tokens: {masked_tokens}, count: {len(masked_tokens)}") # Ensure we have the same number of answers as masks if len(user_tokens) != len(masked_tokens): return f"Please provide exactly {len(masked_tokens)} answers (one for each [MASK]). You provided {len(user_tokens)}.\n\nFormat example: word1, word2, word3" # Compare each answer correct = 0 feedback = [] for i, (user_token, orig_token) in enumerate(zip(user_tokens, masked_tokens)): orig_token = orig_token.lower() # Remove ## from subword tokens for comparison if orig_token.startswith("##"): orig_token = orig_token[2:] if user_token == orig_token: correct += 1 feedback.append(f"✓ Token {i+1}: '{user_token}' is correct!") else: feedback.append(f"✗ Token {i+1}: '{user_token}' should be '{orig_token}'") # Update stats user_stats["mlm"]["correct"] += correct user_stats["mlm"]["total"] += len(masked_tokens) # Calculate accuracy accuracy = correct / len(masked_tokens) if masked_tokens else 0 accuracy_percentage = accuracy * 100 # Add overall accuracy to feedback feedback.insert(0, f"Your accuracy: {correct}/{len(masked_tokens)} ({accuracy_percentage:.1f}%)") # Calculate overall stats overall_accuracy = user_stats["mlm"]["correct"] / user_stats["mlm"]["total"] if user_stats["mlm"]["total"] > 0 else 0 feedback.append(f"\nOverall MLM Accuracy: {user_stats['mlm']['correct']}/{user_stats['mlm']['total']} ({overall_accuracy*100:.1f}%)") return "\n".join(feedback) # Variable to store NTP state ntp_state = { "full_text": "", "revealed_text": "", "next_token_idx": 0, "tokens": [] } def prepare_next_token_prediction(): """Prepare for the next token prediction.""" global ntp_state, masked_text, original_text # Get the hidden part full_hidden = original_text[len(masked_text):].strip() # Tokenize the hidden part hidden_tokens = tokenizer.tokenize(full_hidden) # Print debug info print(f"NTP State setup:") print(f" Full text: '{original_text}'") print(f" Visible text: '{masked_text}'") print(f" Hidden text: '{full_hidden}'") print(f" Hidden tokens: {hidden_tokens}") # Set up the NTP state ntp_state["tokens"] = hidden_tokens ntp_state["full_text"] = full_hidden ntp_state["revealed_text"] = "" ntp_state["next_token_idx"] = 0 # Make sure we have tokens to predict if not ntp_state["tokens"]: print("Warning: No tokens to predict, will try another sample") # If we don't have tokens, get a new sample with a higher cut ratio new_text = get_new_sample("ntp", 0.4) # Use higher cut ratio prepare_next_token_prediction() def check_ntp_answer(user_continuation): """Check user NTP answer for the next token only.""" global user_stats, ntp_state, masked_text # If we haven't set up NTP state yet, do it now if not ntp_state["tokens"]: prepare_next_token_prediction() # Print debug info print(f"Current NTP state:") print(f" Next token index: {ntp_state['next_token_idx']}") print(f" Total tokens: {len(ntp_state['tokens'])}") print(f" User input: '{user_continuation}'") # No more tokens to predict if ntp_state["next_token_idx"] >= len(ntp_state["tokens"]): # Reset for next round return "You've completed this prediction! Click 'New Sample' for another." # Get the next token to predict next_token = ntp_state["tokens"][ntp_state["next_token_idx"]] print(f" Expected next token: '{next_token}'") # Get user's prediction user_text = user_continuation.strip() # Tokenize user's prediction to get their first token user_tokens = tokenizer.tokenize(user_text) user_token = user_tokens[0].lower() if user_tokens else "" print(f" User's tokenized input: {user_tokens}") # Clean up tokens for comparison next_token_clean = next_token.lower() if next_token_clean.startswith("##"): next_token_clean = next_token_clean[2:] if user_token.startswith("##"): user_token = user_token[2:] # Check if correct is_correct = (user_token == next_token_clean) print(f" Comparison: '{user_token}' vs '{next_token_clean}' -> {'Correct' if is_correct else 'Incorrect'}") # Update stats if is_correct: user_stats["ntp"]["correct"] += 1 user_stats["ntp"]["total"] += 1 # Reveal this token and prepare for next ntp_state["revealed_text"] += tokenizer.convert_tokens_to_string([next_token]) ntp_state["next_token_idx"] += 1 # Calculate overall accuracy overall_accuracy = user_stats["ntp"]["correct"] / user_stats["ntp"]["total"] if user_stats["ntp"]["total"] > 0 else 0 feedback = [] if is_correct: feedback.append(f"✓ Correct! The next token was indeed '{next_token_clean}'") else: feedback.append(f"✗ Not quite. The actual next token was '{next_token_clean}'") # Show progress feedback.append(f"\nText so far: {masked_text}{ntp_state['revealed_text']}") # If there are more tokens, prompt for next if ntp_state["next_token_idx"] < len(ntp_state["tokens"]): feedback.append(f"\nPredict the next token...") else: feedback.append(f"\nPrediction complete! Full text was:\n{original_text}") # Show overall stats feedback.append(f"\nOverall NTP Accuracy: {user_stats['ntp']['correct']}/{user_stats['ntp']['total']} ({overall_accuracy*100:.1f}%)") return "\n".join(feedback) def switch_task(task): """Switch between MLM and NTP tasks.""" global current_task current_task = task return gr.update(visible=(task == "mlm")), gr.update(visible=(task == "ntp")) def generate_new_sample(mask_ratio): """Generate a new sample based on current task.""" ratio = float(mask_ratio) / 100.0 # Convert percentage to ratio sample = get_new_sample(current_task, ratio) return sample, "" def check_answer(user_input, task): """Check user answer based on current task.""" # Make the current task visible in UI and more prominent if task == "mlm": return check_mlm_answer(user_input) else: # NTP return check_ntp_answer(user_input) def reset_stats(): """Reset user statistics.""" global user_stats user_stats = { "mlm": {"correct": 0, "total": 0}, "ntp": {"correct": 0, "total": 0} } return "Statistics have been reset." # Set up Gradio interface with gr.Blocks(title="MLM and NTP Testing") as demo: gr.Markdown("# Language Model Testing: MLM vs NTP") gr.Markdown("Test your skills at Masked Language Modeling (MLM) and Next Token Prediction (NTP)") with gr.Row(): task_radio = gr.Radio( ["mlm", "ntp"], label="Task Type", value="mlm", info="MLM: Guess the masked words | NTP: Predict what comes next" ) mask_ratio = gr.Slider( minimum=5, maximum=50, value=15, step=5, label="Mask/Cut Ratio (%)", info="Percentage of tokens to mask (MLM) or text to hide (NTP)" ) # Count the visible [MASK] tokens for user reference mask_count = gr.Markdown("**Number of [MASK] tokens to guess: 0**") sample_text = gr.Textbox( label="Text Sample", placeholder="Click 'New Sample' to get started", value=get_new_sample("mlm", 0.15), lines=10, interactive=False ) with gr.Row(): new_button = gr.Button("New Sample", variant="primary") reset_button = gr.Button("Reset Stats") # Consolidated input area - only one visible at a time input_area = gr.Group() with input_area: # Task-specific input instructions mlm_instructions = gr.Markdown(""" ### MLM Instructions 1. For each [MASK] token, provide your guess for the original word. 2. Separate your answers with commas. 3. Make sure you provide exactly the same number of answers as [MASK] tokens. **Example format:** `word1, word2, word3` or `word1,word2,word3` """, visible=True) ntp_instructions = gr.Markdown(""" ### NTP Instructions Predict the next word or token that would follow the text. Type a single word or token for each prediction. """, visible=False) # Unified input box answer_input = gr.Textbox( label="Your answer", placeholder="For MLM: word1, word2, word3 | For NTP: single word", lines=1 ) with gr.Row(): check_button = gr.Button("Check Answer", variant="primary") result = gr.Textbox(label="Result", lines=6) # Function to switch task type def switch_task_unified(task): if task == "mlm": mask_text = f"**Number of [MASK] tokens to guess: {len(masked_tokens)}**" return ( gr.update(visible=True), # mlm_instructions gr.update(visible=False), # ntp_instructions gr.update(placeholder="comma-separated answers (e.g., word1, word2, word3)"), mask_text ) else: # ntp return ( gr.update(visible=False), # mlm_instructions gr.update(visible=True), # ntp_instructions gr.update(placeholder="Type the next word/token you predict"), "**Next Token Prediction mode - guess one token at a time**" ) # Set up event handlers task_radio.change( switch_task_unified, inputs=[task_radio], outputs=[mlm_instructions, ntp_instructions, answer_input, mask_count] ) # Update the sample text when mask ratio changes (without clicking new sample) def update_on_ratio_change(mask_ratio_pct, task): print(f"Ratio changed to {mask_ratio_pct}%") # Don't generate a new sample here, just update the UI to show the effect of ratio change return f"Current mask/cut ratio: {mask_ratio_pct}%. Click 'New Sample' to apply." mask_ratio.change( update_on_ratio_change, inputs=[mask_ratio, task_radio], outputs=[result] ) # Update the sample text and also update the mask count def new_sample_with_count(mask_ratio_pct, task): print(f"Generating new sample with mask ratio: {mask_ratio_pct}% for task: {task}") ratio = float(mask_ratio_pct) / 100.0 sample = get_new_sample(task, ratio) mask_count_text = "" if task == "mlm": count = len(masked_tokens) mask_count_text = f"**Number of [MASK] tokens to guess: {count}**" print(f"Generated MLM sample with {count} masks at ratio {ratio}") else: mask_count_text = "**Next Token Prediction mode - guess one token at a time**" print(f"Generated NTP sample with cut ratio {ratio}") return sample, mask_count_text, "" new_button.click( new_sample_with_count, inputs=[mask_ratio, task_radio], outputs=[sample_text, mask_count, result] ) reset_button.click(reset_stats, inputs=None, outputs=[result]) # Unified check answer function def unified_check_answer(user_input, task): if task == "mlm": return check_mlm_answer(user_input) else: # ntp return check_ntp_answer(user_input) check_button.click( unified_check_answer, inputs=[answer_input, task_radio], outputs=[result] ) answer_input.submit( unified_check_answer, inputs=[answer_input, task_radio], outputs=[result] ) demo.launch()