# Copyright 2021 Tencent # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================= import os import numpy as np import torch from model import SASNet import warnings import random import matplotlib.pyplot as plt import gradio as gr warnings.filterwarnings('ignore') # define the GPU id to be used os.environ['CUDA_VISIBLE_DEVICES'] = '0' def predict(img): """the main process of inference""" test_loader = loading_data(img) model = SASNet(batch_size=4, log_para=1000, block_size=32).cuda() model_path = "SHHA.pth" # load the trained model model.load_state_dict(torch.load(model_path)) print('successfully load model from', model_path) with torch.no_grad(): model.eval() img = img.convert('RGB') transform = standard_transforms.Compose([ standard_transforms.ToTensor(), standard_transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),]) img = transform(img) img = torch.Tensor(img) img = img.cuda() pred_map = model(img) pred_map = pred_map.data.cpu().numpy() pred_cnt = np.sum(pred_map[i_img]) / 1000 den_map = np.squeeze(pred_map[i_img]) fig = plt.figure(frameon=False) ax = plt.Axes(fig, [0., 0., 1., 1.]) ax.set_axis_off() fig.add_axes(ax) ax.imshow(den_map, aspect='auto') return (pred_cnt, fig) with gr.Blocks() as demo: gr.Markdown(""" # Crowd Counting based on SASNet We implemented a image crowd counting model with VGG16 following the paper of Song et. al (2021). ## References Song, Q., Wang, C., Wang, Y., Tai, Y., Wang, C., Li, J., … Ma, J. (2021). To Choose or to Fuse? Scale Selection for Crowd Counting. The Thirty-Fifth AAAI Conference on Artificial Intelligence (AAAI-21). """) image_button = gr.Button("Count the Crowd!") with gr.Row(): with gr.Column(): image_input = gr.Image(type="pil") gr.Examples(["IMG_1.jpg", "IMG_2.jpg", "IMG_3.jpg"], image_input) with gr.Column(): text_output = gr.Label() image_output = gr.Plot() image_button.click(predict, inputs=image_input, outputs=[text_output, image_output]) demo.launch()