|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
from unet import UNet |
|
from torch.utils.data import DataLoader |
|
from data import SegmentationDataset, transform_img |
|
|
|
transform = transform_img() |
|
|
|
train_dataset = SegmentationDataset("DUTS-TR-Image", "DUTS-TR-Mask", transform=transform) |
|
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True) |
|
|
|
test_dataset = SegmentationDataset("DUTS-TE-Image", "DUTS-TE-Mask", transform=transform) |
|
test_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True) |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model = UNet().to(device) |
|
criterion = nn.BCEWithLogitsLoss() |
|
optimizer = optim.Adam(model.parameters(), lr=1e-4) |
|
|
|
def evaluate_model(model, dataloader, criterion, device): |
|
model.eval() |
|
total_loss = 0 |
|
total_correct = 0 |
|
total_pixels = 0 |
|
|
|
with torch.no_grad(): |
|
for images, masks in dataloader: |
|
|
|
images = images.to(device) |
|
masks = masks.to(device) |
|
|
|
outputs = model(images) |
|
|
|
loss = criterion(outputs, masks) |
|
total_loss += loss.item() |
|
|
|
preds = torch.sigmoid(outputs) > 0.5 |
|
total_correct += (preds==masks).sum().item() |
|
total_pixels += torch.numel(preds) |
|
|
|
avg_loss = total_loss / len(dataloader) |
|
accuracy = total_correct / total_pixels |
|
return avg_loss, accuracy |
|
|
|
num_epochs = 2 |
|
total_correct = 0 |
|
total_pixels = 0 |
|
|
|
train_loss_lst = [] |
|
train_accuracy_lst = [] |
|
test_loss_lst = [] |
|
test_accuracy_lst = [] |
|
|
|
for epoch in range(num_epochs): |
|
print(f"Epoch: {epoch+1}") |
|
model.train() |
|
epoch_loss = 0 |
|
|
|
for images, masks in train_dataloader: |
|
|
|
images = images.to(device) |
|
masks = masks.to(device) |
|
|
|
outputs = model(images) |
|
|
|
loss = criterion(outputs, masks) |
|
optimizer.zero_grad() |
|
loss.backward() |
|
optimizer.step() |
|
|
|
preds = torch.sigmoid(outputs) > 0.5 |
|
total_correct += (preds==masks).sum().item() |
|
total_pixels += torch.numel(preds) |
|
|
|
epoch_loss += loss.item() |
|
|
|
train_accuracy = total_correct / total_pixels |
|
avg_train_loss = epoch_loss/len(train_dataloader) |
|
print(f"Train loss at {epoch+1} epoch: {avg_train_loss}") |
|
print(f"Train accuracy at {epoch+1} epoch: {train_accuracy}") |
|
test_loss, test_accuracy = evaluate_model(model, test_dataloader, criterion, device) |
|
print(f"Test loss at {epoch+1} epoch: {test_loss}") |
|
print(f"Test accuracy at {epoch+1} epoch: {test_accuracy}") |
|
train_loss_lst.append(avg_train_loss) |
|
test_loss_lst.append(test_loss) |
|
train_accuracy_lst.append(train_accuracy) |
|
test_accuracy_lst.append(test_accuracy) |