File size: 2,927 Bytes
994ce72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
494c823
994ce72
 
 
 
 
deb6583
 
994ce72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
deb6583
494c823
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
994ce72
 
494c823
994ce72
494c823
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
# Copyright 2021 Tencent

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
import os
import numpy as np
import torch
from model import SASNet
import warnings
import random
import matplotlib.pyplot as plt
import gradio as gr

warnings.filterwarnings('ignore')

# define the GPU id to be used
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

def predict(img):
    """the main process of inference"""
    test_loader = loading_data(img)

    model = SASNet(batch_size=4, log_para=1000, block_size=32).cuda()
    model_path = "SHHA.pth"
    # load the trained model
    model.load_state_dict(torch.load(model_path))
    print('successfully load model from', model_path)

    with torch.no_grad():
        model.eval()
	
	img = img.convert('RGB')
        transform = standard_transforms.Compose([
        standard_transforms.ToTensor(), standard_transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                                      std=[0.229, 0.224, 0.225]),])
        img = transform(img)
        img = torch.Tensor(img)

        img = img.cuda()
        pred_map = model(img)

        pred_map = pred_map.data.cpu().numpy()
        pred_cnt = np.sum(pred_map[i_img]) / 1000
                
        den_map = np.squeeze(pred_map[i_img])
        fig = plt.figure(frameon=False)
        ax = plt.Axes(fig, [0., 0., 1., 1.])
        ax.set_axis_off()
        fig.add_axes(ax)
        ax.imshow(den_map, aspect='auto')
        return (pred_cnt, fig)

with gr.Blocks() as demo:
    gr.Markdown("""
    # Crowd Counting based on SASNet

    We implemented a image crowd counting model with VGG16 following the paper of Song et. al (2021). 

    ## References
    Song, Q., Wang, C., Wang, Y., Tai, Y., Wang, C., Li, J., … Ma, J. (2021). To Choose or to Fuse? Scale Selection for Crowd Counting. The Thirty-Fifth AAAI Conference on Artificial Intelligence (AAAI-21).
    """)
    image_button = gr.Button("Count the Crowd!")
    with gr.Row():
            with gr.Column():
              image_input = gr.Image(type="pil")
              gr.Examples(["IMG_1.jpg", "IMG_2.jpg", "IMG_3.jpg"], image_input)
            with gr.Column():
              text_output = gr.Label()
              image_output = gr.Plot()
              
    
    image_button.click(predict, inputs=image_input, outputs=[text_output, image_output])

demo.launch()