Spaces:
Running
on
Zero
Running
on
Zero
import json | |
import torch | |
import torch.nn as nn | |
from torch.utils.data import Dataset, DataLoader | |
from transformers import AutoTokenizer, AutoModel, AutoConfig | |
import numpy as np | |
from tqdm import tqdm | |
import re | |
from typing import List, Tuple, Dict, Any | |
import warnings | |
import logging | |
import os | |
from datetime import datetime | |
from sklearn.utils.class_weight import compute_class_weight | |
import torch.nn.functional as F | |
# Disable tokenizer parallelism to avoid forking warnings | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
warnings.filterwarnings('ignore') | |
def set_random_seeds(seed=42): | |
"""Set random seeds for reproducibility""" | |
import random | |
import numpy as np | |
import torch | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) # For multi-GPU | |
# Make CuDNN deterministic (slower but reproducible) | |
torch.backends.cudnn.deterministic = True | |
torch.backends.cudnn.benchmark = False | |
def setup_logging(log_dir='data/logs'): | |
"""Setup logging configuration""" | |
# Create logs directory if it doesn't exist | |
os.makedirs(log_dir, exist_ok=True) | |
# Create timestamp for log file | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
log_file = os.path.join(log_dir, f'training_log_{timestamp}.log') | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, # Back to INFO level | |
format='%(asctime)s - %(levelname)s - %(message)s', | |
handlers=[ | |
logging.FileHandler(log_file), | |
logging.StreamHandler() # Also print to console | |
] | |
) | |
logger = logging.getLogger(__name__) | |
logger.info(f"Logging initialized. Log file: {log_file}") | |
return logger, log_file | |
def check_gpu_availability(): | |
"""Check and print GPU availability information""" | |
logger = logging.getLogger(__name__) | |
logger.info("=== GPU Availability Check ===") | |
if torch.backends.mps.is_available(): | |
logger.info("β MPS (Apple Silicon GPU) is available") | |
if torch.backends.mps.is_built(): | |
logger.info("β MPS is built into PyTorch") | |
else: | |
logger.info("β MPS is not built into PyTorch") | |
else: | |
logger.info("β MPS (Apple Silicon GPU) is not available") | |
if torch.cuda.is_available(): | |
logger.info(f"β CUDA is available (GPU count: {torch.cuda.device_count()})") | |
else: | |
logger.info("β CUDA is not available") | |
logger.info(f"PyTorch version: {torch.__version__}") | |
logger.info("=" * 50) | |
def calculate_class_weights(dataset): | |
"""Calculate class weights for imbalanced dataset using BERT paper approach""" | |
logger = logging.getLogger(__name__) | |
# Collect all labels from the dataset (BERT approach: only first subtokens have real labels) | |
all_labels = [] | |
for window_data in dataset.processed_data: | |
# Filter out -100 labels (special tokens + subsequent subtokens of same word) | |
# This gives us true word-level class distribution | |
valid_labels = [label for label in window_data['subword_labels'] if label != -100] | |
all_labels.extend(valid_labels) | |
# Convert to numpy array | |
y = np.array(all_labels) | |
# Calculate class weights using sklearn | |
classes = np.unique(y) | |
class_weights = compute_class_weight('balanced', classes=classes, y=y) | |
# Create weight tensor | |
weight_tensor = torch.FloatTensor(class_weights) | |
logger.info(f"Word-level class distribution: {np.bincount(y)}") | |
logger.info(f"Class 0 (Non-instruction words): {np.sum(y == 0)} words ({np.sum(y == 0)/len(y)*100:.1f}%)") | |
logger.info(f"Class 1 (Instruction words): {np.sum(y == 1)} words ({np.sum(y == 1)/len(y)*100:.1f}%)") | |
logger.info(f"Calculated class weights (word-level): {class_weights}") | |
logger.info(f" Weight for class 0 (Non-instruction): {class_weights[0]:.4f}") | |
logger.info(f" Weight for class 1 (Instruction): {class_weights[1]:.4f}") | |
return weight_tensor | |
class FocalLoss(nn.Module): | |
"""Focal Loss for addressing class imbalance""" | |
def __init__(self, alpha=1, gamma=2, ignore_index=-100): | |
super(FocalLoss, self).__init__() | |
self.alpha = alpha | |
self.gamma = gamma | |
self.ignore_index = ignore_index | |
def forward(self, inputs, targets): | |
# Flatten inputs and targets | |
inputs = inputs.view(-1, inputs.size(-1)) | |
targets = targets.view(-1) | |
# Create mask for non-ignored indices | |
mask = targets != self.ignore_index | |
targets = targets[mask] | |
inputs = inputs[mask] | |
if len(targets) == 0: | |
return torch.tensor(0.0, requires_grad=True, device=inputs.device) | |
# Calculate cross entropy | |
ce_loss = F.cross_entropy(inputs, targets, reduction='none') | |
# Calculate pt | |
pt = torch.exp(-ce_loss) | |
# Calculate focal loss | |
focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss | |
return focal_loss.mean() | |
class InstructionDataset(Dataset): | |
def __init__(self, data_path: str, tokenizer, max_length: int = 512, is_training: bool = True, | |
window_size: int = 512, overlap: int = 100): | |
self.tokenizer = tokenizer | |
self.max_length = max_length | |
self.is_training = is_training | |
self.window_size = window_size | |
self.overlap = overlap | |
# Load and process data | |
self.raw_data = self._load_and_process_data(data_path) | |
# Create sliding windows at subword level (eliminates all data loss) | |
self.processed_data = self._create_subword_sliding_windows(self.raw_data) | |
def _load_and_process_data(self, data_path: str) -> List[Dict[str, Any]]: | |
"""Load JSONL data and process it for token classification""" | |
logger = logging.getLogger(__name__) | |
processed_data = [] | |
skipped_count = 0 | |
sanity_check_failed = 0 | |
total_instruction_tokens = 0 | |
total_non_instruction_tokens = 0 | |
logger.info(f"Loading data from: {data_path}") | |
with open(data_path, 'r', encoding='utf-8') as f: | |
for line_num, line in enumerate(f, 1): | |
try: | |
data = json.loads(line.strip()) | |
# Skip data points that failed sanity check | |
sanity_check = data.get('sanity_check', False) # Default to False if not present | |
if sanity_check is False: | |
sanity_check_failed += 1 | |
continue | |
# Extract labeled text | |
labeled_text = data.get('label_text', '') | |
# Remove <text>...</text> tags if present | |
if labeled_text.startswith("<text>") and labeled_text.endswith("</text>"): | |
labeled_text = labeled_text[len("<text>"):-len("</text>")] | |
labeled_text = labeled_text.strip() | |
sample_id = data.get('id', f'sample_{line_num}') | |
# Process the tagged text | |
processed_sample = self._process_tagged_text(labeled_text, sample_id) | |
if processed_sample is not None: | |
processed_data.append(processed_sample) | |
# Count token distribution for debugging | |
labels = processed_sample['labels'] | |
sample_instruction_tokens = sum(1 for label in labels if label == 1) | |
total_instruction_tokens += sample_instruction_tokens | |
total_non_instruction_tokens += len(labels) - sample_instruction_tokens | |
else: | |
skipped_count += 1 | |
except Exception as e: | |
logger.error(f"Error processing line {line_num}: {e}") | |
skipped_count += 1 | |
logger.info(f"Successfully processed {len(processed_data)} samples") | |
logger.info(f"Skipped {skipped_count} samples due to errors or malformed data") | |
logger.info(f"Skipped {sanity_check_failed} samples due to failed sanity check") | |
logger.info(f"Token distribution - Instruction: {total_instruction_tokens}, Non-instruction: {total_non_instruction_tokens}") | |
if total_instruction_tokens == 0: | |
logger.warning("No instruction tokens found! This will cause training issues.") | |
if total_non_instruction_tokens == 0: | |
logger.warning("No non-instruction tokens found! This will cause training issues.") | |
return processed_data | |
def _process_tagged_text(self, labeled_text: str, sample_id: str) -> Dict[str, Any] | None: | |
"""Process tagged text to extract tokens and labels""" | |
logger = logging.getLogger(__name__) | |
try: | |
# Keep original casing since XLM-RoBERTa is case-sensitive | |
# labeled_text = labeled_text.lower() # Removed for cased model | |
# Find all instruction tags | |
instruction_pattern = r'<instruction>(.*?)</instruction>' | |
matches = list(re.finditer(instruction_pattern, labeled_text, re.DOTALL)) | |
# Check for malformed tags or edge cases | |
if '<instruction>' in labeled_text and '</instruction>' not in labeled_text: | |
return None | |
if '</instruction>' in labeled_text and '<instruction>' not in labeled_text: | |
return None | |
# Create character-level labels | |
char_labels = [0] * len(labeled_text) | |
# Mark instruction characters | |
for match in matches: | |
start, end = match.span() | |
# Mark the content inside tags as instruction (1) | |
content_start = start + len('<instruction>') | |
content_end = end - len('</instruction>') | |
for i in range(content_start, content_end): | |
char_labels[i] = 1 | |
# Remove tags and adjust labels | |
clean_text = re.sub(instruction_pattern, r'\1', labeled_text) | |
# Recalculate labels for clean text | |
clean_char_labels = [] | |
original_idx = 0 | |
for char in clean_text: | |
# Skip over tag characters in original text | |
while original_idx < len(labeled_text) and labeled_text[original_idx] in '<>/': | |
if labeled_text[original_idx:original_idx+13] == '<instruction>': | |
original_idx += 13 | |
elif labeled_text[original_idx:original_idx+14] == '</instruction>': | |
original_idx += 14 | |
else: | |
original_idx += 1 | |
if original_idx < len(char_labels): | |
clean_char_labels.append(char_labels[original_idx]) | |
else: | |
clean_char_labels.append(0) | |
original_idx += 1 | |
# Tokenize and align labels | |
tokens = clean_text.split() | |
token_labels = [] | |
char_idx = 0 | |
for token in tokens: | |
# Skip whitespace | |
while char_idx < len(clean_text) and clean_text[char_idx].isspace(): | |
char_idx += 1 | |
# Check if any character in this token is labeled as instruction | |
token_is_instruction = False | |
for i in range(len(token)): | |
if char_idx + i < len(clean_char_labels) and clean_char_labels[char_idx + i] == 1: | |
token_is_instruction = True | |
break | |
token_labels.append(1 if token_is_instruction else 0) | |
char_idx += len(token) | |
return { | |
'id': sample_id, | |
'tokens': tokens, | |
'labels': token_labels, | |
'original_text': clean_text | |
} | |
except Exception as e: | |
logger.error(f"Error processing sample {sample_id}: {e}") | |
return None | |
def _create_subword_sliding_windows(self, raw_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: | |
"""Create sliding windows at subword level - eliminates all data loss and mismatch issues""" | |
logger = logging.getLogger(__name__) | |
windowed_data = [] | |
logger.info(f"Creating subword-level sliding windows:") | |
logger.info(f" Window size: {self.max_length} subword tokens") | |
logger.info(f" Overlap: {self.overlap} subword tokens") | |
logger.info(f" Label strategy: BERT paper approach (first subtoken only)") | |
total_original_samples = len(raw_data) | |
total_windows = 0 | |
samples_with_multiple_windows = 0 | |
# Word split tracking | |
total_words_processed = 0 | |
total_words_split_across_windows = 0 | |
samples_with_split_words = 0 | |
for sample in raw_data: | |
words = sample['tokens'] | |
word_labels = sample['labels'] | |
sample_id = sample['id'] | |
encoded = self.tokenizer( | |
words, | |
is_split_into_words=True, | |
add_special_tokens=True, # Include [CLS], [SEP] | |
truncation=False, # We handle long sequences with sliding windows | |
padding=False, | |
return_tensors='pt' | |
) | |
subword_tokens = encoded['input_ids'][0].tolist() | |
word_ids = encoded.word_ids() | |
# Step 2: Create aligned subword labels (BERT paper approach) | |
# Only the FIRST subtoken of each word gets the real label, rest get -100 | |
subword_labels = [] | |
prev_word_id = None | |
for word_id in word_ids: | |
if word_id is None: | |
subword_labels.append(-100) # Special tokens [CLS], [SEP] | |
elif word_id != prev_word_id: | |
# First subtoken of a new word - assign the real label | |
subword_labels.append(word_labels[word_id]) | |
prev_word_id = word_id | |
else: | |
# Subsequent subtoken of the same word - assign dummy label | |
subword_labels.append(-100) | |
# prev_word_id remains the same | |
# Step 3: Create sliding windows at subword level | |
if len(subword_tokens) <= self.max_length: | |
# Single window - no word splits possible | |
windowed_data.append({ | |
'subword_tokens': subword_tokens, | |
'subword_labels': subword_labels, | |
'original_words': words, | |
'original_labels': word_labels, | |
'sample_id': sample_id, | |
'window_id': 0, | |
'total_windows': 1, | |
'window_start': 0, | |
'window_end': len(subword_tokens), | |
'original_text': sample['original_text'] | |
}) | |
total_windows += 1 | |
total_words_processed += len(words) | |
else: | |
# Multiple windows needed | |
step = self.max_length - self.overlap | |
window_count = 0 | |
split_words_this_sample = set() | |
for start in range(0, len(subword_tokens), step): | |
end = min(start + self.max_length, len(subword_tokens)) | |
# Extract subword window | |
window_subword_tokens = subword_tokens[start:end] | |
window_subword_labels = subword_labels[start:end] | |
# Track word splits for this window | |
window_word_ids = word_ids[start:end] if word_ids else [] | |
window_words_set = set(wid for wid in window_word_ids if wid is not None) | |
# Find which words are split across window boundaries | |
for word_idx in window_words_set: | |
if word_idx is not None: | |
# Check if this word's subwords extend beyond current window | |
word_subword_positions = [i for i, wid in enumerate(word_ids) if wid == word_idx] | |
word_start_pos = min(word_subword_positions) | |
word_end_pos = max(word_subword_positions) | |
# Word is split if it extends beyond current window boundaries | |
if word_start_pos < start or word_end_pos >= end: | |
split_words_this_sample.add(word_idx) | |
# Get original words for this window (for debugging/inspection) | |
window_word_indices = list(window_words_set) | |
window_original_words = [words[i] for i in window_word_indices if i < len(words)] | |
window_original_labels = [word_labels[i] for i in window_word_indices if i < len(words)] | |
windowed_data.append({ | |
'subword_tokens': window_subword_tokens, | |
'subword_labels': window_subword_labels, | |
'original_words': window_original_words, # For reference only | |
'original_labels': window_original_labels, # For reference only | |
'sample_id': sample_id, | |
'window_id': window_count, | |
'total_windows': -1, # Will be filled later | |
'window_start': start, | |
'window_end': end, | |
'original_text': sample['original_text'] | |
}) | |
window_count += 1 | |
total_windows += 1 | |
# Break if we've covered all subword tokens | |
if end >= len(subword_tokens): | |
break | |
# Update total_windows for this sample | |
for i in range(len(windowed_data) - window_count, len(windowed_data)): | |
windowed_data[i]['total_windows'] = window_count | |
# Track word split statistics | |
total_words_processed += len(words) | |
total_words_split_across_windows += len(split_words_this_sample) | |
if len(split_words_this_sample) > 0: | |
samples_with_split_words += 1 | |
if window_count > 1: | |
samples_with_multiple_windows += 1 | |
# Calculate word split statistics | |
word_split_percentage = (total_words_split_across_windows / total_words_processed * 100) if total_words_processed > 0 else 0 | |
logger.info(f"=== Subword Sliding Window Statistics ===") | |
logger.info(f" Original samples: {total_original_samples}") | |
logger.info(f" Total windows created: {total_windows}") | |
logger.info(f" Samples split into multiple windows: {samples_with_multiple_windows}") | |
logger.info(f" Average windows per sample: {total_windows / total_original_samples:.2f}") | |
logger.info(f"=== Word Split Analysis ===") | |
logger.info(f" Total words processed: {total_words_processed:,}") | |
logger.info(f" Words split across windows: {total_words_split_across_windows:,}") | |
logger.info(f" Word split rate: {word_split_percentage:.2f}%") | |
logger.info(f" Samples with split words: {samples_with_split_words} / {total_original_samples}") | |
if word_split_percentage > 10.0: | |
logger.warning(f" HIGH WORD SPLIT RATE: {word_split_percentage:.1f}% - consider larger overlap") | |
elif word_split_percentage > 5.0: | |
logger.warning(f" Moderate word splitting: {word_split_percentage:.1f}% - monitor model performance") | |
else: | |
logger.info(f" Excellent word preservation: {100 - word_split_percentage:.1f}% of words intact") | |
logger.info(f"β ZERO DATA LOSS: All subword tokens processed exactly once") | |
logger.info(f"π BERT PAPER APPROACH: Only first subtokens carry labels for training/evaluation") | |
return windowed_data | |
def __len__(self): | |
return len(self.processed_data) | |
def __getitem__(self, idx): | |
window_data = self.processed_data[idx] | |
subword_tokens = window_data['subword_tokens'] | |
subword_labels = window_data['subword_labels'] | |
# Convert subword tokens to padded tensors (no tokenization needed!) | |
input_ids = subword_tokens[:self.max_length] # Guaranteed to fit | |
# Pad to max_length if needed | |
pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id | |
while len(input_ids) < self.max_length: | |
input_ids.append(pad_token_id) | |
# Create attention mask (1 for real tokens, 0 for padding) | |
attention_mask = [1 if token != pad_token_id else 0 for token in input_ids] | |
# Pad labels to match | |
labels = subword_labels[:self.max_length] | |
while len(labels) < self.max_length: | |
labels.append(-100) # Ignore padding tokens in loss | |
return { | |
'input_ids': torch.tensor(input_ids, dtype=torch.long), | |
'attention_mask': torch.tensor(attention_mask, dtype=torch.long), | |
'labels': torch.tensor(labels, dtype=torch.long), | |
'original_tokens': window_data['original_words'], # Original words for reference | |
'original_labels': window_data['original_labels'], # Original word labels | |
# Add window metadata for evaluation aggregation | |
'sample_id': window_data['sample_id'], | |
'window_id': window_data['window_id'], | |
'total_windows': window_data['total_windows'], | |
'window_start': window_data['window_start'], | |
'window_end': window_data['window_end'] | |
} | |
class TransformerInstructionClassifier(nn.Module): | |
def __init__(self, model_name: str = 'xlm-roberta-base', num_labels: int = 2, | |
class_weights=None, loss_type='weighted_ce', dropout: float = 0.1): | |
super().__init__() | |
self.num_labels = num_labels | |
self.loss_type = loss_type | |
# Load pre-trained transformer model (XLM-RoBERTa, ModernBERT, etc.) | |
self.bert = AutoModel.from_pretrained(model_name) | |
self.dropout = nn.Dropout(dropout) | |
# Classification head | |
self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels) | |
# Setup loss function based on type | |
if loss_type == 'weighted_ce': | |
self.loss_fct = nn.CrossEntropyLoss(ignore_index=-100, weight=class_weights) | |
elif loss_type == 'focal': | |
self.loss_fct = FocalLoss(alpha=1, gamma=2, ignore_index=-100) | |
else: | |
self.loss_fct = nn.CrossEntropyLoss(ignore_index=-100) | |
def forward(self, input_ids, attention_mask, labels=None): | |
# Get BERT outputs | |
outputs = self.bert( | |
input_ids=input_ids, | |
attention_mask=attention_mask | |
) | |
# Get last hidden state | |
last_hidden_state = outputs.last_hidden_state | |
# Apply dropout | |
last_hidden_state = self.dropout(last_hidden_state) | |
# Classification | |
logits = self.classifier(last_hidden_state) | |
loss = None | |
if labels is not None: | |
logger = logging.getLogger(__name__) | |
# Check for NaN in inputs before loss calculation | |
if torch.isnan(logits).any(): | |
logger.warning("NaN detected in logits!") | |
if torch.isnan(labels.float()).any(): | |
logger.warning("NaN detected in labels!") | |
loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) | |
# Check if loss is NaN | |
if torch.isnan(loss): | |
logger.warning("NaN loss detected!") | |
logger.warning(f"Logits stats: min={logits.min()}, max={logits.max()}, mean={logits.mean()}") | |
logger.warning(f"Labels unique values: {torch.unique(labels[labels != -100])}") | |
return { | |
'loss': loss, | |
'logits': logits | |
} | |
def collate_fn(batch): | |
"""Custom collate function for DataLoader""" | |
input_ids = torch.stack([item['input_ids'] for item in batch]) | |
attention_mask = torch.stack([item['attention_mask'] for item in batch]) | |
labels = torch.stack([item['labels'] for item in batch]) | |
return { | |
'input_ids': input_ids, | |
'attention_mask': attention_mask, | |
'labels': labels, | |
'original_tokens': [item['original_tokens'] for item in batch], | |
'original_labels': [item['original_labels'] for item in batch], | |
# Add window metadata | |
'sample_ids': [item['sample_id'] for item in batch], | |
'window_ids': [item['window_id'] for item in batch], | |
'total_windows': [item['total_windows'] for item in batch], | |
'window_starts': [item['window_start'] for item in batch], | |
'window_ends': [item['window_end'] for item in batch] | |
} | |
def predict_instructions(model, tokenizer, text: str, device=None): | |
"""Predict instructions in a given text""" | |
# Auto-detect device if not provided | |
if device is None: | |
if torch.backends.mps.is_available(): | |
device = torch.device('mps') | |
elif torch.cuda.is_available(): | |
device = torch.device('cuda') | |
else: | |
device = torch.device('cpu') | |
model.eval() | |
# Keep original casing since XLM-RoBERTa is case-sensitive | |
# text = text.lower() # Removed for cased model | |
tokens = text.split() | |
# Tokenize | |
encoded = tokenizer( | |
tokens, | |
is_split_into_words=True, | |
padding='max_length', | |
truncation=True, | |
max_length=512, | |
return_tensors='pt' | |
) | |
input_ids = encoded['input_ids'].to(device) | |
attention_mask = encoded['attention_mask'].to(device) | |
with torch.no_grad(): | |
outputs = model(input_ids=input_ids, attention_mask=attention_mask) | |
predictions = torch.argmax(outputs['logits'], dim=-1) | |
# Align predictions with original tokens | |
word_ids = encoded.word_ids() | |
word_predictions = [] | |
prev_word_id = None | |
for i, word_id in enumerate(word_ids): | |
if word_id is not None and word_id != prev_word_id: | |
if word_id < len(tokens): | |
word_predictions.append(predictions[0][i].item()) | |
prev_word_id = word_id | |
return tokens, word_predictions | |
def get_device(): | |
"""Get the best available device""" | |
if torch.backends.mps.is_available(): | |
return torch.device('mps') | |
elif torch.cuda.is_available(): | |
return torch.device('cuda') | |
else: | |
return torch.device('cpu') |