File size: 3,214 Bytes
7fcd26a 83a3d89 653a242 7fcd26a 83a3d89 7fcd26a 653a242 7fcd26a 653a242 7fcd26a 653a242 7fcd26a 653a242 7fcd26a 83a3d89 7fcd26a 83a3d89 7fcd26a 83a3d89 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
import torch
import torch.nn as nn
import torch.optim as optim
from resnet_model import ResNet50
from data_utils import get_train_transform, get_test_transform, get_data_loaders
from train_test import train, test
from utils import save_checkpoint, load_checkpoint, plot_training_curves, plot_misclassified_samples
from torchsummary import summary
from torch.optim.lr_scheduler import OneCycleLR
def main():
# Initialize model, loss function, and optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ResNet50()
model = torch.nn.DataParallel(model)
model =
summary(model, input_size=(3, 224, 224))
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.9, weight_decay=1e-4)
# Load data
train_transform = get_train_transform()
test_transform = get_test_transform()
trainloader, testloader = get_data_loaders(train_transform, test_transform)
# Load checkpoint if it exists
checkpoint_path = "checkpoint.pth"
model, optimizer, start_epoch, _ = load_checkpoint(model, optimizer, checkpoint_path)
except FileNotFoundError:
print("No checkpoint found, starting from scratch.")
start_epoch = 1
# Store results for plotting
results = []
learning_rates = []
# Set One-Cycle LR scheduler
num_epochs = 10
steps_per_epoch = len(trainloader)
lr_max = 1e-2
scheduler = OneCycleLR(optimizer, max_lr=lr_max, epochs=num_epochs, steps_per_epoch=steps_per_epoch)
# Training loop
for epoch in range(start_epoch+1, start_epoch + num_epochs):
train_accuracy1, train_accuracy5, train_loss = train(model, device, trainloader, optimizer, criterion, epoch)
test_accuracy1, test_accuracy5, test_loss, misclassified_images, misclassified_labels, misclassified_preds = test(model, device, testloader, criterion)
print(f'Epoch {epoch} | Train Top-1 Acc: {train_accuracy1:.2f} | Test Top-1 Acc: {test_accuracy1:.2f}')
# Append results for this epoch
results.append((epoch, train_accuracy1, train_accuracy5, test_accuracy1, test_accuracy5, train_loss, test_loss))
# Save checkpoint
save_checkpoint(model, optimizer, epoch, test_loss, checkpoint_path)
# Extract results for plotting
epochs = [r[0] for r in results]
train_acc1 = [r[1] for r in results]
train_acc5 = [r[2] for r in results]
test_acc1 = [r[3] for r in results]
test_acc5 = [r[4] for r in results]
train_losses = [r[5] for r in results]
test_losses = [r[6] for r in results]
# Plot training curves
plot_training_curves(epochs, train_acc1, test_acc1, train_acc5, test_acc5, train_losses, test_losses, learning_rates)
# Plot misclassified samples
plot_misclassified_samples(misclassified_images, misclassified_labels, misclassified_preds, classes=['class1', 'class2', ...]) # Replace with actual class names
if __name__ == '__main__':