import torch import numpy as np from PIL import Image from unet import UNet from data import transform_img device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def load_model(weights_path, device): model = UNet(in_channels=3, out_channels=1) model.load_state_dict(torch.load(weights_path, map_location=device)) model.to(device) model.eval() return model def preprocess_image(image_path): transform = transform_img() image = Image.open(image_path).convert("RGB") return transform(image).unsqueeze(0) def predict(model, image_tensor, device): with torch.no_grad(): image_tensor = image_tensor.to(device) output = model(image_tensor) output = torch.sigmoid(output) return output.squeeze(0).cpu().numpy() def save_output(mask, save_path): mask = (mask > 0.5).astype(np.uint8)*255 mask_image = Image.fromarray(mask[0]) mask_image.save(save_path) if __name__ == "__main__": weights_path = "unet_model.pth" model = load_model(weights_path, device) image_tensor = preprocess_image("DUTS-TE-Image/ILSVRC2012_test_00000003.jpg") mask = predict(model, image_tensor, device) save_output(mask, "predicted_mask.jpg")