leuschnm's picture
add model
994ce72
raw
history blame
2.93 kB
# Copyright 2021 Tencent
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
import os
import numpy as np
import torch
from model import SASNet
import warnings
import random
import matplotlib.pyplot as plt
import gradio as gr
warnings.filterwarnings('ignore')
# define the GPU id to be used
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
def predict(img):
"""the main process of inference"""
test_loader = loading_data(img)
model = SASNet(batch_size=4, log_para=1000, block_size=32).cuda()
model_path = "SHHA.pth"
# load the trained model
model.load_state_dict(torch.load(model_path))
print('successfully load model from', model_path)
with torch.no_grad():
model.eval()
img = img.convert('RGB')
transform = standard_transforms.Compose([
standard_transforms.ToTensor(), standard_transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),])
img = transform(img)
img = torch.Tensor(img)
img = img.cuda()
pred_map = model(img)
pred_map = pred_map.data.cpu().numpy()
pred_cnt = np.sum(pred_map[i_img]) / 1000
den_map = np.squeeze(pred_map[i_img])
fig = plt.figure(frameon=False)
ax = plt.Axes(fig, [0., 0., 1., 1.])
ax.set_axis_off()
fig.add_axes(ax)
ax.imshow(den_map, aspect='auto')
return (pred_cnt, fig)
with gr.Blocks() as demo:
gr.Markdown("""
# Crowd Counting based on SASNet
We implemented a image crowd counting model with VGG16 following the paper of Song et. al (2021).
## References
Song, Q., Wang, C., Wang, Y., Tai, Y., Wang, C., Li, J., … Ma, J. (2021). To Choose or to Fuse? Scale Selection for Crowd Counting. The Thirty-Fifth AAAI Conference on Artificial Intelligence (AAAI-21).
""")
image_button = gr.Button("Count the Crowd!")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil")
gr.Examples(["IMG_1.jpg", "IMG_2.jpg", "IMG_3.jpg"], image_input)
with gr.Column():
text_output = gr.Label()
image_output = gr.Plot()
image_button.click(predict, inputs=image_input, outputs=[text_output, image_output])
demo.launch()