File size: 2,762 Bytes
5c7e8ca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
import torch
import torch.nn as nn
import torch.optim as optim
from unet import UNet
from torch.utils.data import DataLoader
from data import SegmentationDataset, transform_img
transform = transform_img()
train_dataset = SegmentationDataset("DUTS-TR-Image", "DUTS-TR-Mask", transform=transform)
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)
test_dataset = SegmentationDataset("DUTS-TE-Image", "DUTS-TE-Mask", transform=transform)
test_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet().to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
def evaluate_model(model, dataloader, criterion, device):
model.eval()
total_loss = 0
total_correct = 0
total_pixels = 0
with torch.no_grad():
for images, masks in dataloader:
images = images.to(device)
masks = masks.to(device)
outputs = model(images)
loss = criterion(outputs, masks)
total_loss += loss.item()
preds = torch.sigmoid(outputs) > 0.5
total_correct += (preds==masks).sum().item()
total_pixels += torch.numel(preds)
avg_loss = total_loss / len(dataloader)
accuracy = total_correct / total_pixels
return avg_loss, accuracy
num_epochs = 2
total_correct = 0
total_pixels = 0
train_loss_lst = []
train_accuracy_lst = []
test_loss_lst = []
test_accuracy_lst = []
for epoch in range(num_epochs):
print(f"Epoch: {epoch+1}")
model.train()
epoch_loss = 0
for images, masks in train_dataloader:
images = images.to(device)
masks = masks.to(device)
outputs = model(images)
loss = criterion(outputs, masks)
optimizer.zero_grad()
loss.backward()
optimizer.step()
preds = torch.sigmoid(outputs) > 0.5
total_correct += (preds==masks).sum().item()
total_pixels += torch.numel(preds)
epoch_loss += loss.item()
train_accuracy = total_correct / total_pixels
avg_train_loss = epoch_loss/len(train_dataloader)
print(f"Train loss at {epoch+1} epoch: {avg_train_loss}")
print(f"Train accuracy at {epoch+1} epoch: {train_accuracy}")
test_loss, test_accuracy = evaluate_model(model, test_dataloader, criterion, device)
print(f"Test loss at {epoch+1} epoch: {test_loss}")
print(f"Test accuracy at {epoch+1} epoch: {test_accuracy}")
train_loss_lst.append(avg_train_loss)
test_loss_lst.append(test_loss)
train_accuracy_lst.append(train_accuracy)
test_accuracy_lst.append(test_accuracy) |