File size: 1,234 Bytes
5c7e8ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e855cc
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import torch
import numpy as np
from PIL import Image
from unet import UNet
from data import transform_img 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def load_model(weights_path, device):
    model = UNet(in_channels=3, out_channels=1)  
    model.load_state_dict(torch.load(weights_path, map_location=device))
    model.to(device)
    model.eval()
    return model

def preprocess_image(image_path):
    transform = transform_img()
    image = Image.open(image_path).convert("RGB")
    return transform(image).unsqueeze(0) 

def predict(model, image_tensor, device):
    with torch.no_grad():
        image_tensor = image_tensor.to(device)
        output = model(image_tensor)
        output = torch.sigmoid(output) 
    return output.squeeze(0).cpu().numpy()
    
def save_output(mask, save_path):
    mask = (mask > 0.5).astype(np.uint8)*255 
    mask_image = Image.fromarray(mask[0])  
    mask_image.save(save_path)

if __name__ == "__main__":
    weights_path = "unet_model.pth"
    model = load_model(weights_path, device)
    image_tensor = preprocess_image("DUTS-TE-Image/ILSVRC2012_test_00000003.jpg")
    mask = predict(model, image_tensor, device)
    save_output(mask, "predicted_mask.jpg")