File size: 2,505 Bytes
7febe9c
b416927
 
 
 
 
 
 
7febe9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import os


os.system('pip3 install torch==1.13.1')
os.system('pip3 install torchvision==0.14.1')
os.system('pip3 install opencv-python')


from glob import glob
import cv2
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import torch
from torchvision import transforms
import gradio as gr

from models.GCoNet import GCoNet


device = ['cpu', 'cuda'][0]


class ImagePreprocessor():
    def __init__(self) -> None:
        self.transform_image = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ])

    def proc(self, image):
        image = self.transform_image(image)
        return image


model = GCoNet(bb_pretrained=False).to(device)
state_dict = './ultimate_duts_cocoseg (The best one).pth'
if os.path.exists(state_dict):
    gconet_dict = torch.load(state_dict, map_location=device)
    model.load_state_dict(gconet_dict)
model.eval()


def pred_maps(dr):
    images = [cv2.imread(image_path) for image_path in glob(os.path.join(dr, '*'))]
    image_shapes = [image.shape[:2] for image in images]
    images = [Image.fromarray(image) for image in images]

    images_proc = []
    image_preprocessor = ImagePreprocessor()
    for image in images:
        images_proc.append(image_preprocessor.proc(image))
    images_proc = torch.cat([image_proc.unsqueeze(0) for image_proc in images_proc])

    with torch.no_grad():
        scaled_preds_tensor = model(images_proc.to(device))[-1]
    preds = []
    for image_shape, pred_tensor in zip(image_shapes, scaled_preds_tensor):
        if device == 'cuda':
            pred_tensor = pred_tensor.cpu()
        preds.append(torch.nn.functional.interpolate(pred_tensor.unsqueeze(0), size=image_shape, mode='bilinear', align_corners=True).squeeze().numpy())
    image_preds = []
    for image, pred in zip(images, preds):
        image_preds.append(
            cv2.cvtColor(
                np.hstack([np.array(image.convert('RGB')), cv2.cvtColor((pred*255).astype(np.uint8), cv2.COLOR_GRAY2RGB)]),
                cv2.COLOR_BGR2RGB
        ))
    # for image_pred in image_preds:
    #     cv2.imwrite('a.png', cv2.cvtColor(image_pred, cv2.COLOR_RGB2BGR))
    return image_preds[:]

demo = gr.Interface(
    fn=pred_maps,
    inputs='text',
    outputs=['image', 'image', 'image', 'image', 'image'],
    css=".output_image, .input_image {height: 300px !important}",
)
demo.launch(debug=True)