Upload 5 files
Browse files- data.py +44 -0
- inference.py +37 -0
- main.py +36 -0
- train.py +88 -0
- unet.py +73 -0
data.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from PIL import Image
|
3 |
+
from torchvision import transforms
|
4 |
+
from torch.utils.data import Dataset
|
5 |
+
|
6 |
+
def find_mask_file(image_path, mask_dir, mask_extensions=['.png', '.jpg', '.jpeg']):
|
7 |
+
base_name = os.path.splitext(os.path.basename(image_path))[0]
|
8 |
+
for ext in mask_extensions:
|
9 |
+
mask_path = os.path.join(mask_dir, base_name + ext)
|
10 |
+
if os.path.exists(mask_path):
|
11 |
+
return mask_path
|
12 |
+
return None
|
13 |
+
|
14 |
+
class SegmentationDataset(Dataset):
|
15 |
+
def __init__(self, image_dir, mask_dir, transform=None):
|
16 |
+
self.image_dir = image_dir
|
17 |
+
self.mask_dir = mask_dir
|
18 |
+
self.transform = transform
|
19 |
+
self.image_filenames = os.listdir(image_dir)
|
20 |
+
|
21 |
+
def __len__(self):
|
22 |
+
return len(self.image_filenames)
|
23 |
+
|
24 |
+
def __getitem__(self, idx):
|
25 |
+
img_path = os.path.join(self.image_dir, self.image_filenames[idx])
|
26 |
+
mask_path = find_mask_file(img_path, self.mask_dir)
|
27 |
+
image = Image.open(img_path).convert("RGB")
|
28 |
+
mask = Image.open(mask_path).convert("L")
|
29 |
+
|
30 |
+
if self.transform:
|
31 |
+
image = self.transform(image)
|
32 |
+
mask = self.transform(mask)
|
33 |
+
|
34 |
+
return image, mask
|
35 |
+
|
36 |
+
def transform_img():
|
37 |
+
transform = transforms.Compose([
|
38 |
+
transforms.Resize((128, 128)),
|
39 |
+
transforms.ToTensor()
|
40 |
+
])
|
41 |
+
return transform
|
42 |
+
|
43 |
+
if __name__ == "__main__":
|
44 |
+
print("Dataset class")
|
inference.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image
|
4 |
+
from unet import UNet
|
5 |
+
from data import transform_img
|
6 |
+
|
7 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
8 |
+
|
9 |
+
def load_model(weights_path, device):
|
10 |
+
model = UNet(in_channels=3, out_channels=1)
|
11 |
+
model.load_state_dict(torch.load(weights_path, map_location=device))
|
12 |
+
model.to(device)
|
13 |
+
model.eval()
|
14 |
+
return model
|
15 |
+
|
16 |
+
def preprocess_image(image_path):
|
17 |
+
transform = transform_img()
|
18 |
+
image = Image.open(image_path).convert("RGB")
|
19 |
+
return transform(image).unsqueeze(0)
|
20 |
+
|
21 |
+
def predict(model, image_tensor, device):
|
22 |
+
with torch.no_grad():
|
23 |
+
image_tensor = image_tensor.to(device)
|
24 |
+
output = model(image_tensor)
|
25 |
+
output = torch.sigmoid(output)
|
26 |
+
return output.squeeze(0).cpu().numpy()
|
27 |
+
|
28 |
+
def save_output(mask, save_path):
|
29 |
+
mask = (mask > 0.5).astype(np.uint8)*255
|
30 |
+
mask_image = Image.fromarray(mask[0])
|
31 |
+
mask_image.save(save_path)
|
32 |
+
|
33 |
+
weights_path = "unet_model.pth"
|
34 |
+
model = load_model(weights_path, device)
|
35 |
+
image_tensor = preprocess_image("DUTS-TE-Image/ILSVRC2012_test_00000003.jpg")
|
36 |
+
mask = predict(model, image_tensor, device)
|
37 |
+
save_output(mask, "predicted_mask.jpg")
|
main.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image, ImageDraw, ImageFont
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from inference import load_model, preprocess_image, predict
|
5 |
+
|
6 |
+
original_img = Image.open("DUTS-TR-Image/ILSVRC2012_test_00000645.jpg").convert("RGB")
|
7 |
+
|
8 |
+
background_with_text = original_img.copy()
|
9 |
+
draw = ImageDraw.Draw(background_with_text)
|
10 |
+
font_size = 50
|
11 |
+
font = ImageFont.truetype("/usr/share/fonts/truetype/freefont/FreeSansBold.ttf", font_size)
|
12 |
+
text = "Hello, world!"
|
13 |
+
text_position = (50, 50)
|
14 |
+
text_color = (255, 255, 255)
|
15 |
+
draw.text(text_position, text, fill=text_color, font=font)
|
16 |
+
|
17 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
18 |
+
weights_path = "unet_model.pth"
|
19 |
+
model = load_model(weights_path, device)
|
20 |
+
image_tensor = preprocess_image("DUTS-TR-Image/ILSVRC2012_test_00000645.jpg")
|
21 |
+
mask = predict(model, image_tensor, device)
|
22 |
+
|
23 |
+
print(mask.shape)
|
24 |
+
|
25 |
+
mask = mask.squeeze(0)
|
26 |
+
mask_binary = (mask > 0.5).astype(np.uint8) * 255
|
27 |
+
mask_img = Image.fromarray(mask_binary, mode="L")
|
28 |
+
mask_img = mask_img.resize(original_img.size, resample=Image.NEAREST)
|
29 |
+
|
30 |
+
original_rgba = original_img.convert("RGBA")
|
31 |
+
|
32 |
+
r, g, b, _ = original_rgba.split()
|
33 |
+
subject_img = Image.merge("RGBA", (r, g, b, mask_img))
|
34 |
+
|
35 |
+
background_with_text.paste(subject_img, (0, 0), subject_img)
|
36 |
+
background_with_text.save("final_output.png")
|
train.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.optim as optim
|
4 |
+
from unet import UNet
|
5 |
+
from torch.utils.data import DataLoader
|
6 |
+
from data import SegmentationDataset, transform_img
|
7 |
+
|
8 |
+
transform = transform_img()
|
9 |
+
|
10 |
+
train_dataset = SegmentationDataset("DUTS-TR-Image", "DUTS-TR-Mask", transform=transform)
|
11 |
+
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)
|
12 |
+
|
13 |
+
test_dataset = SegmentationDataset("DUTS-TE-Image", "DUTS-TE-Mask", transform=transform)
|
14 |
+
test_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)
|
15 |
+
|
16 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
17 |
+
model = UNet().to(device)
|
18 |
+
criterion = nn.BCEWithLogitsLoss()
|
19 |
+
optimizer = optim.Adam(model.parameters(), lr=1e-4)
|
20 |
+
|
21 |
+
def evaluate_model(model, dataloader, criterion, device):
|
22 |
+
model.eval()
|
23 |
+
total_loss = 0
|
24 |
+
total_correct = 0
|
25 |
+
total_pixels = 0
|
26 |
+
|
27 |
+
with torch.no_grad():
|
28 |
+
for images, masks in dataloader:
|
29 |
+
|
30 |
+
images = images.to(device)
|
31 |
+
masks = masks.to(device)
|
32 |
+
|
33 |
+
outputs = model(images)
|
34 |
+
|
35 |
+
loss = criterion(outputs, masks)
|
36 |
+
total_loss += loss.item()
|
37 |
+
|
38 |
+
preds = torch.sigmoid(outputs) > 0.5
|
39 |
+
total_correct += (preds==masks).sum().item()
|
40 |
+
total_pixels += torch.numel(preds)
|
41 |
+
|
42 |
+
avg_loss = total_loss / len(dataloader)
|
43 |
+
accuracy = total_correct / total_pixels
|
44 |
+
return avg_loss, accuracy
|
45 |
+
|
46 |
+
num_epochs = 2
|
47 |
+
total_correct = 0
|
48 |
+
total_pixels = 0
|
49 |
+
|
50 |
+
train_loss_lst = []
|
51 |
+
train_accuracy_lst = []
|
52 |
+
test_loss_lst = []
|
53 |
+
test_accuracy_lst = []
|
54 |
+
|
55 |
+
for epoch in range(num_epochs):
|
56 |
+
print(f"Epoch: {epoch+1}")
|
57 |
+
model.train()
|
58 |
+
epoch_loss = 0
|
59 |
+
|
60 |
+
for images, masks in train_dataloader:
|
61 |
+
|
62 |
+
images = images.to(device)
|
63 |
+
masks = masks.to(device)
|
64 |
+
|
65 |
+
outputs = model(images)
|
66 |
+
|
67 |
+
loss = criterion(outputs, masks)
|
68 |
+
optimizer.zero_grad()
|
69 |
+
loss.backward()
|
70 |
+
optimizer.step()
|
71 |
+
|
72 |
+
preds = torch.sigmoid(outputs) > 0.5
|
73 |
+
total_correct += (preds==masks).sum().item()
|
74 |
+
total_pixels += torch.numel(preds)
|
75 |
+
|
76 |
+
epoch_loss += loss.item()
|
77 |
+
|
78 |
+
train_accuracy = total_correct / total_pixels
|
79 |
+
avg_train_loss = epoch_loss/len(train_dataloader)
|
80 |
+
print(f"Train loss at {epoch+1} epoch: {avg_train_loss}")
|
81 |
+
print(f"Train accuracy at {epoch+1} epoch: {train_accuracy}")
|
82 |
+
test_loss, test_accuracy = evaluate_model(model, test_dataloader, criterion, device)
|
83 |
+
print(f"Test loss at {epoch+1} epoch: {test_loss}")
|
84 |
+
print(f"Test accuracy at {epoch+1} epoch: {test_accuracy}")
|
85 |
+
train_loss_lst.append(avg_train_loss)
|
86 |
+
test_loss_lst.append(test_loss)
|
87 |
+
train_accuracy_lst.append(train_accuracy)
|
88 |
+
test_accuracy_lst.append(test_accuracy)
|
unet.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
class ConvBlock(nn.Module):
|
5 |
+
def __init__(self, in_channels, out_channels):
|
6 |
+
super(ConvBlock, self).__init__()
|
7 |
+
self.conv = nn.Sequential(
|
8 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
|
9 |
+
nn.ReLU(inplace=True),
|
10 |
+
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
|
11 |
+
nn.ReLU(inplace=True)
|
12 |
+
)
|
13 |
+
|
14 |
+
def forward(self, x):
|
15 |
+
return self.conv(x)
|
16 |
+
|
17 |
+
class UpConv(nn.Module):
|
18 |
+
def __init__(self, in_channels, out_channels):
|
19 |
+
super(UpConv, self).__init__()
|
20 |
+
self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
|
21 |
+
|
22 |
+
def forward(self, x):
|
23 |
+
return self.up(x)
|
24 |
+
|
25 |
+
class UNet(nn.Module):
|
26 |
+
def __init__(self, in_channels=3, out_channels=1):
|
27 |
+
super(UNet, self).__init__()
|
28 |
+
|
29 |
+
self.encoder1 = ConvBlock(in_channels, 64)
|
30 |
+
self.encoder2 = ConvBlock(64, 128)
|
31 |
+
self.encoder3 = ConvBlock(128, 256)
|
32 |
+
self.encoder4 = ConvBlock(256, 512)
|
33 |
+
|
34 |
+
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
35 |
+
|
36 |
+
self.bottleneck = ConvBlock(512, 1024)
|
37 |
+
|
38 |
+
self.upconv4 = UpConv(1024, 512)
|
39 |
+
self.decoder4 = ConvBlock(1024, 512)
|
40 |
+
self.upconv3 = UpConv(512, 256)
|
41 |
+
self.decoder3 = ConvBlock(512, 256)
|
42 |
+
self.upconv2 = UpConv(256, 128)
|
43 |
+
self.decoder2 = ConvBlock(256, 128)
|
44 |
+
self.upconv1 = UpConv(128, 64)
|
45 |
+
self.decoder1 = ConvBlock(128, 64)
|
46 |
+
|
47 |
+
self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)
|
48 |
+
|
49 |
+
def forward(self, x):
|
50 |
+
enc1 = self.encoder1(x)
|
51 |
+
enc2 = self.encoder2(self.pool(enc1))
|
52 |
+
enc3 = self.encoder3(self.pool(enc2))
|
53 |
+
enc4 = self.encoder4(self.pool(enc3))
|
54 |
+
|
55 |
+
bottleneck = self.bottleneck(self.pool(enc4))
|
56 |
+
|
57 |
+
dec4 = self.upconv4(bottleneck)
|
58 |
+
dec4 = torch.cat((enc4, dec4), dim=1)
|
59 |
+
dec4 = self.decoder4(dec4)
|
60 |
+
|
61 |
+
dec3 = self.upconv3(dec4)
|
62 |
+
dec3 = torch.cat((enc3, dec3), dim=1)
|
63 |
+
dec3 = self.decoder3(dec3)
|
64 |
+
|
65 |
+
dec2 = self.upconv2(dec3)
|
66 |
+
dec2 = torch.cat((enc2, dec2), dim=1)
|
67 |
+
dec2 = self.decoder2(dec2)
|
68 |
+
|
69 |
+
dec1 = self.upconv1(dec2)
|
70 |
+
dec1 = torch.cat((enc1, dec1), dim=1)
|
71 |
+
dec1 = self.decoder1(dec1)
|
72 |
+
|
73 |
+
return self.final_conv(dec1)
|