File size: 4,193 Bytes
ef1c94f
79d80e3
 
c6145cf
 
 
 
 
ef1c94f
79d80e3
 
c6145cf
ef1c94f
79d80e3
 
 
 
 
ef1c94f
79d80e3
 
 
ef1c94f
c6145cf
5bff35b
c6145cf
 
5bff35b
 
c6145cf
 
24860f2
 
 
 
 
c6145cf
24860f2
 
 
c6145cf
24860f2
ef1c94f
 
24860f2
ef1c94f
 
 
5bff35b
24860f2
79d80e3
c6145cf
79d80e3
c6145cf
79d80e3
5bff35b
ef1c94f
c6145cf
5bff35b
 
79d80e3
c6145cf
 
79d80e3
 
 
 
 
5bff35b
 
79d80e3
 
 
 
 
 
c6145cf
 
79d80e3
 
 
 
 
 
c6145cf
79d80e3
 
c6145cf
79d80e3
5bff35b
 
79d80e3
ef1c94f
79d80e3
 
c6145cf
79d80e3
 
ef1c94f
 
79d80e3
 
5bff35b
79d80e3
ef1c94f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import gradio as gr 
from run_on_video.run import MomentDETRPredictor
from moviepy.video.io.ffmpeg_tools import ffmpeg_extract_subclip
import torch 

DESCRIPTION = """
_This Space demonstrates model [QVHighlights: Detecting Moments and Highlights in Videos via Natural Language Queries](https://arxiv.org/abs/2107.09609), NeurIPS 2021, by [Jie Lei](http://www.cs.unc.edu/~jielei/), [Tamara L. Berg](http://tamaraberg.com/), [Mohit Bansal](http://www.cs.unc.edu/~mbansal/)_
"""

ckpt_path = "run_on_video/moment_detr_ckpt/model_best.ckpt"
clip_model_name_or_path = "ViT-B/32"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

moment_detr_predictor = MomentDETRPredictor(
    ckpt_path=ckpt_path,
    clip_model_name_or_path=clip_model_name_or_path,
    device=device
)

def trim_video(video_path, start, end, output_file='result.mp4'):
    ffmpeg_extract_subclip(video_path, start, end, targetname=output_file)
    return output_file

def display_prediction(result):
    return f'### Moment  Start time:  {result[0]},   End time:  {result[1]},  Score:  {result[2]}'

with gr.Blocks(theme=gr.themes.Default()) as demo:
    output_videos = gr.State(None)
    moment_prediction = gr.State(None)
    gr.HTML("""<h2 align="center"> 🎞️  Highlight Detection with MomentDETR </h2>""")
    gr.Markdown(DESCRIPTION)
    with gr.Column():
        with gr.Row():
            with gr.Blocks():
                with gr.Column():
                    gr.HTML("""<h3 align="center"> Input Video </h3>""")
                    input_video = gr.Video(label="Please input mp4", height=400)
            with gr.Blocks():
                with gr.Column(): 
                    gr.HTML("""<h3 align="center"> Highlight Videos </h3>""")
                    playable_video = gr.Video(height=400)
        with gr.Row():
            with gr.Column():
                retrieval_text = gr.Textbox(
                    label="Query text", 
                    placeholder="What should be highlighted?",
                    visible=True
                )
                submit = gr.Button("Submit")
            with gr.Column():
                radio_button = gr.Radio(
                    choices=[i+1 for i in range(10)], 
                    label="Moments", 
                    value=1
                )
                display_score = gr.Markdown("### Moment Score: ")

        def update_video_player(radio_value, output_videos, moment_prediction):
            if output_videos is None or moment_prediction is None:
                return [None, None]
            return {
                playable_video: output_videos[radio_value-1], 
                display_score: display_prediction(moment_prediction[radio_value-1])
            }
                
    def submit_video(input_video, retrieval_text):
        print(f'== video path: {input_video}')
        print(f'== retrieval_text: {retrieval_text}')
        if input_video is None:
            return [None, None, None, None, 1]
        if retrieval_text is None:
            retrieval_text = ''
        predictions, video_frames = moment_detr_predictor.localize_moment(
            video_path=input_video, 
            query_list=[retrieval_text]
        )
        predictions = predictions[0]['pred_relevant_windows']
        pred_windows = [[pred[0], pred[1]]for pred in predictions]
        output_files = [ trim_video(
            video_path=input_video, 
            start=pred_windows[i][0], 
            end=pred_windows[i][1],
            output_file=f'{i}.mp4'
        ) for i in range(10)]

        return { 
            output_videos: output_files, 
            moment_prediction: predictions,
            playable_video:  output_files[0],
            display_score: display_prediction(predictions[0]),
            radio_button: 1
        }

    radio_button.change(
        fn=update_video_player, 
        inputs=[radio_button, output_videos, moment_prediction],
        outputs=[playable_video, display_score]
    )

    submit.click(
        fn=submit_video, 
        inputs=[input_video, retrieval_text], 
        outputs=[output_videos, moment_prediction, playable_video, display_score, radio_button]
    )

demo.launch()