import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from resnet_model import ResNet50
from tqdm import tqdm
from torchvision import datasets
from checkpoint import save_checkpoint, load_checkpoint
transform = transforms.Compose([
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
trainset = datasets.ImageFolder(root='/mnt/imagenet/ILSVRC/Data/CLS-LOC/train', transform=transform)
trainloader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=16, pin_memory=True)
testset = datasets.ImageFolder(root='/mnt/imagenet/ILSVRC/Data/CLS-LOC/val', transform=transform )
testloader = DataLoader(testset, batch_size=1000, shuffle=False, num_workers=16, pin_memory=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ResNet50()
model = torch.nn.DataParallel(model)
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
from torch.amp import autocast
from tqdm import tqdm
def train(model, device, train_loader, optimizer, criterion, epoch, accumulation_steps=4):
running_loss = 0.0
correct = 0
total = 0
pbar = tqdm(train_loader)
for batch_idx, (inputs, targets) in enumerate(pbar):
inputs, targets = inputs.to(device), targets.to(device)
with autocast(device_type='cuda'):
outputs = model(inputs)
loss = criterion(outputs, targets) / accumulation_steps
if (batch_idx + 1) % accumulation_steps == 0 or (batch_idx + 1) == len(train_loader):
running_loss += loss.item() * accumulation_steps
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
pbar.set_description(desc=f'Epoch {epoch} | Loss: {running_loss / (batch_idx + 1):.4f} | Accuracy: {100. * correct / total:.2f}%')
if (batch_idx + 1) % 50 == 0:
return 100. * correct / total
def test(model, device, test_loader, criterion):
test_loss = 0
correct = 0
total = 0
with torch.no_grad():
for inputs, targets in test_loader:
inputs, targets = inputs.to(device), targets.to(device)
outputs = model(inputs)
loss = criterion(outputs, targets)
test_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
test_accuracy = 100.*correct/total
print(f'Test Loss: {test_loss/len(test_loader):.4f}, Accuracy: {test_accuracy:.2f}%')
return test_accuracy, test_loss/len(test_loader)
if __name__ == '__main__':
checkpoint_path = "checkpoint.pth"
best_loss = float('inf')
patience = 5
patience_counter = 0
model, optimizer, best_test_accuracy = load_checkpoint(model, optimizer, checkpoint_path)
except FileNotFoundError:
print("No checkpoint found, starting from scratch.")
for epoch in range(1, 6):
train_accuracy = train(model, device, trainloader, optimizer, criterion, epoch)
test_accuracy, test_loss = test(model, device, testloader, criterion)
print(f'Epoch {epoch} | Train Accuracy: {train_accuracy:.2f}% | Test Accuracy: {test_accuracy:.2f}%')
if test_loss < best_loss:
best_loss = test_loss
patience_counter = 0
save_checkpoint(model, optimizer, epoch, test_loss, checkpoint_path)
patience_counter += 1
if patience_counter >= patience:
print("Early stopping triggered. Training terminated.")