tbi / train.py
Sartc's picture
Upload 5 files
5c7e8ca verified
raw
history blame
2.76 kB
import torch
import torch.nn as nn
import torch.optim as optim
from unet import UNet
from torch.utils.data import DataLoader
from data import SegmentationDataset, transform_img
transform = transform_img()
train_dataset = SegmentationDataset("DUTS-TR-Image", "DUTS-TR-Mask", transform=transform)
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)
test_dataset = SegmentationDataset("DUTS-TE-Image", "DUTS-TE-Mask", transform=transform)
test_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet().to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
def evaluate_model(model, dataloader, criterion, device):
model.eval()
total_loss = 0
total_correct = 0
total_pixels = 0
with torch.no_grad():
for images, masks in dataloader:
images = images.to(device)
masks = masks.to(device)
outputs = model(images)
loss = criterion(outputs, masks)
total_loss += loss.item()
preds = torch.sigmoid(outputs) > 0.5
total_correct += (preds==masks).sum().item()
total_pixels += torch.numel(preds)
avg_loss = total_loss / len(dataloader)
accuracy = total_correct / total_pixels
return avg_loss, accuracy
num_epochs = 2
total_correct = 0
total_pixels = 0
train_loss_lst = []
train_accuracy_lst = []
test_loss_lst = []
test_accuracy_lst = []
for epoch in range(num_epochs):
print(f"Epoch: {epoch+1}")
model.train()
epoch_loss = 0
for images, masks in train_dataloader:
images = images.to(device)
masks = masks.to(device)
outputs = model(images)
loss = criterion(outputs, masks)
optimizer.zero_grad()
loss.backward()
optimizer.step()
preds = torch.sigmoid(outputs) > 0.5
total_correct += (preds==masks).sum().item()
total_pixels += torch.numel(preds)
epoch_loss += loss.item()
train_accuracy = total_correct / total_pixels
avg_train_loss = epoch_loss/len(train_dataloader)
print(f"Train loss at {epoch+1} epoch: {avg_train_loss}")
print(f"Train accuracy at {epoch+1} epoch: {train_accuracy}")
test_loss, test_accuracy = evaluate_model(model, test_dataloader, criterion, device)
print(f"Test loss at {epoch+1} epoch: {test_loss}")
print(f"Test accuracy at {epoch+1} epoch: {test_accuracy}")
train_loss_lst.append(avg_train_loss)
test_loss_lst.append(test_loss)
train_accuracy_lst.append(train_accuracy)
test_accuracy_lst.append(test_accuracy)