File size: 5,736 Bytes
ef1c94f
79d80e3
c6145cf
84805b3
c639a10
 
72851a0
c6145cf
 
 
 
ef1c94f
84805b3
 
79d80e3
 
ef1c94f
79d80e3
 
 
 
 
ef1c94f
 
c6145cf
84805b3
c6145cf
 
5bff35b
84805b3
5bff35b
84805b3
 
c6145cf
 
24860f2
 
 
 
 
c6145cf
24860f2
 
84805b3
c6145cf
84805b3
 
 
 
 
 
24860f2
ef1c94f
 
24860f2
ef1c94f
 
 
5bff35b
24860f2
79d80e3
c6145cf
84805b3
c6145cf
79d80e3
84805b3
ef1c94f
84805b3
5bff35b
84805b3
79d80e3
84805b3
 
 
 
79d80e3
 
 
c639a10
 
 
9a5c94c
c639a10
79d80e3
 
5bff35b
c639a10
79d80e3
 
 
 
 
 
c6145cf
b23d98b
79d80e3
c639a10
84805b3
 
79d80e3
 
84805b3
 
 
 
c639a10
84805b3
 
 
ec35ab9
84805b3
c639a10
 
 
 
 
 
 
 
 
 
 
 
ef1c94f
79d80e3
 
84805b3
 
79d80e3
ef1c94f
 
79d80e3
 
c639a10
79d80e3
ef1c94f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import gradio as gr 
from run_on_video.run import MomentDETRPredictor
import torch 
from lbhd.infer import lbhd_predict
import os
import subprocess
from utils.export_utils import trim_video

DESCRIPTION = """
_This Space demonstrates model [QVHighlights: Detecting Moments and Highlights in Videos via Natural Language Queries](https://arxiv.org/abs/2107.09609), NeurIPS 2021, by [Jie Lei](http://www.cs.unc.edu/~jielei/), [Tamara L. Berg](http://tamaraberg.com/), [Mohit Bansal](http://www.cs.unc.edu/~mbansal/)_
"""

device = 'cuda' if torch.cuda.is_available() else 'cpu'

ckpt_path = "run_on_video/moment_detr_ckpt/model_best.ckpt"
clip_model_name_or_path = "ViT-B/32"

moment_detr_predictor = MomentDETRPredictor(
    ckpt_path=ckpt_path,
    clip_model_name_or_path=clip_model_name_or_path,
    device=device
)


def display_prediction(result):
    return f'### Start time:  {result[0]:.2f},   End time:  {result[1]:.2f},  Score:  {result[2]:.2f}'

with gr.Blocks(theme=gr.themes.Default()) as demo:
    output_videos = gr.State(None)
    output_lbhd_videos = gr.State(None)
    moment_prediction = gr.State(None)
    our_prediction = gr.State(None)

    gr.HTML("""<h2 align="center"> 🎞️  Highlight Detection with MomentDETR </h2>""")
    gr.Markdown(DESCRIPTION)
    with gr.Column():
        with gr.Row():
            with gr.Blocks():
                with gr.Column():
                    gr.HTML("""<h3 align="center"> Input Video </h3>""")
                    input_video = gr.Video(label="Please input mp4", height=400)
            with gr.Blocks():
                with gr.Column(): 
                    gr.HTML("""<h3 align="center"> MomentDETR Result </h3>""")
                    playable_video = gr.Video(height=400)
                    display_score = gr.Markdown("### Start time, End time, Score")
            with gr.Blocks():
                with gr.Column(): 
                    gr.HTML("""<h3 align="center"> Ours Result </h3>""")
                    our_result_video = gr.Video(height=400)
                    display_clip_score = gr.Markdown("### Start time, End time, Score")
        with gr.Row():
            with gr.Column():
                retrieval_text = gr.Textbox(
                    label="Query text", 
                    placeholder="What should be highlighted?",
                    visible=True
                )
                submit = gr.Button("Submit")
            with gr.Column():
                radio_button = gr.Radio(
                    choices=[i+1 for i in range(10)], 
                    label="Top 10", 
                    value=1
                )
            

        def update_video_player(radio_value, output_videos, output_lbhd_videos, moment_prediction, our_prediction):
            if output_videos is None or moment_prediction is None:
                return [None, None, None, None]
            return {
                playable_video: output_videos[radio_value-1],
                our_result_video: output_lbhd_videos[min(radio_value-1, len(output_lbhd_videos)-1)],
                display_score: display_prediction(moment_prediction[radio_value-1]),
                display_clip_score: display_prediction(our_prediction[min(radio_value-1, len(output_lbhd_videos)-1)])
            }
                
    def submit_video(input_video, retrieval_text):
        ext = os.path.splitext(input_video)[-1].lower()
        if ext == ".mov":
            output_file = os.path.join(input_video.replace(".mov", ".mp4"))
            subprocess.call(['ffmpeg', '-i', input_video, "-vf", "scale=320:-2", output_file])  

        print(f'== video path: {input_video}')
        print(f'== retrieval_text: {retrieval_text}')
        if input_video is None:
            return [None, None, None, None, None, None, None, None, None, 1]
        if retrieval_text is None:
            retrieval_text = ''
        predictions, video_frames = moment_detr_predictor.localize_moment(
            video_path=input_video, 
            query_list=[retrieval_text]
        )
        predictions = predictions[0]['pred_relevant_windows']
        print(f'== Moment prediction: {predictions}')
        output_files = [ trim_video(
            video_path= output_file if ext == ".mov" else input_video, 
            start=predictions[i][0], 
            end=predictions[i][1],
            output_file=f'{i}.mp4'
        ) for i in range(10)]
        
        lbhd_predictions = lbhd_predict(input_video)
        print(f'== lbhd_predictions: {lbhd_predictions}')
        output_files_lbhd = [ trim_video(
            video_path= output_file if ext == ".mov" else input_video, 
            start=lbhd_predictions[i][0], 
            end=lbhd_predictions[i][1],
            output_file=f'{i}_lbhd.mp4'
        ) for i in range(min(10, len(lbhd_predictions)))]
        
        return [
            output_file if ext == ".mov" else input_video,
            output_files,
            output_files_lbhd,
            predictions,
            lbhd_predictions,
            output_files[0],
            output_files_lbhd[0],
            display_prediction(predictions[0]),
            display_prediction(lbhd_predictions[0]),
            1
        ]

    radio_button.change(
        fn=update_video_player, 
        inputs=[radio_button, output_videos, output_lbhd_videos, moment_prediction, our_prediction],
        outputs=[playable_video, our_result_video, display_score, display_clip_score]
    )

    submit.click(
        fn=submit_video, 
        inputs=[input_video, retrieval_text], 
        outputs=[input_video, output_videos, output_lbhd_videos, moment_prediction, our_prediction, playable_video, our_result_video, display_score, display_clip_score, radio_button]
    )

demo.launch()