File size: 4,290 Bytes
e3876f0
b3660df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2fce180
 
 
b3660df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56a1b99
 
b3660df
 
 
 
 
 
 
 
 
 
 
a777a7d
b3660df
81071ed
b3660df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81071ed
d598fd0
81071ed
a74ea5d
b3660df
81071ed
3a9a338
b3660df
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
### app.py
# User interface for the demo.
###

import os, pdb
import pandas as pd
import gradio as gr
from gradio_rich_textbox import RichTextbox

from demo import VideoMIRModel


def load_v2t_samples(data_root):
    sample_videos = []
    df = pd.read_csv("meta/ek100_mir/sel_v2t.csv", header=None)
    idx2sid = {}
    for i, x in enumerate(df[0].values):
        sample_videos.append(f'{data_root}/video/gif/{x}.gif')
        idx2sid[i] = x

    return sample_videos, idx2sid

def load_t2v_samples(data_root):
    #sample_text = ['cut the sausage', 'stir vegetables into salmon', 'rinse cutting board']
    sample_text = ['cut the sausage', 'rinse cutting board']
    idx2sid = {0: 2119, 1: 1730, 2: 1276}
    return sample_text, idx2sid

def format_pred(pred, gt):
    tp = '[color=green]{}[/color]'
    fp = '[color=red]{}[/color]'
    fmt_pred = []
    for x in pred:
        if x in gt:
            fmt_pred.append(tp.format(x))
        else:
            fmt_pred.append(fp.format(x))

    return ', '.join(fmt_pred)

def main():
    lavila = VideoMIRModel("configs/ek100_mir/zeroshot.yml")
    egovpa = VideoMIRModel("configs/ek100_mir/egovpa.yml")
    v2t_samples, idx2sid_v2t = load_v2t_samples('data/ek100_mir')
    t2v_samples, idx2sid_t2v = load_t2v_samples('data/ek100_mir')
    print(v2t_samples)

    def predict_v2t(idx):
        if idx == 1:
            idx = 2
        sid = idx2sid_v2t[idx]
        zeroshot_action, gt_action = lavila.predict_v2t(idx, sid)
        egovpa_action, gt_action = egovpa.predict_v2t(idx, sid)
        zeroshot_action = format_pred(zeroshot_action, gt_action)
        egovpa_action = format_pred(egovpa_action, gt_action)
    
        return gt_action, zeroshot_action, egovpa_action

    def predict_t2v(idx):
        sid = idx2sid_t2v[idx]
        egovpa_video, gt_video = egovpa.predict_t2v(idx, sid)
        egovpa_video = [f'data/ek100_mir/video/gif/{x}.gif' for x in egovpa_video]
    
        return egovpa_video

    with gr.Blocks() as demo:
        with gr.Tab("Video-to-text retrieval"):
            gr.Markdown(
                """
                # Ego-VPA Demo
                Choose a sample video and click predict to view the text queried by the selected video
                (<span style="color:green">correct</span>/<span style="color:red">incorrect</span>).
                """
            )

            with gr.Row():        
                with gr.Column():
                    video = gr.Image(label="video query", height='300px', interactive=False)
                with gr.Column():
                    idx = gr.Number(label="Idx", visible=False)
                    label = RichTextbox(label="Ground Truth", visible=False)
                    zeroshot = RichTextbox(label="LaViLa (zero-shot) prediction")
                    ours = RichTextbox(label="Ego-VPA prediction")
            btn = gr.Button("Predict", variant="primary")
            btn.click(predict_v2t, inputs=[idx], outputs=[label, zeroshot, ours])
            gr.Examples(examples=[[i, x] for i, x in enumerate(v2t_samples)], inputs=[idx, video])

        with gr.Tab("Text-to-video retrieval"):
            gr.Markdown(
                """
                # Ego-VPA Demo
                Choose a sample narration and click predict to view the video queried by the selected text.
                """
            )

            with gr.Row():        
                with gr.Column():
                    text = gr.Text(label="text query")
                with gr.Column():
                    idx = gr.Number(label="Idx", visible=False)
                    #zeroshot = gr.Textbox(label="LaViLa (zero-shot) prediction")
                    #zeroshot = gr.Gallery(label="LaViLa (zero-shot) prediction", columns=[3], rows=[1], object_fit="contain", height="auto")
                    #ours = gr.Textbox(label="Ego-VPA prediction")
                    ours = gr.Gallery(label="Ego-VPA prediction", columns=[1], rows=[1], object_fit="contain", height="auto")
            btn = gr.Button("Predict", variant="primary")
            btn.click(predict_t2v, inputs=[idx], outputs=[ours])
            gr.Examples(examples=[[i, x] for i, x in enumerate(t2v_samples)], inputs=[idx, text])



    demo.launch(share=True)


if __name__ == "__main__":
    main()