Spaces:
Build error
Build error
| import torch | |
| import matplotlib.pyplot as plt | |
| from torchvision.utils import make_grid | |
| def save_checkpoint(model, optimizer, epoch, loss, path): | |
| torch.save({ | |
| 'epoch': epoch, | |
| 'model_state_dict': model.state_dict(), | |
| 'optimizer_state_dict': optimizer.state_dict(), | |
| 'loss': loss, | |
| }, path) | |
| print(f"Checkpoint saved at epoch {epoch}") | |
| def load_checkpoint(model, optimizer, checkpoint_path): | |
| # Use map_location to load the checkpoint on CPU if CUDA is not available | |
| map_location = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| checkpoint = torch.load(checkpoint_path, map_location=map_location) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| if optimizer is not None: | |
| optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | |
| start_epoch = checkpoint['epoch'] | |
| loss = checkpoint['loss'] | |
| return model, optimizer, start_epoch, loss | |
| def plot_training_curves(epochs, train_acc1, test_acc1, train_acc5, test_acc5, train_losses, test_losses, learning_rates): | |
| plt.figure(figsize=(12, 8)) | |
| plt.subplot(2, 2, 1) | |
| plt.plot(epochs, train_acc1, label='Train Top-1 Acc') | |
| plt.plot(epochs, test_acc1, label='Test Top-1 Acc') | |
| plt.xlabel('Epoch') | |
| plt.ylabel('Accuracy') | |
| plt.legend() | |
| plt.title('Top-1 Accuracy') | |
| plt.subplot(2, 2, 2) | |
| plt.plot(epochs, train_acc5, label='Train Top-5 Acc') | |
| plt.plot(epochs, test_acc5, label='Test Top-5 Acc') | |
| plt.xlabel('Epoch') | |
| plt.ylabel('Accuracy') | |
| plt.legend() | |
| plt.title('Top-5 Accuracy') | |
| plt.subplot(2, 2, 3) | |
| plt.plot(epochs, train_losses, label='Train Loss') | |
| plt.plot(epochs, test_losses, label='Test Loss') | |
| plt.xlabel('Epoch') | |
| plt.ylabel('Loss') | |
| plt.legend() | |
| plt.title('Loss') | |
| plt.subplot(2, 2, 4) | |
| plt.plot(epochs, learning_rates, label='Learning Rate') | |
| plt.xlabel('Epoch') | |
| plt.ylabel('Learning Rate') | |
| plt.legend() | |
| plt.title('Learning Rate') | |
| plt.tight_layout() | |
| plt.show() | |
| def plot_misclassified_samples(misclassified_images, misclassified_labels, misclassified_preds, classes): | |
| if misclassified_images: | |
| print("\nDisplaying some misclassified samples:") | |
| misclassified_grid = make_grid(misclassified_images[:16], nrow=4, normalize=True, scale_each=True) | |
| plt.figure(figsize=(8, 8)) | |
| plt.imshow(misclassified_grid.permute(1, 2, 0)) | |
| plt.title("Misclassified Samples") | |
| plt.axis('off') | |
| plt.show() | |
| def find_lr(model, criterion, optimizer, train_loader, num_epochs=1, start_lr=1e-7, end_lr=10, lr_multiplier=1.1): | |
| """ | |
| Find the optimal learning rate using LR Finder. | |
| Args: | |
| - model: The model to train | |
| - criterion: Loss function (e.g., CrossEntropyLoss) | |
| - optimizer: Optimizer (e.g., SGD) | |
| - train_loader: DataLoader for training data | |
| - num_epochs: Number of epochs to run the LR Finder (typically 1-2) | |
| - start_lr: Starting learning rate for the experiment | |
| - end_lr: Maximum learning rate (used for scaling) | |
| - lr_multiplier: Factor by which the learning rate is increased every batch | |
| Returns: | |
| - A plot of loss vs learning rate | |
| """ | |
| lrs = [] | |
| losses = [] | |
| avg_loss = 0.0 | |
| batch_count = 0 | |
| lr = start_lr | |
| for epoch in range(num_epochs): | |
| model.train() | |
| for inputs, labels in train_loader: | |
| inputs, labels = inputs.to(device), labels.to(device) | |
| optimizer.param_groups[0]['lr'] = lr # Set the learning rate | |
| # Forward pass | |
| optimizer.zero_grad() | |
| outputs = model(inputs) | |
| loss = criterion(outputs, labels) | |
| loss.backward() | |
| optimizer.step() | |
| avg_loss += loss.item() | |
| batch_count += 1 | |
| lrs.append(lr) | |
| losses.append(loss.item()) | |
| # Increase the learning rate for next batch | |
| lr *= lr_multiplier | |
| avg_loss /= batch_count | |
| print(f"Epoch [{epoch+1}/{num_epochs}], Avg Loss: {avg_loss:.4f}") | |
| # Plot the loss vs learning rate | |
| plt.plot(lrs, losses) | |
| plt.xscale('log') | |
| plt.xlabel('Learning Rate') | |
| plt.ylabel('Loss') | |
| plt.title('Learning Rate Finder') | |
| plt.show() | |