|
|
|
""" |
|
Created on Fri Dec 20 09:32:12 2024 |
|
|
|
This script contains the LWM pre-training and task-specific fine-tuning functions. |
|
|
|
@author: Sadjad Alikhani |
|
""" |
|
import torch |
|
import torch.nn as nn |
|
from tqdm import tqdm |
|
import matplotlib.pyplot as plt |
|
import os |
|
import csv |
|
from utils import count_parameters |
|
import time |
|
|
|
def nmse_loss(y_pred, y_true): |
|
y_pred_flat = y_pred.view(y_pred.size(0), -1) |
|
y_true_flat = y_true.view(y_true.size(0), -1) |
|
mse = torch.sum((y_true_flat - y_pred_flat)**2, dim=-1) |
|
normalization = torch.sum(y_true_flat**2, dim=-1) |
|
return mse / normalization |
|
|
|
def train_lwm(model, train_loaders, val_loaders, optimizer, scheduler, epochs, device, save_dir="models", log_file="training_log.csv"): |
|
|
|
if not os.path.exists(save_dir): |
|
os.makedirs(save_dir) |
|
|
|
|
|
if not os.path.exists(log_file): |
|
with open(log_file, mode='w', newline='') as file: |
|
writer = csv.writer(file) |
|
writer.writerow(["Epoch", "Train NMSE", "Validation NMSE", "Learning Rate", "Best Model"]) |
|
|
|
train_nmse_losses = [] |
|
val_nmse_losses = [] |
|
best_val_nmse = float('inf') |
|
|
|
for epoch in range(epochs): |
|
model.train() |
|
train_nmse = 0.0 |
|
train_samples = 0 |
|
|
|
|
|
print(f"\nEpoch {epoch + 1}/{epochs} [Training]") |
|
for length, train_loader in train_loaders.items(): |
|
print(f"Processing sequences of length {length}") |
|
with tqdm(train_loader, desc=f"Length {length} [Training]", unit="batch") as t: |
|
for batch in t: |
|
|
|
optimizer.zero_grad() |
|
|
|
|
|
input_ids, masked_tokens, masked_pos = [b.to(device) for b in batch] |
|
|
|
|
|
logits_lm, _, _ = model(input_ids, masked_pos) |
|
|
|
|
|
loss = torch.sum(nmse_loss(masked_tokens, logits_lm)) |
|
loss.backward() |
|
optimizer.step() |
|
scheduler.step() |
|
|
|
train_nmse += loss.item() |
|
train_samples += input_ids.shape[0] |
|
|
|
|
|
t.set_postfix({"nmse": train_nmse/train_samples, "lr": scheduler.get_last_lr()[0]}) |
|
|
|
|
|
train_nmse /= max(train_samples, 1) |
|
train_nmse_losses.append(train_nmse) |
|
|
|
if epoch % 2 == 0: |
|
|
|
model.eval() |
|
val_nmse = 0.0 |
|
val_samples = 0 |
|
with torch.no_grad(): |
|
print(f"\nEpoch {epoch + 1}/{epochs} [Validation]") |
|
for length, val_loader in val_loaders.items(): |
|
print(f"Processing sequences of length {length}") |
|
with tqdm(val_loader, desc=f"Length {length} [Validation]", unit="batch") as t: |
|
for batch in t: |
|
|
|
|
|
|
|
input_ids, masked_tokens, masked_pos = [b.to(device) for b in batch] |
|
|
|
|
|
logits_lm, _, _ = model(input_ids, masked_pos) |
|
|
|
|
|
loss = torch.sum(nmse_loss(masked_tokens, logits_lm)) |
|
val_nmse += loss.item() |
|
val_samples += input_ids.shape[0] |
|
|
|
|
|
t.set_postfix({"nmse": val_nmse/val_samples}) |
|
|
|
|
|
val_nmse /= max(val_samples, 1) |
|
val_nmse_losses.append(val_nmse) |
|
|
|
|
|
is_best_model = False |
|
if val_nmse < best_val_nmse: |
|
best_val_nmse = val_nmse |
|
model_path = os.path.join(save_dir, f"lwm_epoch{epoch+1}_train{train_nmse:.4f}_val{val_nmse:.4f}.pth") |
|
torch.save(model.state_dict(), model_path) |
|
print(f"Model saved: {model_path}") |
|
is_best_model = True |
|
|
|
|
|
print(f" Train NMSE: {train_nmse:.4f}") |
|
print(f" Validation NMSE: {val_nmse:.4f}") |
|
print(f" Learning Rate: {scheduler.get_last_lr()[0]:.6e}") |
|
|
|
|
|
with open(log_file, mode='a', newline='') as file: |
|
writer = csv.writer(file) |
|
writer.writerow([epoch + 1, train_nmse, val_nmse, scheduler.get_last_lr()[0], is_best_model]) |
|
|
|
|
|
plt.figure(figsize=(10, 6)) |
|
plt.plot(range(1, len(train_nmse_losses) + 1), train_nmse_losses, label="Train NMSE") |
|
plt.plot(range(1, len(val_nmse_losses) + 1), val_nmse_losses, label="Validation NMSE") |
|
plt.xlabel("Epochs") |
|
plt.ylabel("NMSE") |
|
plt.title("Training and Validation NMSE Loss") |
|
plt.legend() |
|
plt.grid(True) |
|
plt.show() |
|
|
|
print("Training and validation complete.") |
|
return model |
|
|
|
from torch.cuda.amp import GradScaler, autocast |
|
|
|
|
|
class ClassificationHead(nn.Module): |
|
def __init__(self, input_dim, num_classes): |
|
super().__init__() |
|
self.fc = nn.Linear(input_dim, num_classes) |
|
|
|
def forward(self, x): |
|
return self.fc(x) |
|
|
|
|
|
|
|
class RegressionHead(nn.Module): |
|
def __init__(self, input_dim): |
|
super().__init__() |
|
self.fc = nn.Linear(input_dim, 1) |
|
|
|
def forward(self, x): |
|
return self.fc(x) |
|
|
|
class CustomClassificationHead(nn.Module): |
|
def __init__(self, input_dim, num_classes): |
|
|
|
super().__init__() |
|
self.classifier = nn.Sequential( |
|
nn.Linear(input_dim, 512), |
|
nn.BatchNorm1d(512), |
|
nn.ReLU(), |
|
nn.Dropout(0.1), |
|
nn.Linear(512, 256), |
|
nn.BatchNorm1d(256), |
|
nn.ReLU(), |
|
nn.Dropout(0.1), |
|
nn.Linear(256, 128), |
|
nn.BatchNorm1d(128), |
|
nn.ReLU(), |
|
|
|
nn.Linear(128, num_classes) |
|
) |
|
|
|
def forward(self, x): |
|
return self.classifier(x) |
|
|
|
class CustomRegressionHead(nn.Module): |
|
def __init__(self, input_dim, output_dim): |
|
|
|
super().__init__() |
|
self.regressor = nn.Sequential( |
|
nn.Linear(input_dim, 512), |
|
nn.BatchNorm1d(512), |
|
nn.ReLU(), |
|
nn.Dropout(0.1), |
|
nn.Linear(512, 256), |
|
nn.BatchNorm1d(256), |
|
nn.ReLU(), |
|
nn.Dropout(0.1), |
|
nn.Linear(256, output_dim) |
|
) |
|
|
|
def forward(self, x): |
|
return self.regressor(x) |
|
|
|
|
|
def custom_heads(input_dim, num_classes=None, output_dim=None, task_type="classification"): |
|
""" |
|
Creates a custom head for classification or regression tasks. |
|
Users should modify the class implementations for further customization. |
|
|
|
Args: |
|
input_dim (int): Input dimension of the head. |
|
num_classes (int): Number of classes for classification tasks. Ignored for regression. |
|
task_type (str): "classification" or "regression". |
|
|
|
Returns: |
|
nn.Module: Custom head for the specified task. |
|
""" |
|
if task_type == "classification": |
|
if num_classes is None: |
|
raise ValueError("num_classes must be specified for classification tasks.") |
|
return CustomClassificationHead(input_dim=input_dim, num_classes=num_classes) |
|
elif task_type == "regression": |
|
return CustomRegressionHead(input_dim=input_dim, output_dim=output_dim) |
|
else: |
|
raise ValueError("Invalid task_type. Choose 'classification' or 'regression'.") |
|
|
|
|
|
class FineTuningWrapper(nn.Module): |
|
def __init__(self, model, task_head, fine_tune_layers="full"): |
|
super().__init__() |
|
self.model = model |
|
self.task_head = task_head |
|
|
|
|
|
for param in self.model.parameters(): |
|
param.requires_grad = False |
|
|
|
|
|
if fine_tune_layers is not None: |
|
if fine_tune_layers == "full": |
|
|
|
for param in self.model.parameters(): |
|
param.requires_grad = True |
|
else: |
|
|
|
available_layers = [name for name, _ in self.model.named_parameters()] |
|
|
|
|
|
for layer in fine_tune_layers: |
|
if not any(layer in lname for lname in available_layers): |
|
raise ValueError( |
|
f"Layer '{layer}' not found in the model. " |
|
f"Available layers: {available_layers}" |
|
) |
|
|
|
|
|
for name, param in self.model.named_parameters(): |
|
if any(layer in name for layer in fine_tune_layers): |
|
param.requires_grad = True |
|
|
|
def forward(self, x, input_type="cls_emb"): |
|
if input_type == "raw": |
|
task_input = x.view(x.size(0), -1) |
|
else: |
|
embeddings, attn_maps = self.model(x) |
|
if input_type == "cls_emb": |
|
task_input = embeddings[:, 0, :] |
|
elif input_type == "channel_emb": |
|
chs_emb = embeddings[:, 1:, :] |
|
task_input = chs_emb.view(chs_emb.size(0), -1) |
|
|
|
return self.task_head(task_input), 0 if input_type=="raw" else attn_maps |
|
|
|
|
|
from sklearn.metrics import f1_score |
|
def finetune( |
|
base_model, |
|
train_loader, |
|
val_loader=None, |
|
task_type="classification", |
|
input_type="cls_emb", |
|
num_classes=None, |
|
output_dim=None, |
|
use_custom_head=False, |
|
fine_tune_layers=None, |
|
optimizer_config=None, |
|
criterion=None, |
|
epochs=10, |
|
device="cuda", |
|
task="Beam Prediction" |
|
): |
|
""" |
|
Configures and fine-tunes the base model with user-defined settings, saving results and models. |
|
""" |
|
|
|
time_now = f"{time.time():.0f}" |
|
results_folder = f"results/{task}/{time_now}" |
|
os.makedirs(results_folder, exist_ok=True) |
|
log_file = os.path.join(results_folder, "training_log.csv") |
|
|
|
|
|
with open(log_file, mode='w', newline='') as file: |
|
writer = csv.writer(file) |
|
writer.writerow(["Task", "Input", "Epoch", "Train Loss", "Validation Loss", "F1-Score (Classification)", "Learning Rate", "Time"]) |
|
|
|
for batch in val_loader: |
|
input_data, targets = batch[0].to(device), batch[1].to(device) |
|
break |
|
|
|
if input_type == "cls_emb": |
|
n_patches = 1 |
|
patch_size = 128 |
|
elif input_type == "channel_emb": |
|
n_patches = input_data.shape[1]-1 |
|
patch_size = 128 |
|
elif input_type == "raw": |
|
n_patches = input_data.shape[1] |
|
patch_size = 32 |
|
|
|
|
|
if use_custom_head: |
|
custom_head = custom_heads(input_dim=n_patches*patch_size, |
|
num_classes=num_classes, |
|
output_dim=output_dim, |
|
task_type=task_type) |
|
|
|
|
|
if isinstance(base_model, nn.DataParallel): |
|
base_model = base_model.module |
|
|
|
|
|
if use_custom_head: |
|
task_head = custom_head |
|
elif task_type == "classification": |
|
if num_classes is None: |
|
raise ValueError("num_classes must be specified for classification tasks.") |
|
task_head = ClassificationHead(input_dim=n_patches*patch_size, num_classes=num_classes) |
|
elif task_type == "regression": |
|
task_head = RegressionHead(input_dim=n_patches*patch_size) |
|
else: |
|
raise ValueError("Invalid task_type. Choose 'classification' or 'regression'.") |
|
|
|
|
|
wrapper = FineTuningWrapper(base_model, task_head, fine_tune_layers=fine_tune_layers) |
|
wrapper = wrapper.to(device) |
|
|
|
print(f'Number of head parameters: {count_parameters(wrapper)}') |
|
|
|
|
|
if optimizer_config is None: |
|
optimizer_config = {"lr": 1e-4} |
|
|
|
optimizer = torch.optim.Adam(wrapper.parameters(), **optimizer_config) |
|
|
|
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.8) |
|
|
|
|
|
if criterion is None: |
|
criterion = nn.CrossEntropyLoss() if task_type == "classification" else nn.MSELoss() |
|
|
|
scaler = GradScaler() |
|
train_losses, val_losses, f1_scores = [], [], [] |
|
best_val_loss = float("inf") |
|
best_model_path = None |
|
|
|
for epoch in range(epochs): |
|
|
|
wrapper.train() |
|
epoch_loss = 0.0 |
|
|
|
with tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs}") as progress_bar: |
|
for batch in progress_bar: |
|
input_data, targets = batch[0].to(device), batch[1].to(device) |
|
optimizer.zero_grad() |
|
|
|
with autocast(): |
|
outputs, attn_maps = wrapper(input_data, input_type=input_type) |
|
loss = criterion(outputs, targets) |
|
|
|
scaler.scale(loss).backward() |
|
scaler.step(optimizer) |
|
scaler.update() |
|
|
|
epoch_loss += loss.item() |
|
progress_bar.set_postfix({"Loss": loss.item()}) |
|
|
|
avg_train_loss = epoch_loss / len(train_loader) |
|
train_losses.append(avg_train_loss) |
|
|
|
|
|
if val_loader: |
|
wrapper.eval() |
|
val_loss = 0.0 |
|
all_preds, all_targets = [], [] |
|
|
|
with torch.no_grad(): |
|
for batch in val_loader: |
|
input_data, targets = batch[0].to(device), batch[1].to(device) |
|
with autocast(): |
|
outputs, _ = wrapper(input_data, input_type=input_type) |
|
loss = criterion(outputs, targets) |
|
|
|
val_loss += loss.item() |
|
|
|
if task_type == "classification": |
|
preds = torch.argmax(outputs, dim=1).cpu().numpy() |
|
all_preds.extend(preds) |
|
all_targets.extend(targets.cpu().numpy()) |
|
|
|
avg_val_loss = val_loss / len(val_loader) |
|
val_losses.append(avg_val_loss) |
|
|
|
time_now = f"{time.time():.0f}" |
|
|
|
if avg_val_loss < best_val_loss: |
|
best_val_loss = avg_val_loss |
|
best_model_path = os.path.join(results_folder, f"{input_type}_epoch{epoch+1}_valLoss{avg_val_loss:.4f}_{time_now}.pth") |
|
torch.save(wrapper.state_dict(), best_model_path) |
|
print(f"Model saved at {best_model_path} with validation loss: {best_val_loss:.4f}") |
|
|
|
|
|
f1 = None |
|
if task_type == "classification": |
|
f1 = f1_score(all_targets, all_preds, average="weighted") |
|
print(f"Epoch {epoch + 1}, Validation F1-Score: {f1:.4f}") |
|
f1_scores.append(f1) |
|
|
|
scheduler.step() |
|
|
|
|
|
with open(log_file, mode='a', newline='') as file: |
|
writer = csv.writer(file) |
|
writer.writerow([task, input_type, epoch + 1, avg_train_loss, avg_val_loss, f1 if f1 is not None else "-", scheduler.get_last_lr()[0], f"{time_now}"]) |
|
|
|
|
|
plt.figure(figsize=(10, 6)) |
|
plt.plot(range(1, epochs + 1), train_losses, label="Training Loss") |
|
plt.plot(range(1, epochs + 1), val_losses, label="Validation Loss", linestyle="--") |
|
plt.xlabel("Epochs") |
|
plt.ylabel("Loss") |
|
plt.title("Training and Validation Loss") |
|
plt.legend() |
|
plt.grid(True) |
|
|
|
plt.show() |
|
|
|
return wrapper, best_model_path, train_losses, val_losses, f1_scores if task_type == "classification" else 0, attn_maps |