|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import numpy as np |
|
import torch |
|
from model import SASNet |
|
import warnings |
|
import random |
|
import matplotlib.pyplot as plt |
|
import gradio as gr |
|
|
|
warnings.filterwarnings('ignore') |
|
|
|
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = '0' |
|
|
|
def predict(img): |
|
"""the main process of inference""" |
|
test_loader = loading_data(img) |
|
|
|
model = SASNet(batch_size=4, log_para=1000, block_size=32).cuda() |
|
model_path = "SHHA.pth" |
|
|
|
model.load_state_dict(torch.load(model_path)) |
|
print('successfully load model from', model_path) |
|
|
|
with torch.no_grad(): |
|
model.eval() |
|
|
|
img = img.convert('RGB') |
|
transform = standard_transforms.Compose([ |
|
standard_transforms.ToTensor(), standard_transforms.Normalize(mean=[0.485, 0.456, 0.406], |
|
std=[0.229, 0.224, 0.225]),]) |
|
img = transform(img) |
|
img = torch.Tensor(img) |
|
|
|
img = img.cuda() |
|
pred_map = model(img) |
|
|
|
pred_map = pred_map.data.cpu().numpy() |
|
pred_cnt = np.sum(pred_map[i_img]) / 1000 |
|
|
|
den_map = np.squeeze(pred_map[i_img]) |
|
fig = plt.figure(frameon=False) |
|
ax = plt.Axes(fig, [0., 0., 1., 1.]) |
|
ax.set_axis_off() |
|
fig.add_axes(ax) |
|
ax.imshow(den_map, aspect='auto') |
|
return (pred_cnt, fig) |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown(""" |
|
# Crowd Counting based on SASNet |
|
|
|
We implemented a image crowd counting model with VGG16 following the paper of Song et. al (2021). |
|
|
|
## References |
|
Song, Q., Wang, C., Wang, Y., Tai, Y., Wang, C., Li, J., … Ma, J. (2021). To Choose or to Fuse? Scale Selection for Crowd Counting. The Thirty-Fifth AAAI Conference on Artificial Intelligence (AAAI-21). |
|
""") |
|
image_button = gr.Button("Count the Crowd!") |
|
with gr.Row(): |
|
with gr.Column(): |
|
image_input = gr.Image(type="pil") |
|
gr.Examples(["IMG_1.jpg", "IMG_2.jpg", "IMG_3.jpg"], image_input) |
|
with gr.Column(): |
|
text_output = gr.Label() |
|
image_output = gr.Plot() |
|
|
|
|
|
image_button.click(predict, inputs=image_input, outputs=[text_output, image_output]) |
|
|
|
demo.launch() |
|
|