Sartc commited on
Commit
5c7e8ca
·
verified ·
1 Parent(s): b848dd9

Upload 5 files

Browse files
Files changed (5) hide show
  1. data.py +44 -0
  2. inference.py +37 -0
  3. main.py +36 -0
  4. train.py +88 -0
  5. unet.py +73 -0
data.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image
3
+ from torchvision import transforms
4
+ from torch.utils.data import Dataset
5
+
6
+ def find_mask_file(image_path, mask_dir, mask_extensions=['.png', '.jpg', '.jpeg']):
7
+ base_name = os.path.splitext(os.path.basename(image_path))[0]
8
+ for ext in mask_extensions:
9
+ mask_path = os.path.join(mask_dir, base_name + ext)
10
+ if os.path.exists(mask_path):
11
+ return mask_path
12
+ return None
13
+
14
+ class SegmentationDataset(Dataset):
15
+ def __init__(self, image_dir, mask_dir, transform=None):
16
+ self.image_dir = image_dir
17
+ self.mask_dir = mask_dir
18
+ self.transform = transform
19
+ self.image_filenames = os.listdir(image_dir)
20
+
21
+ def __len__(self):
22
+ return len(self.image_filenames)
23
+
24
+ def __getitem__(self, idx):
25
+ img_path = os.path.join(self.image_dir, self.image_filenames[idx])
26
+ mask_path = find_mask_file(img_path, self.mask_dir)
27
+ image = Image.open(img_path).convert("RGB")
28
+ mask = Image.open(mask_path).convert("L")
29
+
30
+ if self.transform:
31
+ image = self.transform(image)
32
+ mask = self.transform(mask)
33
+
34
+ return image, mask
35
+
36
+ def transform_img():
37
+ transform = transforms.Compose([
38
+ transforms.Resize((128, 128)),
39
+ transforms.ToTensor()
40
+ ])
41
+ return transform
42
+
43
+ if __name__ == "__main__":
44
+ print("Dataset class")
inference.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from PIL import Image
4
+ from unet import UNet
5
+ from data import transform_img
6
+
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
+
9
+ def load_model(weights_path, device):
10
+ model = UNet(in_channels=3, out_channels=1)
11
+ model.load_state_dict(torch.load(weights_path, map_location=device))
12
+ model.to(device)
13
+ model.eval()
14
+ return model
15
+
16
+ def preprocess_image(image_path):
17
+ transform = transform_img()
18
+ image = Image.open(image_path).convert("RGB")
19
+ return transform(image).unsqueeze(0)
20
+
21
+ def predict(model, image_tensor, device):
22
+ with torch.no_grad():
23
+ image_tensor = image_tensor.to(device)
24
+ output = model(image_tensor)
25
+ output = torch.sigmoid(output)
26
+ return output.squeeze(0).cpu().numpy()
27
+
28
+ def save_output(mask, save_path):
29
+ mask = (mask > 0.5).astype(np.uint8)*255
30
+ mask_image = Image.fromarray(mask[0])
31
+ mask_image.save(save_path)
32
+
33
+ weights_path = "unet_model.pth"
34
+ model = load_model(weights_path, device)
35
+ image_tensor = preprocess_image("DUTS-TE-Image/ILSVRC2012_test_00000003.jpg")
36
+ mask = predict(model, image_tensor, device)
37
+ save_output(mask, "predicted_mask.jpg")
main.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image, ImageDraw, ImageFont
2
+ import numpy as np
3
+ import torch
4
+ from inference import load_model, preprocess_image, predict
5
+
6
+ original_img = Image.open("DUTS-TR-Image/ILSVRC2012_test_00000645.jpg").convert("RGB")
7
+
8
+ background_with_text = original_img.copy()
9
+ draw = ImageDraw.Draw(background_with_text)
10
+ font_size = 50
11
+ font = ImageFont.truetype("/usr/share/fonts/truetype/freefont/FreeSansBold.ttf", font_size)
12
+ text = "Hello, world!"
13
+ text_position = (50, 50)
14
+ text_color = (255, 255, 255)
15
+ draw.text(text_position, text, fill=text_color, font=font)
16
+
17
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
+ weights_path = "unet_model.pth"
19
+ model = load_model(weights_path, device)
20
+ image_tensor = preprocess_image("DUTS-TR-Image/ILSVRC2012_test_00000645.jpg")
21
+ mask = predict(model, image_tensor, device)
22
+
23
+ print(mask.shape)
24
+
25
+ mask = mask.squeeze(0)
26
+ mask_binary = (mask > 0.5).astype(np.uint8) * 255
27
+ mask_img = Image.fromarray(mask_binary, mode="L")
28
+ mask_img = mask_img.resize(original_img.size, resample=Image.NEAREST)
29
+
30
+ original_rgba = original_img.convert("RGBA")
31
+
32
+ r, g, b, _ = original_rgba.split()
33
+ subject_img = Image.merge("RGBA", (r, g, b, mask_img))
34
+
35
+ background_with_text.paste(subject_img, (0, 0), subject_img)
36
+ background_with_text.save("final_output.png")
train.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from unet import UNet
5
+ from torch.utils.data import DataLoader
6
+ from data import SegmentationDataset, transform_img
7
+
8
+ transform = transform_img()
9
+
10
+ train_dataset = SegmentationDataset("DUTS-TR-Image", "DUTS-TR-Mask", transform=transform)
11
+ train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)
12
+
13
+ test_dataset = SegmentationDataset("DUTS-TE-Image", "DUTS-TE-Mask", transform=transform)
14
+ test_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)
15
+
16
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+ model = UNet().to(device)
18
+ criterion = nn.BCEWithLogitsLoss()
19
+ optimizer = optim.Adam(model.parameters(), lr=1e-4)
20
+
21
+ def evaluate_model(model, dataloader, criterion, device):
22
+ model.eval()
23
+ total_loss = 0
24
+ total_correct = 0
25
+ total_pixels = 0
26
+
27
+ with torch.no_grad():
28
+ for images, masks in dataloader:
29
+
30
+ images = images.to(device)
31
+ masks = masks.to(device)
32
+
33
+ outputs = model(images)
34
+
35
+ loss = criterion(outputs, masks)
36
+ total_loss += loss.item()
37
+
38
+ preds = torch.sigmoid(outputs) > 0.5
39
+ total_correct += (preds==masks).sum().item()
40
+ total_pixels += torch.numel(preds)
41
+
42
+ avg_loss = total_loss / len(dataloader)
43
+ accuracy = total_correct / total_pixels
44
+ return avg_loss, accuracy
45
+
46
+ num_epochs = 2
47
+ total_correct = 0
48
+ total_pixels = 0
49
+
50
+ train_loss_lst = []
51
+ train_accuracy_lst = []
52
+ test_loss_lst = []
53
+ test_accuracy_lst = []
54
+
55
+ for epoch in range(num_epochs):
56
+ print(f"Epoch: {epoch+1}")
57
+ model.train()
58
+ epoch_loss = 0
59
+
60
+ for images, masks in train_dataloader:
61
+
62
+ images = images.to(device)
63
+ masks = masks.to(device)
64
+
65
+ outputs = model(images)
66
+
67
+ loss = criterion(outputs, masks)
68
+ optimizer.zero_grad()
69
+ loss.backward()
70
+ optimizer.step()
71
+
72
+ preds = torch.sigmoid(outputs) > 0.5
73
+ total_correct += (preds==masks).sum().item()
74
+ total_pixels += torch.numel(preds)
75
+
76
+ epoch_loss += loss.item()
77
+
78
+ train_accuracy = total_correct / total_pixels
79
+ avg_train_loss = epoch_loss/len(train_dataloader)
80
+ print(f"Train loss at {epoch+1} epoch: {avg_train_loss}")
81
+ print(f"Train accuracy at {epoch+1} epoch: {train_accuracy}")
82
+ test_loss, test_accuracy = evaluate_model(model, test_dataloader, criterion, device)
83
+ print(f"Test loss at {epoch+1} epoch: {test_loss}")
84
+ print(f"Test accuracy at {epoch+1} epoch: {test_accuracy}")
85
+ train_loss_lst.append(avg_train_loss)
86
+ test_loss_lst.append(test_loss)
87
+ train_accuracy_lst.append(train_accuracy)
88
+ test_accuracy_lst.append(test_accuracy)
unet.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class ConvBlock(nn.Module):
5
+ def __init__(self, in_channels, out_channels):
6
+ super(ConvBlock, self).__init__()
7
+ self.conv = nn.Sequential(
8
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
9
+ nn.ReLU(inplace=True),
10
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
11
+ nn.ReLU(inplace=True)
12
+ )
13
+
14
+ def forward(self, x):
15
+ return self.conv(x)
16
+
17
+ class UpConv(nn.Module):
18
+ def __init__(self, in_channels, out_channels):
19
+ super(UpConv, self).__init__()
20
+ self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
21
+
22
+ def forward(self, x):
23
+ return self.up(x)
24
+
25
+ class UNet(nn.Module):
26
+ def __init__(self, in_channels=3, out_channels=1):
27
+ super(UNet, self).__init__()
28
+
29
+ self.encoder1 = ConvBlock(in_channels, 64)
30
+ self.encoder2 = ConvBlock(64, 128)
31
+ self.encoder3 = ConvBlock(128, 256)
32
+ self.encoder4 = ConvBlock(256, 512)
33
+
34
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
35
+
36
+ self.bottleneck = ConvBlock(512, 1024)
37
+
38
+ self.upconv4 = UpConv(1024, 512)
39
+ self.decoder4 = ConvBlock(1024, 512)
40
+ self.upconv3 = UpConv(512, 256)
41
+ self.decoder3 = ConvBlock(512, 256)
42
+ self.upconv2 = UpConv(256, 128)
43
+ self.decoder2 = ConvBlock(256, 128)
44
+ self.upconv1 = UpConv(128, 64)
45
+ self.decoder1 = ConvBlock(128, 64)
46
+
47
+ self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)
48
+
49
+ def forward(self, x):
50
+ enc1 = self.encoder1(x)
51
+ enc2 = self.encoder2(self.pool(enc1))
52
+ enc3 = self.encoder3(self.pool(enc2))
53
+ enc4 = self.encoder4(self.pool(enc3))
54
+
55
+ bottleneck = self.bottleneck(self.pool(enc4))
56
+
57
+ dec4 = self.upconv4(bottleneck)
58
+ dec4 = torch.cat((enc4, dec4), dim=1)
59
+ dec4 = self.decoder4(dec4)
60
+
61
+ dec3 = self.upconv3(dec4)
62
+ dec3 = torch.cat((enc3, dec3), dim=1)
63
+ dec3 = self.decoder3(dec3)
64
+
65
+ dec2 = self.upconv2(dec3)
66
+ dec2 = torch.cat((enc2, dec2), dim=1)
67
+ dec2 = self.decoder2(dec2)
68
+
69
+ dec1 = self.upconv1(dec2)
70
+ dec1 = torch.cat((enc1, dec1), dim=1)
71
+ dec1 = self.decoder1(dec1)
72
+
73
+ return self.final_conv(dec1)