import torch import torch.nn as nn import torch.optim as optim from unet import UNet from torch.utils.data import DataLoader from data import SegmentationDataset, transform_img transform = transform_img() train_dataset = SegmentationDataset("DUTS-TR-Image", "DUTS-TR-Mask", transform=transform) train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True) test_dataset = SegmentationDataset("DUTS-TE-Image", "DUTS-TE-Mask", transform=transform) test_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = UNet().to(device) criterion = nn.BCEWithLogitsLoss() optimizer = optim.Adam(model.parameters(), lr=1e-4) def evaluate_model(model, dataloader, criterion, device): model.eval() total_loss = 0 total_correct = 0 total_pixels = 0 with torch.no_grad(): for images, masks in dataloader: images = images.to(device) masks = masks.to(device) outputs = model(images) loss = criterion(outputs, masks) total_loss += loss.item() preds = torch.sigmoid(outputs) > 0.5 total_correct += (preds==masks).sum().item() total_pixels += torch.numel(preds) avg_loss = total_loss / len(dataloader) accuracy = total_correct / total_pixels return avg_loss, accuracy num_epochs = 2 total_correct = 0 total_pixels = 0 train_loss_lst = [] train_accuracy_lst = [] test_loss_lst = [] test_accuracy_lst = [] for epoch in range(num_epochs): print(f"Epoch: {epoch+1}") model.train() epoch_loss = 0 for images, masks in train_dataloader: images = images.to(device) masks = masks.to(device) outputs = model(images) loss = criterion(outputs, masks) optimizer.zero_grad() loss.backward() optimizer.step() preds = torch.sigmoid(outputs) > 0.5 total_correct += (preds==masks).sum().item() total_pixels += torch.numel(preds) epoch_loss += loss.item() train_accuracy = total_correct / total_pixels avg_train_loss = epoch_loss/len(train_dataloader) print(f"Train loss at {epoch+1} epoch: {avg_train_loss}") print(f"Train accuracy at {epoch+1} epoch: {train_accuracy}") test_loss, test_accuracy = evaluate_model(model, test_dataloader, criterion, device) print(f"Test loss at {epoch+1} epoch: {test_loss}") print(f"Test accuracy at {epoch+1} epoch: {test_accuracy}") train_loss_lst.append(avg_train_loss) test_loss_lst.append(test_loss) train_accuracy_lst.append(train_accuracy) test_accuracy_lst.append(test_accuracy)