|
import torch |
|
import numpy as np |
|
from PIL import Image |
|
from unet import UNet |
|
from data import transform_img |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
def load_model(weights_path, device): |
|
model = UNet(in_channels=3, out_channels=1) |
|
model.load_state_dict(torch.load(weights_path, map_location=device)) |
|
model.to(device) |
|
model.eval() |
|
return model |
|
|
|
def preprocess_image(image_path): |
|
transform = transform_img() |
|
image = Image.open(image_path).convert("RGB") |
|
return transform(image).unsqueeze(0) |
|
|
|
def predict(model, image_tensor, device): |
|
with torch.no_grad(): |
|
image_tensor = image_tensor.to(device) |
|
output = model(image_tensor) |
|
output = torch.sigmoid(output) |
|
return output.squeeze(0).cpu().numpy() |
|
|
|
def save_output(mask, save_path): |
|
mask = (mask > 0.5).astype(np.uint8)*255 |
|
mask_image = Image.fromarray(mask[0]) |
|
mask_image.save(save_path) |
|
|
|
if __name__ == "__main__": |
|
weights_path = "unet_model.pth" |
|
model = load_model(weights_path, device) |
|
image_tensor = preprocess_image("DUTS-TE-Image/ILSVRC2012_test_00000003.jpg") |
|
mask = predict(model, image_tensor, device) |
|
save_output(mask, "predicted_mask.jpg") |
|
|