Ubuntu
commited on
Commit
·
2e9c13e
1
Parent(s):
2959565
lower batch size and accumulation argument. Changed misclassified samples to be for last epoch only
Browse files- resnet_execute.py +15 -13
resnet_execute.py
CHANGED
@@ -35,7 +35,7 @@ trainset = datasets.ImageFolder(root='/mnt/imagenet/ILSVRC/Data/CLS-LOC/train',
|
|
35 |
trainloader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=16, pin_memory=True)
|
36 |
|
37 |
testset = datasets.ImageFolder(root='/mnt/imagenet/ILSVRC/Data/CLS-LOC/val', transform=lambda img: test_transform(image=np.array(img))['image'])
|
38 |
-
testloader = DataLoader(testset, batch_size=
|
39 |
|
40 |
# Initialize model, loss function, and optimizer
|
41 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
@@ -49,7 +49,7 @@ optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e
|
|
49 |
# Training function
|
50 |
from torch.amp import autocast
|
51 |
|
52 |
-
def train(model, device, train_loader, optimizer, criterion, epoch, accumulation_steps=
|
53 |
model.train()
|
54 |
running_loss = 0.0
|
55 |
correct1 = 0
|
@@ -135,7 +135,7 @@ if __name__ == '__main__':
|
|
135 |
results = []
|
136 |
learning_rates = []
|
137 |
|
138 |
-
for epoch in range(1,
|
139 |
train_accuracy1, train_accuracy5, train_loss = train(model, device, trainloader, optimizer, criterion, epoch)
|
140 |
test_accuracy1, test_accuracy5, test_loss, misclassified_images, misclassified_labels, misclassified_preds = test(model, device, testloader, criterion)
|
141 |
print(f'Epoch {epoch} | Train Top-1 Acc: {train_accuracy1:.2f} | Train Top-5 Acc: {train_accuracy5:.2f} | Test Top-1 Acc: {test_accuracy1:.2f} | Test Top-5 Acc: {test_accuracy5:.2f}')
|
@@ -155,6 +155,18 @@ if __name__ == '__main__':
|
|
155 |
print("Early stopping triggered. Training terminated.")
|
156 |
break
|
157 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
# Print the Top-1 accuracy results in a tab-separated format
|
159 |
print("\nEpoch\tTrain Top-1 Accuracy\tTest Top-1 Accuracy")
|
160 |
for epoch, train_acc1, test_acc1, *_ in results:
|
@@ -203,13 +215,3 @@ if __name__ == '__main__':
|
|
203 |
|
204 |
plt.tight_layout()
|
205 |
plt.show()
|
206 |
-
|
207 |
-
# Display some misclassified samples
|
208 |
-
if misclassified_images:
|
209 |
-
print("\nDisplaying some misclassified samples from the last epoch:")
|
210 |
-
misclassified_grid = make_grid(misclassified_images[:16], nrow=4, normalize=True, scale_each=True)
|
211 |
-
plt.figure(figsize=(8, 8))
|
212 |
-
plt.imshow(misclassified_grid.permute(1, 2, 0))
|
213 |
-
plt.title("Misclassified Samples")
|
214 |
-
plt.axis('off')
|
215 |
-
plt.show()
|
|
|
35 |
trainloader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=16, pin_memory=True)
|
36 |
|
37 |
testset = datasets.ImageFolder(root='/mnt/imagenet/ILSVRC/Data/CLS-LOC/val', transform=lambda img: test_transform(image=np.array(img))['image'])
|
38 |
+
testloader = DataLoader(testset, batch_size=500, shuffle=False, num_workers=16, pin_memory=True)
|
39 |
|
40 |
# Initialize model, loss function, and optimizer
|
41 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
49 |
# Training function
|
50 |
from torch.amp import autocast
|
51 |
|
52 |
+
def train(model, device, train_loader, optimizer, criterion, epoch, accumulation_steps=2):
|
53 |
model.train()
|
54 |
running_loss = 0.0
|
55 |
correct1 = 0
|
|
|
135 |
results = []
|
136 |
learning_rates = []
|
137 |
|
138 |
+
for epoch in range(1, 26): # 20 epochs
|
139 |
train_accuracy1, train_accuracy5, train_loss = train(model, device, trainloader, optimizer, criterion, epoch)
|
140 |
test_accuracy1, test_accuracy5, test_loss, misclassified_images, misclassified_labels, misclassified_preds = test(model, device, testloader, criterion)
|
141 |
print(f'Epoch {epoch} | Train Top-1 Acc: {train_accuracy1:.2f} | Train Top-5 Acc: {train_accuracy5:.2f} | Test Top-1 Acc: {test_accuracy1:.2f} | Test Top-5 Acc: {test_accuracy5:.2f}')
|
|
|
155 |
print("Early stopping triggered. Training terminated.")
|
156 |
break
|
157 |
|
158 |
+
# Only process misclassified samples after the last epoch
|
159 |
+
if epoch == 25:
|
160 |
+
# Display or process misclassified samples
|
161 |
+
if misclassified_images:
|
162 |
+
print("\nDisplaying some misclassified samples from the last epoch:")
|
163 |
+
misclassified_grid = make_grid(misclassified_images[:16], nrow=4, normalize=True, scale_each=True)
|
164 |
+
plt.figure(figsize=(8, 8))
|
165 |
+
plt.imshow(misclassified_grid.permute(1, 2, 0))
|
166 |
+
plt.title("Misclassified Samples")
|
167 |
+
plt.axis('off')
|
168 |
+
plt.show()
|
169 |
+
|
170 |
# Print the Top-1 accuracy results in a tab-separated format
|
171 |
print("\nEpoch\tTrain Top-1 Accuracy\tTest Top-1 Accuracy")
|
172 |
for epoch, train_acc1, test_acc1, *_ in results:
|
|
|
215 |
|
216 |
plt.tight_layout()
|
217 |
plt.show()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|