import json import torch import torch.nn as nn from torch.utils.data import Dataset from transformers import AutoModel import re from typing import List, Dict, Any import warnings import logging import os # Disable tokenizer parallelism to avoid forking warnings os.environ["TOKENIZERS_PARALLELISM"] = "false" warnings.filterwarnings('ignore') class InstructionDataset(Dataset): def __init__(self, data_path: str, tokenizer, max_length: int = 512, is_training: bool = True, window_size: int = 512, overlap: int = 100): self.tokenizer = tokenizer self.max_length = max_length self.is_training = is_training self.window_size = window_size self.overlap = overlap # Load and process data self.raw_data = self._load_and_process_data(data_path) # Create sliding windows at subword level (eliminates all data loss) self.processed_data = self._create_subword_sliding_windows(self.raw_data) def _load_and_process_data(self, data_path: str) -> List[Dict[str, Any]]: """Load JSONL data and process it for token classification""" logger = logging.getLogger(__name__) processed_data = [] skipped_count = 0 sanity_check_failed = 0 total_instruction_tokens = 0 total_non_instruction_tokens = 0 logger.info(f"Loading data from: {data_path}") with open(data_path, 'r', encoding='utf-8') as f: for line_num, line in enumerate(f, 1): try: data = json.loads(line.strip()) # Skip data points that failed sanity check sanity_check = data.get('sanity_check', False) # Default to False if not present if sanity_check is False: sanity_check_failed += 1 continue # Extract labeled text labeled_text = data.get('label_text', '') # Remove ... tags if present if labeled_text.startswith("") and labeled_text.endswith(""): labeled_text = labeled_text[len(""):-len("")] labeled_text = labeled_text.strip() sample_id = data.get('id', f'sample_{line_num}') # Process the tagged text processed_sample = self._process_tagged_text(labeled_text, sample_id) if processed_sample is not None: processed_data.append(processed_sample) # Count token distribution for debugging labels = processed_sample['labels'] sample_instruction_tokens = sum(1 for label in labels if label == 1) total_instruction_tokens += sample_instruction_tokens total_non_instruction_tokens += len(labels) - sample_instruction_tokens else: skipped_count += 1 except Exception as e: logger.error(f"Error processing line {line_num}: {e}") skipped_count += 1 logger.info(f"Successfully processed {len(processed_data)} samples") logger.info(f"Skipped {skipped_count} samples due to errors or malformed data") logger.info(f"Skipped {sanity_check_failed} samples due to failed sanity check") logger.info(f"Token distribution - Instruction: {total_instruction_tokens}, Non-instruction: {total_non_instruction_tokens}") if total_instruction_tokens == 0: logger.warning("No instruction tokens found! This will cause training issues.") if total_non_instruction_tokens == 0: logger.warning("No non-instruction tokens found! This will cause training issues.") return processed_data def _process_tagged_text(self, labeled_text: str, sample_id: str) -> Dict[str, Any] | None: """Process tagged text to extract tokens and labels""" logger = logging.getLogger(__name__) try: # Keep original casing since XLM-RoBERTa is case-sensitive # labeled_text = labeled_text.lower() # Removed for cased model # Find all instruction tags instruction_pattern = r'(.*?)' matches = list(re.finditer(instruction_pattern, labeled_text, re.DOTALL)) # Check for malformed tags or edge cases if '' in labeled_text and '' not in labeled_text: return None if '' in labeled_text and '' not in labeled_text: return None # Create character-level labels char_labels = [0] * len(labeled_text) # Mark instruction characters for match in matches: start, end = match.span() # Mark the content inside tags as instruction (1) content_start = start + len('') content_end = end - len('') for i in range(content_start, content_end): char_labels[i] = 1 # Remove tags and adjust labels clean_text = re.sub(instruction_pattern, r'\1', labeled_text) # Recalculate labels for clean text clean_char_labels = [] original_idx = 0 for char in clean_text: # Skip over tag characters in original text while original_idx < len(labeled_text) and labeled_text[original_idx] in '<>/': if labeled_text[original_idx:original_idx+13] == '': original_idx += 13 elif labeled_text[original_idx:original_idx+14] == '': original_idx += 14 else: original_idx += 1 if original_idx < len(char_labels): clean_char_labels.append(char_labels[original_idx]) else: clean_char_labels.append(0) original_idx += 1 # Tokenize and align labels tokens = clean_text.split() token_labels = [] char_idx = 0 for token in tokens: # Skip whitespace while char_idx < len(clean_text) and clean_text[char_idx].isspace(): char_idx += 1 # Check if any character in this token is labeled as instruction token_is_instruction = False for i in range(len(token)): if char_idx + i < len(clean_char_labels) and clean_char_labels[char_idx + i] == 1: token_is_instruction = True break token_labels.append(1 if token_is_instruction else 0) char_idx += len(token) return { 'id': sample_id, 'tokens': tokens, 'labels': token_labels, 'original_text': clean_text } except Exception as e: logger.error(f"Error processing sample {sample_id}: {e}") return None def _create_subword_sliding_windows(self, raw_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """Create sliding windows at subword level - eliminates all data loss and mismatch issues""" logger = logging.getLogger(__name__) windowed_data = [] logger.info(f"Creating subword-level sliding windows:") logger.info(f" Window size: {self.max_length} subword tokens") logger.info(f" Overlap: {self.overlap} subword tokens") logger.info(f" Label strategy: BERT paper approach (first subtoken only)") total_original_samples = len(raw_data) total_windows = 0 samples_with_multiple_windows = 0 # Word split tracking total_words_processed = 0 total_words_split_across_windows = 0 samples_with_split_words = 0 for sample in raw_data: words = sample['tokens'] word_labels = sample['labels'] sample_id = sample['id'] encoded = self.tokenizer( words, is_split_into_words=True, add_special_tokens=True, # Include [CLS], [SEP] truncation=False, # We handle long sequences with sliding windows padding=False, return_tensors='pt' ) subword_tokens = encoded['input_ids'][0].tolist() word_ids = encoded.word_ids() # Step 2: Create aligned subword labels (BERT paper approach) # Only the FIRST subtoken of each word gets the real label, rest get -100 subword_labels = [] prev_word_id = None for word_id in word_ids: if word_id is None: subword_labels.append(-100) # Special tokens [CLS], [SEP] elif word_id != prev_word_id: # First subtoken of a new word - assign the real label subword_labels.append(word_labels[word_id]) prev_word_id = word_id else: # Subsequent subtoken of the same word - assign dummy label subword_labels.append(-100) # prev_word_id remains the same # Step 3: Create sliding windows at subword level if len(subword_tokens) <= self.max_length: # Single window - no word splits possible windowed_data.append({ 'subword_tokens': subword_tokens, 'subword_labels': subword_labels, 'original_words': words, 'original_labels': word_labels, 'sample_id': sample_id, 'window_id': 0, 'total_windows': 1, 'window_start': 0, 'window_end': len(subword_tokens), 'original_text': sample['original_text'] }) total_windows += 1 total_words_processed += len(words) else: # Multiple windows needed step = self.max_length - self.overlap window_count = 0 split_words_this_sample = set() for start in range(0, len(subword_tokens), step): end = min(start + self.max_length, len(subword_tokens)) # Extract subword window window_subword_tokens = subword_tokens[start:end] window_subword_labels = subword_labels[start:end] # Track word splits for this window window_word_ids = word_ids[start:end] if word_ids else [] window_words_set = set(wid for wid in window_word_ids if wid is not None) # Find which words are split across window boundaries for word_idx in window_words_set: if word_idx is not None: # Check if this word's subwords extend beyond current window word_subword_positions = [i for i, wid in enumerate(word_ids) if wid == word_idx] word_start_pos = min(word_subword_positions) word_end_pos = max(word_subword_positions) # Word is split if it extends beyond current window boundaries if word_start_pos < start or word_end_pos >= end: split_words_this_sample.add(word_idx) # Get original words for this window (for debugging/inspection) window_word_indices = list(window_words_set) window_original_words = [words[i] for i in window_word_indices if i < len(words)] window_original_labels = [word_labels[i] for i in window_word_indices if i < len(words)] windowed_data.append({ 'subword_tokens': window_subword_tokens, 'subword_labels': window_subword_labels, 'original_words': window_original_words, # For reference only 'original_labels': window_original_labels, # For reference only 'sample_id': sample_id, 'window_id': window_count, 'total_windows': -1, # Will be filled later 'window_start': start, 'window_end': end, 'original_text': sample['original_text'] }) window_count += 1 total_windows += 1 # Break if we've covered all subword tokens if end >= len(subword_tokens): break # Update total_windows for this sample for i in range(len(windowed_data) - window_count, len(windowed_data)): windowed_data[i]['total_windows'] = window_count # Track word split statistics total_words_processed += len(words) total_words_split_across_windows += len(split_words_this_sample) if len(split_words_this_sample) > 0: samples_with_split_words += 1 if window_count > 1: samples_with_multiple_windows += 1 # Calculate word split statistics word_split_percentage = (total_words_split_across_windows / total_words_processed * 100) if total_words_processed > 0 else 0 logger.info(f"=== Subword Sliding Window Statistics ===") logger.info(f" Original samples: {total_original_samples}") logger.info(f" Total windows created: {total_windows}") logger.info(f" Samples split into multiple windows: {samples_with_multiple_windows}") logger.info(f" Average windows per sample: {total_windows / total_original_samples:.2f}") logger.info(f"=== Word Split Analysis ===") logger.info(f" Total words processed: {total_words_processed:,}") logger.info(f" Words split across windows: {total_words_split_across_windows:,}") logger.info(f" Word split rate: {word_split_percentage:.2f}%") logger.info(f" Samples with split words: {samples_with_split_words} / {total_original_samples}") if word_split_percentage > 10.0: logger.warning(f" HIGH WORD SPLIT RATE: {word_split_percentage:.1f}% - consider larger overlap") elif word_split_percentage > 5.0: logger.warning(f" Moderate word splitting: {word_split_percentage:.1f}% - monitor model performance") else: logger.info(f" Excellent word preservation: {100 - word_split_percentage:.1f}% of words intact") logger.info(f"✅ ZERO DATA LOSS: All subword tokens processed exactly once") logger.info(f"📋 BERT PAPER APPROACH: Only first subtokens carry labels for training/evaluation") return windowed_data def __len__(self): return len(self.processed_data) def __getitem__(self, idx): window_data = self.processed_data[idx] subword_tokens = window_data['subword_tokens'] subword_labels = window_data['subword_labels'] # Convert subword tokens to padded tensors (no tokenization needed!) input_ids = subword_tokens[:self.max_length] # Guaranteed to fit # Pad to max_length if needed pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id while len(input_ids) < self.max_length: input_ids.append(pad_token_id) # Create attention mask (1 for real tokens, 0 for padding) attention_mask = [1 if token != pad_token_id else 0 for token in input_ids] # Pad labels to match labels = subword_labels[:self.max_length] while len(labels) < self.max_length: labels.append(-100) # Ignore padding tokens in loss return { 'input_ids': torch.tensor(input_ids, dtype=torch.long), 'attention_mask': torch.tensor(attention_mask, dtype=torch.long), 'labels': torch.tensor(labels, dtype=torch.long), 'original_tokens': window_data['original_words'], # Original words for reference 'original_labels': window_data['original_labels'], # Original word labels # Add window metadata for evaluation aggregation 'sample_id': window_data['sample_id'], 'window_id': window_data['window_id'], 'total_windows': window_data['total_windows'], 'window_start': window_data['window_start'], 'window_end': window_data['window_end'] } class TransformerInstructionClassifier(nn.Module): def __init__(self, model_name: str = 'xlm-roberta-base', num_labels: int = 2, class_weights=None, loss_type='weighted_ce', dropout: float = 0.1): super().__init__() self.num_labels = num_labels self.loss_type = loss_type # Load pre-trained transformer model (XLM-RoBERTa, ModernBERT, etc.) self.bert = AutoModel.from_pretrained(model_name) self.dropout = nn.Dropout(dropout) # Classification head self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels) # Setup loss function based on type if loss_type == 'weighted_ce': self.loss_fct = nn.CrossEntropyLoss(ignore_index=-100, weight=class_weights) else: self.loss_fct = nn.CrossEntropyLoss(ignore_index=-100) def forward(self, input_ids, attention_mask, labels=None): # Get BERT outputs outputs = self.bert( input_ids=input_ids, attention_mask=attention_mask ) # Get last hidden state last_hidden_state = outputs.last_hidden_state # Apply dropout last_hidden_state = self.dropout(last_hidden_state) # Classification logits = self.classifier(last_hidden_state) loss = None if labels is not None: logger = logging.getLogger(__name__) # Check for NaN in inputs before loss calculation if torch.isnan(logits).any(): logger.warning("NaN detected in logits!") if torch.isnan(labels.float()).any(): logger.warning("NaN detected in labels!") loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) # Check if loss is NaN if torch.isnan(loss): logger.warning("NaN loss detected!") logger.warning(f"Logits stats: min={logits.min()}, max={logits.max()}, mean={logits.mean()}") logger.warning(f"Labels unique values: {torch.unique(labels[labels != -100])}") return { 'loss': loss, 'logits': logits } def collate_fn(batch): """Custom collate function for DataLoader""" input_ids = torch.stack([item['input_ids'] for item in batch]) attention_mask = torch.stack([item['attention_mask'] for item in batch]) labels = torch.stack([item['labels'] for item in batch]) return { 'input_ids': input_ids, 'attention_mask': attention_mask, 'labels': labels, 'original_tokens': [item['original_tokens'] for item in batch], 'original_labels': [item['original_labels'] for item in batch], # Add window metadata 'sample_ids': [item['sample_id'] for item in batch], 'window_ids': [item['window_id'] for item in batch], 'total_windows': [item['total_windows'] for item in batch], 'window_starts': [item['window_start'] for item in batch], 'window_ends': [item['window_end'] for item in batch] } def predict_instructions(model, tokenizer, text: str, device=None, threshold=0.4): """Predict instructions in a given text Args: model: The trained instruction classifier model tokenizer: The tokenizer for the model text: Input text to analyze device: Device to run inference on threshold: Probability threshold for classifying tokens as INSTRUCTION. Lower values = more aggressive detection (default: 0.4) Returns: tuple: (tokens, predictions) where predictions are 0=OTHER, 1=INSTRUCTION """ # Auto-detect device if not provided if device is None: if torch.backends.mps.is_available(): device = torch.device('mps') elif torch.cuda.is_available(): device = torch.device('cuda') else: device = torch.device('cpu') model.eval() # Keep original casing since XLM-RoBERTa is case-sensitive # text = text.lower() # Removed for cased model tokens = text.split() # Tokenize encoded = tokenizer( tokens, is_split_into_words=True, padding='max_length', truncation=True, max_length=512, return_tensors='pt' ) input_ids = encoded['input_ids'].to(device) attention_mask = encoded['attention_mask'].to(device) with torch.no_grad(): outputs = model(input_ids=input_ids, attention_mask=attention_mask) # Convert logits to probabilities probs = torch.softmax(outputs['logits'], dim=-1) # Use threshold on probability of class 1 (INSTRUCTION) instead of argmax # This makes the classifier more aggressive - tokens are classified as INSTRUCTION # if their probability of being INSTRUCTION is above the threshold predictions = (probs[:, :, 1] > threshold).long() # Align predictions with original tokens word_ids = encoded.word_ids() word_predictions = [] prev_word_id = None for i, word_id in enumerate(word_ids): if word_id is not None and word_id != prev_word_id: if word_id < len(tokens): word_predictions.append(predictions[0][i].item()) prev_word_id = word_id return tokens, word_predictions def get_device(): """Get the best available device""" if torch.backends.mps.is_available(): return torch.device('mps') elif torch.cuda.is_available(): return torch.device('cuda') else: return torch.device('cpu')