lwm-v1.1 / train.py
wi-lab's picture
Update train.py
f9ab6ed verified
raw
history blame contribute delete
17.2 kB
# -*- coding: utf-8 -*-
"""
Created on Fri Dec 20 09:32:12 2024
This script contains the LWM pre-training and task-specific fine-tuning functions.
@author: Sadjad Alikhani
"""
import torch
import torch.nn as nn
from tqdm import tqdm
import matplotlib.pyplot as plt
import os
import csv
from utils import count_parameters
import time
#%% LOSS FUNCTION
def nmse_loss(y_pred, y_true):
y_pred_flat = y_pred.view(y_pred.size(0), -1)
y_true_flat = y_true.view(y_true.size(0), -1)
mse = torch.sum((y_true_flat - y_pred_flat)**2, dim=-1)
normalization = torch.sum(y_true_flat**2, dim=-1)
return mse / normalization
#%%
def train_lwm(model, train_loaders, val_loaders, optimizer, scheduler, epochs, device, save_dir="models", log_file="training_log.csv"):
if not os.path.exists(save_dir):
os.makedirs(save_dir)
# Initialize CSV log
if not os.path.exists(log_file):
with open(log_file, mode='w', newline='') as file:
writer = csv.writer(file)
writer.writerow(["Epoch", "Train NMSE", "Validation NMSE", "Learning Rate", "Best Model"])
train_nmse_losses = []
val_nmse_losses = []
best_val_nmse = float('inf')
for epoch in range(epochs):
model.train()
train_nmse = 0.0
train_samples = 0
# Training loop across all buckets
print(f"\nEpoch {epoch + 1}/{epochs} [Training]")
for length, train_loader in train_loaders.items():
print(f"Processing sequences of length {length}")
with tqdm(train_loader, desc=f"Length {length} [Training]", unit="batch") as t:
for batch in t:
# train_batches += 1
optimizer.zero_grad()
# Move data to device
input_ids, masked_tokens, masked_pos = [b.to(device) for b in batch]
# Forward pass
logits_lm, _, _ = model(input_ids, masked_pos)
# Compute NMSE
loss = torch.sum(nmse_loss(masked_tokens, logits_lm))
loss.backward()
optimizer.step()
scheduler.step()
train_nmse += loss.item()
train_samples += input_ids.shape[0]
# Update progress bar
t.set_postfix({"nmse": train_nmse/train_samples, "lr": scheduler.get_last_lr()[0]})
# Average NMSE across training batches
train_nmse /= max(train_samples, 1)
train_nmse_losses.append(train_nmse)
if epoch % 2 == 0:
# Validation loop across all buckets
model.eval()
val_nmse = 0.0
val_samples = 0
with torch.no_grad():
print(f"\nEpoch {epoch + 1}/{epochs} [Validation]")
for length, val_loader in val_loaders.items():
print(f"Processing sequences of length {length}")
with tqdm(val_loader, desc=f"Length {length} [Validation]", unit="batch") as t:
for batch in t:
# val_batches += 1
# Move data to device
input_ids, masked_tokens, masked_pos = [b.to(device) for b in batch]
# Forward pass
logits_lm, _, _ = model(input_ids, masked_pos)
# Compute NMSE
loss = torch.sum(nmse_loss(masked_tokens, logits_lm))
val_nmse += loss.item()
val_samples += input_ids.shape[0]
# Update progress bar
t.set_postfix({"nmse": val_nmse/val_samples})
# Average NMSE across validation batches
val_nmse /= max(val_samples, 1)
val_nmse_losses.append(val_nmse)
# Save model if validation NMSE improves
is_best_model = False
if val_nmse < best_val_nmse:
best_val_nmse = val_nmse
model_path = os.path.join(save_dir, f"lwm_epoch{epoch+1}_train{train_nmse:.4f}_val{val_nmse:.4f}.pth")
torch.save(model.state_dict(), model_path)
print(f"Model saved: {model_path}")
is_best_model = True
# Log the results
print(f" Train NMSE: {train_nmse:.4f}")
print(f" Validation NMSE: {val_nmse:.4f}")
print(f" Learning Rate: {scheduler.get_last_lr()[0]:.6e}")
# Append to CSV log
with open(log_file, mode='a', newline='') as file:
writer = csv.writer(file)
writer.writerow([epoch + 1, train_nmse, val_nmse, scheduler.get_last_lr()[0], is_best_model])
# Plot losses after each epoch
plt.figure(figsize=(10, 6))
plt.plot(range(1, len(train_nmse_losses) + 1), train_nmse_losses, label="Train NMSE")
plt.plot(range(1, len(val_nmse_losses) + 1), val_nmse_losses, label="Validation NMSE")
plt.xlabel("Epochs")
plt.ylabel("NMSE")
plt.title("Training and Validation NMSE Loss")
plt.legend()
plt.grid(True)
plt.show()
print("Training and validation complete.")
return model
#%% FINE-TUNE
from torch.cuda.amp import GradScaler, autocast
# Define the ClassificationHead
class ClassificationHead(nn.Module):
def __init__(self, input_dim, num_classes):
super().__init__()
self.fc = nn.Linear(input_dim, num_classes)
def forward(self, x):
return self.fc(x)
# Define the RegressionHead
class RegressionHead(nn.Module):
def __init__(self, input_dim):
super().__init__()
self.fc = nn.Linear(input_dim, 1)
def forward(self, x):
return self.fc(x)
class CustomClassificationHead(nn.Module):
def __init__(self, input_dim, num_classes):
super().__init__()
self.classifier = nn.Sequential(
nn.Linear(input_dim, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(512, 256),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(256, 128),
nn.BatchNorm1d(128),
nn.ReLU(),
# nn.Dropout(0.1),
nn.Linear(128, num_classes)
)
def forward(self, x):
return self.classifier(x)
class CustomRegressionHead(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
self.regressor = nn.Sequential(
nn.Linear(input_dim, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(512, 256),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(256, output_dim)
)
def forward(self, x):
return self.regressor(x)
def custom_heads(input_dim, num_classes=None, output_dim=None, task_type="classification"):
"""
Creates a custom head for classification or regression tasks.
Users should modify the class implementations for further customization.
Args:
input_dim (int): Input dimension of the head.
num_classes (int): Number of classes for classification tasks. Ignored for regression.
task_type (str): "classification" or "regression".
Returns:
nn.Module: Custom head for the specified task.
"""
if task_type == "classification":
if num_classes is None:
raise ValueError("num_classes must be specified for classification tasks.")
return CustomClassificationHead(input_dim=input_dim, num_classes=num_classes)
elif task_type == "regression":
return CustomRegressionHead(input_dim=input_dim, output_dim=output_dim)
else:
raise ValueError("Invalid task_type. Choose 'classification' or 'regression'.")
#%%
# Fine-tuning wrapper for the base model
class FineTuningWrapper(nn.Module):
def __init__(self, model, task_head, fine_tune_layers="full"):
super().__init__()
self.model = model
self.task_head = task_head
# Freeze all layers initially
for param in self.model.parameters():
param.requires_grad = False
# Handle fine-tuning layers
if fine_tune_layers is not None:
if fine_tune_layers == "full":
# Unfreeze all layers if "all" is specified
for param in self.model.parameters():
param.requires_grad = True
else:
# Get a list of all available layer names in the model
available_layers = [name for name, _ in self.model.named_parameters()]
# Validate that specified layers exist in the model
for layer in fine_tune_layers:
if not any(layer in lname for lname in available_layers):
raise ValueError(
f"Layer '{layer}' not found in the model. "
f"Available layers: {available_layers}"
)
# Unfreeze only the specified layers
for name, param in self.model.named_parameters():
if any(layer in name for layer in fine_tune_layers):
param.requires_grad = True
def forward(self, x, input_type="cls_emb"):
if input_type == "raw":
task_input = x.view(x.size(0), -1)
else:
embeddings, attn_maps = self.model(x) # Get embeddings from the base model
if input_type == "cls_emb":
task_input = embeddings[:, 0, :] # CLS token
elif input_type == "channel_emb":
chs_emb = embeddings[:, 1:, :]
task_input = chs_emb.view(chs_emb.size(0), -1) # embeddings.mean(dim=1) # Mean pooling over channel embeddings
return self.task_head(task_input), 0 if input_type=="raw" else attn_maps
#%%
# Fine-tuning function
from sklearn.metrics import f1_score
def finetune(
base_model,
train_loader,
val_loader=None,
task_type="classification",
input_type="cls_emb",
num_classes=None,
output_dim=None,
use_custom_head=False,
fine_tune_layers=None,
optimizer_config=None,
criterion=None,
epochs=10,
device="cuda",
task="Beam Prediction"
):
"""
Configures and fine-tunes the base model with user-defined settings, saving results and models.
"""
# Create results folder
time_now = f"{time.time():.0f}"
results_folder = f"results/{task}/{time_now}"
os.makedirs(results_folder, exist_ok=True)
log_file = os.path.join(results_folder, "training_log.csv")
# Initialize the CSV log
with open(log_file, mode='w', newline='') as file:
writer = csv.writer(file)
writer.writerow(["Task", "Input", "Epoch", "Train Loss", "Validation Loss", "F1-Score (Classification)", "Learning Rate", "Time"])
for batch in val_loader:
input_data, targets = batch[0].to(device), batch[1].to(device)
break
if input_type == "cls_emb":
n_patches = 1
patch_size = 128
elif input_type == "channel_emb":
n_patches = input_data.shape[1]-1
patch_size = 128
elif input_type == "raw":
n_patches = input_data.shape[1]
patch_size = 32
# patch_size = 1
if use_custom_head:
custom_head = custom_heads(input_dim=n_patches*patch_size,
num_classes=num_classes,
output_dim=output_dim,
task_type=task_type)
# Handle DataParallel models
if isinstance(base_model, nn.DataParallel):
base_model = base_model.module
# Set up the task-specific head
if use_custom_head:
task_head = custom_head
elif task_type == "classification":
if num_classes is None:
raise ValueError("num_classes must be specified for classification tasks.")
task_head = ClassificationHead(input_dim=n_patches*patch_size, num_classes=num_classes) # input_dim=base_model.embedding.d_model
elif task_type == "regression":
task_head = RegressionHead(input_dim=n_patches*patch_size) # input_dim=base_model.embedding.d_model
else:
raise ValueError("Invalid task_type. Choose 'classification' or 'regression'.")
# Wrap the model with the fine-tuning head
wrapper = FineTuningWrapper(base_model, task_head, fine_tune_layers=fine_tune_layers)
wrapper = wrapper.to(device)
print(f'Number of head parameters: {count_parameters(wrapper)}')
# Set default optimizer config if not provided
if optimizer_config is None:
optimizer_config = {"lr": 1e-4}
# Set up the optimizer
optimizer = torch.optim.Adam(wrapper.parameters(), **optimizer_config)
# Set up the scheduler for learning rate decay
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.8) # Example: Reduce LR by 10x every 10 epochs
# Set up the loss criterion
if criterion is None:
criterion = nn.CrossEntropyLoss() if task_type == "classification" else nn.MSELoss()
scaler = GradScaler()
train_losses, val_losses, f1_scores = [], [], []
best_val_loss = float("inf")
best_model_path = None
for epoch in range(epochs):
# Training loop
wrapper.train()
epoch_loss = 0.0
with tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs}") as progress_bar:
for batch in progress_bar:
input_data, targets = batch[0].to(device), batch[1].to(device)
optimizer.zero_grad()
with autocast():
outputs, attn_maps = wrapper(input_data, input_type=input_type)
loss = criterion(outputs, targets)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
epoch_loss += loss.item()
progress_bar.set_postfix({"Loss": loss.item()})
avg_train_loss = epoch_loss / len(train_loader)
train_losses.append(avg_train_loss)
# Validation loop
if val_loader:
wrapper.eval()
val_loss = 0.0
all_preds, all_targets = [], []
with torch.no_grad():
for batch in val_loader:
input_data, targets = batch[0].to(device), batch[1].to(device)
with autocast():
outputs, _ = wrapper(input_data, input_type=input_type)
loss = criterion(outputs, targets)
val_loss += loss.item()
if task_type == "classification":
preds = torch.argmax(outputs, dim=1).cpu().numpy()
all_preds.extend(preds)
all_targets.extend(targets.cpu().numpy())
avg_val_loss = val_loss / len(val_loader)
val_losses.append(avg_val_loss)
time_now = f"{time.time():.0f}"
# Save the best model
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
best_model_path = os.path.join(results_folder, f"{input_type}_epoch{epoch+1}_valLoss{avg_val_loss:.4f}_{time_now}.pth")
torch.save(wrapper.state_dict(), best_model_path)
print(f"Model saved at {best_model_path} with validation loss: {best_val_loss:.4f}")
# Compute F1-score for classification tasks
f1 = None
if task_type == "classification":
f1 = f1_score(all_targets, all_preds, average="weighted")
print(f"Epoch {epoch + 1}, Validation F1-Score: {f1:.4f}")
f1_scores.append(f1)
scheduler.step()
# Log results
with open(log_file, mode='a', newline='') as file:
writer = csv.writer(file)
writer.writerow([task, input_type, epoch + 1, avg_train_loss, avg_val_loss, f1 if f1 is not None else "-", scheduler.get_last_lr()[0], f"{time_now}"])
# Plot training and validation losses
plt.figure(figsize=(10, 6))
plt.plot(range(1, epochs + 1), train_losses, label="Training Loss")
plt.plot(range(1, epochs + 1), val_losses, label="Validation Loss", linestyle="--")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("Training and Validation Loss")
plt.legend()
plt.grid(True)
# plt.savefig(os.path.join(results_folder, "loss_curve.png"))
plt.show()
return wrapper, best_model_path, train_losses, val_losses, f1_scores if task_type == "classification" else 0, attn_maps