File size: 3,775 Bytes
3fe1151
 
 
 
 
 
 
 
 
7fd17d1
3fe1151
0c0d56d
 
 
 
 
3fe1151
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c75812b
0c0d56d
3fe1151
 
 
 
 
 
 
 
 
 
 
 
 
8830156
 
 
 
 
 
 
 
 
 
 
 
3fe1151
 
 
 
 
 
 
0c0d56d
3fe1151
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
#!/usr/bin/env python

from __future__ import annotations

import argparse
import pathlib

import gradio as gr

from model import Model

DESCRIPTION = '''# CBNetV2

This is an unofficial demo for [https://github.com/VDIGPKU/CBNetV2](https://github.com/VDIGPKU/CBNetV2).'''
FOOTER = '<img id="visitor-badge" alt="visitor badge" src="https://visitor-badge.glitch.me/badge?page_id=hysts.cbnetv2" />'


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser()
    parser.add_argument('--device', type=str, default='cpu')
    parser.add_argument('--theme', type=str)
    parser.add_argument('--share', action='store_true')
    parser.add_argument('--port', type=int)
    parser.add_argument('--disable-queue',
                        dest='enable_queue',
                        action='store_false')
    return parser.parse_args()


def set_example_image(example: list) -> dict:
    return gr.Image.update(value=example[0])


def main():
    args = parse_args()
    model = Model(args.device)

    with gr.Blocks(theme=args.theme, css='style.css') as demo:
        gr.Markdown(DESCRIPTION)

        with gr.Row():
            with gr.Column():
                with gr.Row():
                    input_image = gr.Image(label='Input Image', type='numpy')
                with gr.Row():
                    detector_name = gr.Dropdown(list(model.models.keys()),
                                                value=model.model_name,
                                                label='Detector')
                with gr.Row():
                    detect_button = gr.Button(value='Detect')
                    detection_results = gr.Variable()
            with gr.Column():
                with gr.Row():
                    detection_visualization = gr.Image(
                        label='Detection Result', type='numpy')
                with gr.Row():
                    visualization_score_threshold = gr.Slider(
                        0,
                        1,
                        step=0.05,
                        value=0.3,
                        label='Visualization Score Threshold')
                with gr.Row():
                    redraw_button = gr.Button(value='Redraw')

        with gr.Row():
            paths = sorted(pathlib.Path('images').rglob('*.jpg'))
            example_images = gr.Dataset(components=[input_image],
                                        samples=[[path.as_posix()]
                                                 for path in paths])

        gr.Markdown(FOOTER)

        detector_name.change(fn=model.set_model_name,
                             inputs=[detector_name],
                             outputs=None)
        detect_button.click(fn=model.detect_and_visualize,
                            inputs=[
                                input_image,
                                visualization_score_threshold,
                            ],
                            outputs=[
                                detection_results,
                                detection_visualization,
                            ])
        redraw_button.click(fn=model.visualize_detection_results,
                            inputs=[
                                input_image,
                                detection_results,
                                visualization_score_threshold,
                            ],
                            outputs=[detection_visualization])
        example_images.click(fn=set_example_image,
                             inputs=[example_images],
                             outputs=[input_image])

    demo.launch(
        enable_queue=args.enable_queue,
        server_port=args.port,
        share=args.share,
    )


if __name__ == '__main__':
    main()