Ubuntu commited on
Commit
2e9c13e
·
1 Parent(s): 2959565

lower batch size and accumulation argument. Changed misclassified samples to be for last epoch only

Browse files
Files changed (1) hide show
  1. resnet_execute.py +15 -13
resnet_execute.py CHANGED
@@ -35,7 +35,7 @@ trainset = datasets.ImageFolder(root='/mnt/imagenet/ILSVRC/Data/CLS-LOC/train',
35
  trainloader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=16, pin_memory=True)
36
 
37
  testset = datasets.ImageFolder(root='/mnt/imagenet/ILSVRC/Data/CLS-LOC/val', transform=lambda img: test_transform(image=np.array(img))['image'])
38
- testloader = DataLoader(testset, batch_size=1000, shuffle=False, num_workers=16, pin_memory=True)
39
 
40
  # Initialize model, loss function, and optimizer
41
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -49,7 +49,7 @@ optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e
49
  # Training function
50
  from torch.amp import autocast
51
 
52
- def train(model, device, train_loader, optimizer, criterion, epoch, accumulation_steps=4):
53
  model.train()
54
  running_loss = 0.0
55
  correct1 = 0
@@ -135,7 +135,7 @@ if __name__ == '__main__':
135
  results = []
136
  learning_rates = []
137
 
138
- for epoch in range(1, 6): # 20 epochs
139
  train_accuracy1, train_accuracy5, train_loss = train(model, device, trainloader, optimizer, criterion, epoch)
140
  test_accuracy1, test_accuracy5, test_loss, misclassified_images, misclassified_labels, misclassified_preds = test(model, device, testloader, criterion)
141
  print(f'Epoch {epoch} | Train Top-1 Acc: {train_accuracy1:.2f} | Train Top-5 Acc: {train_accuracy5:.2f} | Test Top-1 Acc: {test_accuracy1:.2f} | Test Top-5 Acc: {test_accuracy5:.2f}')
@@ -155,6 +155,18 @@ if __name__ == '__main__':
155
  print("Early stopping triggered. Training terminated.")
156
  break
157
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  # Print the Top-1 accuracy results in a tab-separated format
159
  print("\nEpoch\tTrain Top-1 Accuracy\tTest Top-1 Accuracy")
160
  for epoch, train_acc1, test_acc1, *_ in results:
@@ -203,13 +215,3 @@ if __name__ == '__main__':
203
 
204
  plt.tight_layout()
205
  plt.show()
206
-
207
- # Display some misclassified samples
208
- if misclassified_images:
209
- print("\nDisplaying some misclassified samples from the last epoch:")
210
- misclassified_grid = make_grid(misclassified_images[:16], nrow=4, normalize=True, scale_each=True)
211
- plt.figure(figsize=(8, 8))
212
- plt.imshow(misclassified_grid.permute(1, 2, 0))
213
- plt.title("Misclassified Samples")
214
- plt.axis('off')
215
- plt.show()
 
35
  trainloader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=16, pin_memory=True)
36
 
37
  testset = datasets.ImageFolder(root='/mnt/imagenet/ILSVRC/Data/CLS-LOC/val', transform=lambda img: test_transform(image=np.array(img))['image'])
38
+ testloader = DataLoader(testset, batch_size=500, shuffle=False, num_workers=16, pin_memory=True)
39
 
40
  # Initialize model, loss function, and optimizer
41
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
49
  # Training function
50
  from torch.amp import autocast
51
 
52
+ def train(model, device, train_loader, optimizer, criterion, epoch, accumulation_steps=2):
53
  model.train()
54
  running_loss = 0.0
55
  correct1 = 0
 
135
  results = []
136
  learning_rates = []
137
 
138
+ for epoch in range(1, 26): # 20 epochs
139
  train_accuracy1, train_accuracy5, train_loss = train(model, device, trainloader, optimizer, criterion, epoch)
140
  test_accuracy1, test_accuracy5, test_loss, misclassified_images, misclassified_labels, misclassified_preds = test(model, device, testloader, criterion)
141
  print(f'Epoch {epoch} | Train Top-1 Acc: {train_accuracy1:.2f} | Train Top-5 Acc: {train_accuracy5:.2f} | Test Top-1 Acc: {test_accuracy1:.2f} | Test Top-5 Acc: {test_accuracy5:.2f}')
 
155
  print("Early stopping triggered. Training terminated.")
156
  break
157
 
158
+ # Only process misclassified samples after the last epoch
159
+ if epoch == 25:
160
+ # Display or process misclassified samples
161
+ if misclassified_images:
162
+ print("\nDisplaying some misclassified samples from the last epoch:")
163
+ misclassified_grid = make_grid(misclassified_images[:16], nrow=4, normalize=True, scale_each=True)
164
+ plt.figure(figsize=(8, 8))
165
+ plt.imshow(misclassified_grid.permute(1, 2, 0))
166
+ plt.title("Misclassified Samples")
167
+ plt.axis('off')
168
+ plt.show()
169
+
170
  # Print the Top-1 accuracy results in a tab-separated format
171
  print("\nEpoch\tTrain Top-1 Accuracy\tTest Top-1 Accuracy")
172
  for epoch, train_acc1, test_acc1, *_ in results:
 
215
 
216
  plt.tight_layout()
217
  plt.show()