import json
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from transformers import AutoModel
import re
from typing import List, Dict, Any
import warnings
import logging
import os
# Disable tokenizer parallelism to avoid forking warnings
os.environ["TOKENIZERS_PARALLELISM"] = "false"
warnings.filterwarnings('ignore')
class InstructionDataset(Dataset):
def __init__(self, data_path: str, tokenizer, max_length: int = 512, is_training: bool = True,
window_size: int = 512, overlap: int = 100):
self.tokenizer = tokenizer
self.max_length = max_length
self.is_training = is_training
self.window_size = window_size
self.overlap = overlap
# Load and process data
self.raw_data = self._load_and_process_data(data_path)
# Create sliding windows at subword level (eliminates all data loss)
self.processed_data = self._create_subword_sliding_windows(self.raw_data)
def _load_and_process_data(self, data_path: str) -> List[Dict[str, Any]]:
"""Load JSONL data and process it for token classification"""
logger = logging.getLogger(__name__)
processed_data = []
skipped_count = 0
sanity_check_failed = 0
total_instruction_tokens = 0
total_non_instruction_tokens = 0
logger.info(f"Loading data from: {data_path}")
with open(data_path, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(f, 1):
try:
data = json.loads(line.strip())
# Skip data points that failed sanity check
sanity_check = data.get('sanity_check', False) # Default to False if not present
if sanity_check is False:
sanity_check_failed += 1
continue
# Extract labeled text
labeled_text = data.get('label_text', '')
# Remove ... tags if present
if labeled_text.startswith("") and labeled_text.endswith(""):
labeled_text = labeled_text[len(""):-len("")]
labeled_text = labeled_text.strip()
sample_id = data.get('id', f'sample_{line_num}')
# Process the tagged text
processed_sample = self._process_tagged_text(labeled_text, sample_id)
if processed_sample is not None:
processed_data.append(processed_sample)
# Count token distribution for debugging
labels = processed_sample['labels']
sample_instruction_tokens = sum(1 for label in labels if label == 1)
total_instruction_tokens += sample_instruction_tokens
total_non_instruction_tokens += len(labels) - sample_instruction_tokens
else:
skipped_count += 1
except Exception as e:
logger.error(f"Error processing line {line_num}: {e}")
skipped_count += 1
logger.info(f"Successfully processed {len(processed_data)} samples")
logger.info(f"Skipped {skipped_count} samples due to errors or malformed data")
logger.info(f"Skipped {sanity_check_failed} samples due to failed sanity check")
logger.info(f"Token distribution - Instruction: {total_instruction_tokens}, Non-instruction: {total_non_instruction_tokens}")
if total_instruction_tokens == 0:
logger.warning("No instruction tokens found! This will cause training issues.")
if total_non_instruction_tokens == 0:
logger.warning("No non-instruction tokens found! This will cause training issues.")
return processed_data
def _process_tagged_text(self, labeled_text: str, sample_id: str) -> Dict[str, Any] | None:
"""Process tagged text to extract tokens and labels"""
logger = logging.getLogger(__name__)
try:
# Keep original casing since XLM-RoBERTa is case-sensitive
# labeled_text = labeled_text.lower() # Removed for cased model
# Find all instruction tags
instruction_pattern = r'(.*?)'
matches = list(re.finditer(instruction_pattern, labeled_text, re.DOTALL))
# Check for malformed tags or edge cases
if '' in labeled_text and '' not in labeled_text:
return None
if '' in labeled_text and '' not in labeled_text:
return None
# Create character-level labels
char_labels = [0] * len(labeled_text)
# Mark instruction characters
for match in matches:
start, end = match.span()
# Mark the content inside tags as instruction (1)
content_start = start + len('')
content_end = end - len('')
for i in range(content_start, content_end):
char_labels[i] = 1
# Remove tags and adjust labels
clean_text = re.sub(instruction_pattern, r'\1', labeled_text)
# Recalculate labels for clean text
clean_char_labels = []
original_idx = 0
for char in clean_text:
# Skip over tag characters in original text
while original_idx < len(labeled_text) and labeled_text[original_idx] in '<>/':
if labeled_text[original_idx:original_idx+13] == '':
original_idx += 13
elif labeled_text[original_idx:original_idx+14] == '':
original_idx += 14
else:
original_idx += 1
if original_idx < len(char_labels):
clean_char_labels.append(char_labels[original_idx])
else:
clean_char_labels.append(0)
original_idx += 1
# Tokenize and align labels
tokens = clean_text.split()
token_labels = []
char_idx = 0
for token in tokens:
# Skip whitespace
while char_idx < len(clean_text) and clean_text[char_idx].isspace():
char_idx += 1
# Check if any character in this token is labeled as instruction
token_is_instruction = False
for i in range(len(token)):
if char_idx + i < len(clean_char_labels) and clean_char_labels[char_idx + i] == 1:
token_is_instruction = True
break
token_labels.append(1 if token_is_instruction else 0)
char_idx += len(token)
return {
'id': sample_id,
'tokens': tokens,
'labels': token_labels,
'original_text': clean_text
}
except Exception as e:
logger.error(f"Error processing sample {sample_id}: {e}")
return None
def _create_subword_sliding_windows(self, raw_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Create sliding windows at subword level - eliminates all data loss and mismatch issues"""
logger = logging.getLogger(__name__)
windowed_data = []
logger.info(f"Creating subword-level sliding windows:")
logger.info(f" Window size: {self.max_length} subword tokens")
logger.info(f" Overlap: {self.overlap} subword tokens")
logger.info(f" Label strategy: BERT paper approach (first subtoken only)")
total_original_samples = len(raw_data)
total_windows = 0
samples_with_multiple_windows = 0
# Word split tracking
total_words_processed = 0
total_words_split_across_windows = 0
samples_with_split_words = 0
for sample in raw_data:
words = sample['tokens']
word_labels = sample['labels']
sample_id = sample['id']
encoded = self.tokenizer(
words,
is_split_into_words=True,
add_special_tokens=True, # Include [CLS], [SEP]
truncation=False, # We handle long sequences with sliding windows
padding=False,
return_tensors='pt'
)
subword_tokens = encoded['input_ids'][0].tolist()
word_ids = encoded.word_ids()
# Step 2: Create aligned subword labels (BERT paper approach)
# Only the FIRST subtoken of each word gets the real label, rest get -100
subword_labels = []
prev_word_id = None
for word_id in word_ids:
if word_id is None:
subword_labels.append(-100) # Special tokens [CLS], [SEP]
elif word_id != prev_word_id:
# First subtoken of a new word - assign the real label
subword_labels.append(word_labels[word_id])
prev_word_id = word_id
else:
# Subsequent subtoken of the same word - assign dummy label
subword_labels.append(-100)
# prev_word_id remains the same
# Step 3: Create sliding windows at subword level
if len(subword_tokens) <= self.max_length:
# Single window - no word splits possible
windowed_data.append({
'subword_tokens': subword_tokens,
'subword_labels': subword_labels,
'original_words': words,
'original_labels': word_labels,
'sample_id': sample_id,
'window_id': 0,
'total_windows': 1,
'window_start': 0,
'window_end': len(subword_tokens),
'original_text': sample['original_text']
})
total_windows += 1
total_words_processed += len(words)
else:
# Multiple windows needed
step = self.max_length - self.overlap
window_count = 0
split_words_this_sample = set()
for start in range(0, len(subword_tokens), step):
end = min(start + self.max_length, len(subword_tokens))
# Extract subword window
window_subword_tokens = subword_tokens[start:end]
window_subword_labels = subword_labels[start:end]
# Track word splits for this window
window_word_ids = word_ids[start:end] if word_ids else []
window_words_set = set(wid for wid in window_word_ids if wid is not None)
# Find which words are split across window boundaries
for word_idx in window_words_set:
if word_idx is not None:
# Check if this word's subwords extend beyond current window
word_subword_positions = [i for i, wid in enumerate(word_ids) if wid == word_idx]
word_start_pos = min(word_subword_positions)
word_end_pos = max(word_subword_positions)
# Word is split if it extends beyond current window boundaries
if word_start_pos < start or word_end_pos >= end:
split_words_this_sample.add(word_idx)
# Get original words for this window (for debugging/inspection)
window_word_indices = list(window_words_set)
window_original_words = [words[i] for i in window_word_indices if i < len(words)]
window_original_labels = [word_labels[i] for i in window_word_indices if i < len(words)]
windowed_data.append({
'subword_tokens': window_subword_tokens,
'subword_labels': window_subword_labels,
'original_words': window_original_words, # For reference only
'original_labels': window_original_labels, # For reference only
'sample_id': sample_id,
'window_id': window_count,
'total_windows': -1, # Will be filled later
'window_start': start,
'window_end': end,
'original_text': sample['original_text']
})
window_count += 1
total_windows += 1
# Break if we've covered all subword tokens
if end >= len(subword_tokens):
break
# Update total_windows for this sample
for i in range(len(windowed_data) - window_count, len(windowed_data)):
windowed_data[i]['total_windows'] = window_count
# Track word split statistics
total_words_processed += len(words)
total_words_split_across_windows += len(split_words_this_sample)
if len(split_words_this_sample) > 0:
samples_with_split_words += 1
if window_count > 1:
samples_with_multiple_windows += 1
# Calculate word split statistics
word_split_percentage = (total_words_split_across_windows / total_words_processed * 100) if total_words_processed > 0 else 0
logger.info(f"=== Subword Sliding Window Statistics ===")
logger.info(f" Original samples: {total_original_samples}")
logger.info(f" Total windows created: {total_windows}")
logger.info(f" Samples split into multiple windows: {samples_with_multiple_windows}")
logger.info(f" Average windows per sample: {total_windows / total_original_samples:.2f}")
logger.info(f"=== Word Split Analysis ===")
logger.info(f" Total words processed: {total_words_processed:,}")
logger.info(f" Words split across windows: {total_words_split_across_windows:,}")
logger.info(f" Word split rate: {word_split_percentage:.2f}%")
logger.info(f" Samples with split words: {samples_with_split_words} / {total_original_samples}")
if word_split_percentage > 10.0:
logger.warning(f" HIGH WORD SPLIT RATE: {word_split_percentage:.1f}% - consider larger overlap")
elif word_split_percentage > 5.0:
logger.warning(f" Moderate word splitting: {word_split_percentage:.1f}% - monitor model performance")
else:
logger.info(f" Excellent word preservation: {100 - word_split_percentage:.1f}% of words intact")
logger.info(f"✅ ZERO DATA LOSS: All subword tokens processed exactly once")
logger.info(f"📋 BERT PAPER APPROACH: Only first subtokens carry labels for training/evaluation")
return windowed_data
def __len__(self):
return len(self.processed_data)
def __getitem__(self, idx):
window_data = self.processed_data[idx]
subword_tokens = window_data['subword_tokens']
subword_labels = window_data['subword_labels']
# Convert subword tokens to padded tensors (no tokenization needed!)
input_ids = subword_tokens[:self.max_length] # Guaranteed to fit
# Pad to max_length if needed
pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
while len(input_ids) < self.max_length:
input_ids.append(pad_token_id)
# Create attention mask (1 for real tokens, 0 for padding)
attention_mask = [1 if token != pad_token_id else 0 for token in input_ids]
# Pad labels to match
labels = subword_labels[:self.max_length]
while len(labels) < self.max_length:
labels.append(-100) # Ignore padding tokens in loss
return {
'input_ids': torch.tensor(input_ids, dtype=torch.long),
'attention_mask': torch.tensor(attention_mask, dtype=torch.long),
'labels': torch.tensor(labels, dtype=torch.long),
'original_tokens': window_data['original_words'], # Original words for reference
'original_labels': window_data['original_labels'], # Original word labels
# Add window metadata for evaluation aggregation
'sample_id': window_data['sample_id'],
'window_id': window_data['window_id'],
'total_windows': window_data['total_windows'],
'window_start': window_data['window_start'],
'window_end': window_data['window_end']
}
class TransformerInstructionClassifier(nn.Module):
def __init__(self, model_name: str = 'xlm-roberta-base', num_labels: int = 2,
class_weights=None, loss_type='weighted_ce', dropout: float = 0.1):
super().__init__()
self.num_labels = num_labels
self.loss_type = loss_type
# Load pre-trained transformer model (XLM-RoBERTa, ModernBERT, etc.)
self.bert = AutoModel.from_pretrained(model_name)
self.dropout = nn.Dropout(dropout)
# Classification head
self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)
# Setup loss function based on type
if loss_type == 'weighted_ce':
self.loss_fct = nn.CrossEntropyLoss(ignore_index=-100, weight=class_weights)
else:
self.loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
def forward(self, input_ids, attention_mask, labels=None):
# Get BERT outputs
outputs = self.bert(
input_ids=input_ids,
attention_mask=attention_mask
)
# Get last hidden state
last_hidden_state = outputs.last_hidden_state
# Apply dropout
last_hidden_state = self.dropout(last_hidden_state)
# Classification
logits = self.classifier(last_hidden_state)
loss = None
if labels is not None:
logger = logging.getLogger(__name__)
# Check for NaN in inputs before loss calculation
if torch.isnan(logits).any():
logger.warning("NaN detected in logits!")
if torch.isnan(labels.float()).any():
logger.warning("NaN detected in labels!")
loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
# Check if loss is NaN
if torch.isnan(loss):
logger.warning("NaN loss detected!")
logger.warning(f"Logits stats: min={logits.min()}, max={logits.max()}, mean={logits.mean()}")
logger.warning(f"Labels unique values: {torch.unique(labels[labels != -100])}")
return {
'loss': loss,
'logits': logits
}
def collate_fn(batch):
"""Custom collate function for DataLoader"""
input_ids = torch.stack([item['input_ids'] for item in batch])
attention_mask = torch.stack([item['attention_mask'] for item in batch])
labels = torch.stack([item['labels'] for item in batch])
return {
'input_ids': input_ids,
'attention_mask': attention_mask,
'labels': labels,
'original_tokens': [item['original_tokens'] for item in batch],
'original_labels': [item['original_labels'] for item in batch],
# Add window metadata
'sample_ids': [item['sample_id'] for item in batch],
'window_ids': [item['window_id'] for item in batch],
'total_windows': [item['total_windows'] for item in batch],
'window_starts': [item['window_start'] for item in batch],
'window_ends': [item['window_end'] for item in batch]
}
def predict_instructions(model, tokenizer, text: str, device=None, threshold=0.4):
"""Predict instructions in a given text
Args:
model: The trained instruction classifier model
tokenizer: The tokenizer for the model
text: Input text to analyze
device: Device to run inference on
threshold: Probability threshold for classifying tokens as INSTRUCTION.
Lower values = more aggressive detection (default: 0.4)
Returns:
tuple: (tokens, predictions) where predictions are 0=OTHER, 1=INSTRUCTION
"""
# Auto-detect device if not provided
if device is None:
if torch.backends.mps.is_available():
device = torch.device('mps')
elif torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
model.eval()
# Keep original casing since XLM-RoBERTa is case-sensitive
# text = text.lower() # Removed for cased model
tokens = text.split()
# Tokenize
encoded = tokenizer(
tokens,
is_split_into_words=True,
padding='max_length',
truncation=True,
max_length=512,
return_tensors='pt'
)
input_ids = encoded['input_ids'].to(device)
attention_mask = encoded['attention_mask'].to(device)
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
# Convert logits to probabilities
probs = torch.softmax(outputs['logits'], dim=-1)
# Use threshold on probability of class 1 (INSTRUCTION) instead of argmax
# This makes the classifier more aggressive - tokens are classified as INSTRUCTION
# if their probability of being INSTRUCTION is above the threshold
predictions = (probs[:, :, 1] > threshold).long()
# Align predictions with original tokens
word_ids = encoded.word_ids()
word_predictions = []
prev_word_id = None
for i, word_id in enumerate(word_ids):
if word_id is not None and word_id != prev_word_id:
if word_id < len(tokens):
word_predictions.append(predictions[0][i].item())
prev_word_id = word_id
return tokens, word_predictions
def get_device():
"""Get the best available device"""
if torch.backends.mps.is_available():
return torch.device('mps')
elif torch.cuda.is_available():
return torch.device('cuda')
else:
return torch.device('cpu')