|
|
|
|
|
from __future__ import annotations |
|
|
|
import argparse |
|
import pathlib |
|
|
|
import gradio as gr |
|
|
|
from model import Model |
|
|
|
DESCRIPTION = '''# CBNetV2 |
|
|
|
This is an unofficial demo for [https://github.com/VDIGPKU/CBNetV2](https://github.com/VDIGPKU/CBNetV2).''' |
|
FOOTER = '<img id="visitor-badge" alt="visitor badge" src="https://visitor-badge.glitch.me/badge?page_id=hysts.cbnetv2" />' |
|
|
|
|
|
def parse_args() -> argparse.Namespace: |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--device', type=str, default='cpu') |
|
parser.add_argument('--theme', type=str) |
|
parser.add_argument('--share', action='store_true') |
|
parser.add_argument('--port', type=int) |
|
parser.add_argument('--disable-queue', |
|
dest='enable_queue', |
|
action='store_false') |
|
return parser.parse_args() |
|
|
|
|
|
def set_example_image(example: list) -> dict: |
|
return gr.Image.update(value=example[0]) |
|
|
|
|
|
def main(): |
|
args = parse_args() |
|
model = Model(args.device) |
|
|
|
with gr.Blocks(theme=args.theme, css='style.css') as demo: |
|
gr.Markdown(DESCRIPTION) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
with gr.Row(): |
|
input_image = gr.Image(label='Input Image', type='numpy') |
|
with gr.Row(): |
|
detector_name = gr.Dropdown(list(model.models.keys()), |
|
value=model.model_name, |
|
label='Detector') |
|
with gr.Row(): |
|
detect_button = gr.Button(value='Detect') |
|
detection_results = gr.Variable() |
|
with gr.Column(): |
|
with gr.Row(): |
|
detection_visualization = gr.Image( |
|
label='Detection Result', type='numpy') |
|
with gr.Row(): |
|
visualization_score_threshold = gr.Slider( |
|
0, |
|
1, |
|
step=0.05, |
|
value=0.3, |
|
label='Visualization Score Threshold') |
|
with gr.Row(): |
|
redraw_button = gr.Button(value='Redraw') |
|
|
|
with gr.Row(): |
|
paths = sorted(pathlib.Path('images').rglob('*.jpg')) |
|
example_images = gr.Dataset(components=[input_image], |
|
samples=[[path.as_posix()] |
|
for path in paths]) |
|
|
|
gr.Markdown(FOOTER) |
|
|
|
detector_name.change(fn=model.set_model_name, |
|
inputs=[detector_name], |
|
outputs=None) |
|
detect_button.click(fn=model.detect_and_visualize, |
|
inputs=[ |
|
input_image, |
|
visualization_score_threshold, |
|
], |
|
outputs=[ |
|
detection_results, |
|
detection_visualization, |
|
]) |
|
redraw_button.click(fn=model.visualize_detection_results, |
|
inputs=[ |
|
input_image, |
|
detection_results, |
|
visualization_score_threshold, |
|
], |
|
outputs=[detection_visualization]) |
|
example_images.click(fn=set_example_image, |
|
inputs=[example_images], |
|
outputs=[input_image]) |
|
|
|
demo.launch( |
|
enable_queue=args.enable_queue, |
|
server_port=args.port, |
|
share=args.share, |
|
) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|