|
import pandas as pd |
|
import torch |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
from datetime import datetime |
|
import logging |
|
from pathlib import Path |
|
from torch.utils.data import DataLoader |
|
import sys |
|
import os |
|
import wandb |
|
from transformers import get_linear_schedule_with_warmup |
|
|
|
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
|
from model.training_config import TrainingConfig |
|
from model.language_aware_transformer import LanguageAwareTransformer |
|
from model.train import ToxicDataset |
|
from transformers import XLMRobertaTokenizer |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(levelname)s - %(message)s' |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
def setup_plot_style(): |
|
"""Configure plot styling""" |
|
plt.style.use('seaborn-darkgrid') |
|
plt.rcParams['figure.figsize'] = (12, 12) |
|
plt.rcParams['font.size'] = 12 |
|
|
|
def setup_wandb(): |
|
"""Initialize wandb for validation tracking""" |
|
try: |
|
wandb.init( |
|
project="toxic-comment-classification", |
|
name=f"validation-analysis-{datetime.now().strftime('%Y%m%d-%H%M%S')}", |
|
config={ |
|
"analysis_type": "validation_loss", |
|
"timestamp": datetime.now().strftime('%Y%m%d-%H%M%S') |
|
} |
|
) |
|
logger.info("Initialized wandb logging") |
|
except Exception as e: |
|
logger.error(f"Error initializing wandb: {str(e)}") |
|
raise |
|
|
|
def load_model_and_data(): |
|
"""Load the model and validation data""" |
|
try: |
|
|
|
config = TrainingConfig( |
|
batch_size=16, |
|
num_workers=16, |
|
lr=2e-5, |
|
weight_decay=0.01, |
|
max_grad_norm=1.0, |
|
warmup_ratio=0.1, |
|
label_smoothing=0.01, |
|
|
|
mixed_precision="fp16", |
|
activation_checkpointing=True, |
|
epochs=1 |
|
|
|
) |
|
|
|
|
|
logger.info("Loading validation and test data...") |
|
val_df = pd.read_csv("dataset/split/val.csv") |
|
test_df = pd.read_csv("dataset/split/test.csv") |
|
combined_df = pd.concat([val_df, test_df]) |
|
tokenizer = XLMRobertaTokenizer.from_pretrained(config.model_name) |
|
combined_dataset = ToxicDataset(combined_df, tokenizer, config, mode='combined') |
|
|
|
|
|
|
|
combined_loader = DataLoader( |
|
combined_dataset, |
|
batch_size=config.batch_size, |
|
shuffle=True, |
|
num_workers=config.num_workers, |
|
pin_memory=True, |
|
drop_last=False |
|
) |
|
|
|
|
|
if wandb.run is not None: |
|
wandb.config.update({ |
|
'shuffle': True, |
|
'drop_last': False, |
|
'total_validation_steps': len(combined_loader), |
|
'total_validation_samples': len(combined_dataset) |
|
}) |
|
|
|
|
|
|
|
logger.info("Loading model...") |
|
model = LanguageAwareTransformer( |
|
num_labels=len(config.toxicity_labels), |
|
model_name=config.model_name |
|
) |
|
|
|
|
|
checkpoint_path = Path('weights/toxic_classifier_xlm-roberta-large/pytorch_model.bin') |
|
if checkpoint_path.exists(): |
|
checkpoint = torch.load(checkpoint_path, map_location='cpu') |
|
model.load_state_dict(checkpoint) |
|
logger.info("Loaded model checkpoint") |
|
else: |
|
raise FileNotFoundError("No checkpoint found") |
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
model = model.to(device) |
|
|
|
|
|
param_groups = config.get_param_groups(model) |
|
optimizer = torch.optim.AdamW(param_groups) |
|
|
|
|
|
total_steps = len(combined_loader) * config.epochs |
|
warmup_steps = int(total_steps * config.warmup_ratio) |
|
|
|
scheduler = get_linear_schedule_with_warmup( |
|
optimizer, |
|
num_warmup_steps=warmup_steps, |
|
num_training_steps=total_steps |
|
) |
|
|
|
|
|
scaler = torch.cuda.amp.GradScaler(enabled=config.mixed_precision == "fp16") |
|
|
|
|
|
if wandb.run is not None: |
|
wandb.config.update({ |
|
'model_name': config.model_name, |
|
'batch_size': config.batch_size, |
|
'learning_rate': config.lr, |
|
'weight_decay': config.weight_decay, |
|
'max_grad_norm': config.max_grad_norm, |
|
'warmup_ratio': config.warmup_ratio, |
|
'label_smoothing': config.label_smoothing, |
|
'mixed_precision': config.mixed_precision, |
|
'num_workers': config.num_workers, |
|
'activation_checkpointing': config.activation_checkpointing, |
|
'validation_epochs': config.epochs |
|
}) |
|
|
|
return model, combined_loader, device, optimizer, scheduler, scaler, config |
|
|
|
|
|
except Exception as e: |
|
logger.error(f"Error loading model and data: {str(e)}") |
|
raise |
|
|
|
def collect_validation_losses(model, combined_loader, device, optimizer, scheduler, scaler, config): |
|
"""Run validation and collect step losses across multiple epochs""" |
|
try: |
|
logger.warning("This is an analysis run on combined val+test data - model will not be saved or updated") |
|
|
|
model.eval() |
|
for param in model.parameters(): |
|
param.requires_grad = False |
|
|
|
all_losses = [] |
|
epoch_losses = [] |
|
|
|
for epoch in range(config.epochs): |
|
logger.info(f"\nStarting validation epoch {epoch+1}/{config.epochs}") |
|
total_loss = 0 |
|
num_batches = len(combined_loader) |
|
epoch_start_time = datetime.now() |
|
|
|
with torch.no_grad(): |
|
for step, batch in enumerate(combined_loader): |
|
|
|
batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v |
|
for k, v in batch.items()} |
|
|
|
|
|
with torch.cuda.amp.autocast(enabled=config.mixed_precision != "no"): |
|
outputs = model(**batch) |
|
loss = outputs['loss'].item() |
|
|
|
total_loss += loss |
|
|
|
|
|
avg_loss = total_loss / (step + 1) |
|
|
|
|
|
lrs = [group['lr'] for group in optimizer.param_groups] |
|
|
|
|
|
wandb.log({ |
|
'val/step_loss': loss, |
|
'val/running_avg_loss': avg_loss, |
|
'val/progress': (step + 1) / num_batches * 100, |
|
'val/learning_rate': lrs[0], |
|
'val/batch_size': config.batch_size, |
|
'val/epoch': epoch + 1, |
|
'val/global_step': epoch * num_batches + step |
|
}) |
|
|
|
|
|
if step % 10 == 0: |
|
elapsed_time = datetime.now() - epoch_start_time |
|
steps_per_sec = (step + 1) / elapsed_time.total_seconds() |
|
remaining_steps = num_batches - (step + 1) |
|
eta_seconds = remaining_steps / steps_per_sec if steps_per_sec > 0 else 0 |
|
|
|
logger.info( |
|
f"Epoch [{epoch+1}/{config.epochs}] " |
|
f"Step [{step+1}/{num_batches}] " |
|
f"Loss: {loss:.4f} " |
|
f"Avg Loss: {avg_loss:.4f} " |
|
f"LR: {lrs[0]:.2e} " |
|
f"ETA: {int(eta_seconds)}s" |
|
) |
|
|
|
|
|
epoch_avg_loss = total_loss / num_batches |
|
epoch_losses.append({ |
|
'epoch': epoch + 1, |
|
'avg_loss': epoch_avg_loss, |
|
'elapsed_time': (datetime.now() - epoch_start_time).total_seconds() |
|
}) |
|
|
|
|
|
wandb.log({ |
|
'val/epoch_avg_loss': epoch_avg_loss, |
|
'val/epoch_number': epoch + 1, |
|
'val/epoch_time': epoch_losses[-1]['elapsed_time'] |
|
}) |
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
|
return epoch_losses |
|
|
|
except Exception as e: |
|
logger.error(f"Error collecting validation losses: {str(e)}") |
|
raise |
|
|
|
def plot_validation_losses(epoch_losses): |
|
"""Plot validation epoch losses""" |
|
try: |
|
setup_plot_style() |
|
|
|
|
|
fig, ax = plt.subplots() |
|
|
|
|
|
epochs = [d['epoch'] for d in epoch_losses] |
|
losses = [d['avg_loss'] for d in epoch_losses] |
|
|
|
|
|
ax.plot(epochs, losses, 'go-', label='Epoch Average Loss', linewidth=2, markersize=8) |
|
|
|
|
|
z = np.polyfit(epochs, losses, 1) |
|
p = np.poly1d(z) |
|
ax.plot(epochs, p(epochs), "r--", alpha=0.8, label='Loss Trend') |
|
|
|
|
|
ax.set_title('Validation Epoch Losses') |
|
ax.set_xlabel('Epoch') |
|
ax.set_ylabel('Average Loss') |
|
ax.legend() |
|
ax.grid(True, linestyle='--', alpha=0.7) |
|
|
|
|
|
plt.tight_layout() |
|
|
|
|
|
output_dir = Path('analysis/plots') |
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') |
|
output_path = output_dir / f'validation_losses_{timestamp}.png' |
|
plt.savefig(output_path, dpi=300, bbox_inches='tight') |
|
logger.info(f"Plot saved to {output_path}") |
|
|
|
|
|
wandb.log({ |
|
"val/loss_plot": wandb.Image(str(output_path)) |
|
}) |
|
|
|
|
|
plt.show() |
|
|
|
except Exception as e: |
|
logger.error(f"Error plotting validation losses: {str(e)}") |
|
raise |
|
|
|
def calculate_loss_statistics(epoch_losses): |
|
"""Calculate and print loss statistics""" |
|
try: |
|
losses = [d['avg_loss'] for d in epoch_losses] |
|
|
|
stats = { |
|
'Mean Loss': np.mean(losses), |
|
'Std Loss': np.std(losses), |
|
'Min Loss': np.min(losses), |
|
'Max Loss': np.max(losses), |
|
'Best Epoch': epoch_losses[np.argmin(losses)]['epoch'] |
|
} |
|
|
|
|
|
wandb.log({ |
|
'val/mean_loss': stats['Mean Loss'], |
|
'val/std_loss': stats['Std Loss'], |
|
'val/min_loss': stats['Min Loss'], |
|
'val/max_loss': stats['Max Loss'], |
|
'val/best_epoch': stats['Best Epoch'] |
|
}) |
|
|
|
|
|
print("\nValidation Loss Statistics:") |
|
for metric_name, value in stats.items(): |
|
if metric_name == 'Best Epoch': |
|
print(f"{metric_name}: {int(value)}") |
|
else: |
|
print(f"{metric_name}: {value:.4f}") |
|
|
|
return stats |
|
|
|
except Exception as e: |
|
logger.error(f"Error calculating statistics: {str(e)}") |
|
raise |
|
|
|
def main(): |
|
try: |
|
|
|
setup_wandb() |
|
|
|
|
|
logger.info("Loading model and data...") |
|
model, combined_loader, device, optimizer, scheduler, scaler, config = load_model_and_data() |
|
|
|
|
|
|
|
logger.info("Collecting validation losses...") |
|
epoch_losses = collect_validation_losses( |
|
model, combined_loader, device, optimizer, scheduler, scaler, config |
|
) |
|
|
|
|
|
|
|
logger.info("Plotting validation losses...") |
|
plot_validation_losses(epoch_losses) |
|
|
|
|
|
logger.info("Calculating statistics...") |
|
calculate_loss_statistics(epoch_losses) |
|
|
|
except Exception as e: |
|
logger.error(f"Error in main: {str(e)}") |
|
raise |
|
finally: |
|
|
|
torch.cuda.empty_cache() |
|
|
|
wandb.finish() |
|
|
|
if __name__ == "__main__": |
|
try: |
|
main() |
|
except Exception as e: |
|
logger.error(f"Script failed: {str(e)}") |
|
raise |