File size: 3,726 Bytes
0fb1163
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0bb375d
0fb1163
 
 
 
 
 
 
 
 
 
 
 
88226d6
fbe1110
88226d6
0fb1163
 
 
 
 
 
 
fbe1110
0fb1163
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import torch
import gradio as gr
from transformers import AutoProcessor, AutoModel
from utils import (
    convert_frames_to_gif,
    download_youtube_video,
    get_num_total_frames,
    sample_frames_from_video_file,
)

FRAME_SAMPLING_RATE = 4
DEFAULT_MODEL = "microsoft/xclip-base-patch16-zero-shot"

VALID_ZEROSHOT_VIDEOCLASSIFICATION_MODELS = [
    "microsoft/xclip-base-patch16-zero-shot",
]

processor = AutoProcessor.from_pretrained(DEFAULT_MODEL)
model = AutoModel.from_pretrained(DEFAULT_MODEL)

def select_model(model_name):
    global processor, model
    processor = AutoProcessor.from_pretrained(model_name)
    model = AutoModel.from_pretrained(model_name)


def predict(youtube_url_or_file_path, labels_text):

    if youtube_url_or_file_path.startswith("http"):
        video_path = download_youtube_video(youtube_url_or_file_path)
    else:
        video_path = youtube_url_or_file_path
    num_total_frames = get_num_total_frames(video_path)
    num_model_input_frames = model.config.vision_config.num_frames
    if num_total_frames < FRAME_SAMPLING_RATE * num_model_input_frames:
        frame_sampling_rate = num_total_frames // num_model_input_frames
    else:
        frame_sampling_rate = FRAME_SAMPLING_RATE

    labels = labels_text.split(",")

    frames = sample_frames_from_video_file(
        video_path, num_model_input_frames, frame_sampling_rate
    )
    gif_path = convert_frames_to_gif(frames, save_path="video.gif")

    inputs = processor(
        text=labels, videos=list(frames), return_tensors="pt", padding=True
    )
    with torch.no_grad():
        outputs = model(**inputs)

    probs = outputs.logits_per_video[0].softmax(dim=-1).cpu().numpy()
    label_to_prob = {}
    for ind, label in enumerate(labels):
        label_to_prob[label] = float(probs[ind])

    return label_to_prob, gif_path


app = gr.Blocks()
with app:
    gr.Markdown(
        "# **<p align='center'> PROTOG - VIOLENCE DETECTION MODULE</p>**"
    )

    with gr.Row():
        with gr.Column():
            model_names_dropdown = gr.Dropdown(
                choices=VALID_ZEROSHOT_VIDEOCLASSIFICATION_MODELS,
                label="Model:",
                show_label=True,
                value=DEFAULT_MODEL,
            )
            model_names_dropdown.change(fn=select_model, inputs=model_names_dropdown)
            with gr.Tab(label="Youtube URL"):
                gr.Markdown(
                    "### **Enter Youtube URL**"
                )
                youtube_url = gr.Textbox(label="Youtube URL:", show_label=True)
                youtube_url_labels_text = gr.Textbox(
                    label="Labels Text:", show_label=True
                )
                youtube_url_predict_btn = gr.Button(value="Predict")
            with gr.Tab(label="Local File"):
                gr.Markdown(
                    "### **Video Upload**"
                )
                video_file = gr.Video(label="Video File:", show_label=True)
                local_video_labels_text = gr.Textbox(
                    label="Labels Text:", show_label=True
                )
                local_video_predict_btn = gr.Button(value="Predict")
        with gr.Column():
            video_gif = gr.Image(
                label="Input Clip",
                show_label=True,
            )
        with gr.Column():
            predictions = gr.Label(label="Predictions:", show_label=True)

    youtube_url_predict_btn.click(
        predict,
        inputs=[youtube_url, youtube_url_labels_text],
        outputs=[predictions, video_gif],
    )
    local_video_predict_btn.click(
        predict,
        inputs=[video_file, local_video_labels_text],
        outputs=[predictions, video_gif],
    )

app.launch()