File size: 3,457 Bytes
021b099
fd2744e
021b099
 
 
fd2744e
021b099
 
 
 
 
fd2744e
 
52636a1
fd2744e
52636a1
 
021b099
 
52636a1
 
021b099
 
 
 
 
52636a1
021b099
 
 
52636a1
 
c40e192
 
 
 
 
 
52636a1
c40e192
 
 
52636a1
 
c40e192
 
 
 
 
 
52636a1
c40e192
 
 
 
 
 
 
52636a1
 
 
 
c40e192
52636a1
 
c40e192
52636a1
c40e192
fd2744e
 
52636a1
 
 
 
 
 
 
fd2744e
 
 
 
 
 
 
 
52636a1
fd2744e
 
 
 
52636a1
fd2744e
 
 
 
 
c40e192
fd2744e
c40e192
 
52636a1
c40e192
 
 
 
 
 
 
 
 
 
 
52636a1
c40e192
 
 
 
52636a1
c40e192
 
fd2744e
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import ffmpeg
import torch
import youtube_dl

import numpy as np
import streamlit as st

from sentence_transformers import SentenceTransformer, util, models
from clip import CLIPModel
from PIL import Image

@st.cache(allow_output_mutation=True, max_entries=1)
def get_model():
    txt_model = SentenceTransformer('clip-ViT-B-32-multilingual-v1').to(dtype=torch.float32, device=torch.device('cpu'))
    clip = CLIPModel()
    vis_model = SentenceTransformer(modules=[clip]).to(dtype=torch.float32, device=torch.device('cpu'))
    return txt_model, vis_model


def get_embedding(txt_model, vis_model, query, video):
    text_emb = txt_model.encode(query, device='cpu')

    # Encode an image:
    images = []
    for img in video:
        images.append(Image.fromarray(img))
    img_embs = vis_model.encode(images, device='cpu')

    return text_emb, img_embs

def find_frames(url, txt_model, vis_model, desc, seconds, top_k):
    text = st.text("Downloading video...")
    probe = ffmpeg.probe(url)
    video_stream = next((stream for stream in probe['streams'] if stream['codec_type'] == 'video'), None)
    width = int(video_stream['width'])
    height = int(video_stream['height'])
    out, _ = (
        ffmpeg
        .input(url, t=seconds)
        .output('pipe:', format='rawvideo', pix_fmt='rgb24')
        .run(capture_stdout=True)
    )

    text.text("Processing video...")
    video = (
        np
        .frombuffer(out, np.uint8)
        .reshape([-1, height, width, 3])
    )[::10]

    txt_embd, img_embds = get_embedding(txt_model, vis_model, desc, video)
    cos_scores = np.array(util.cos_sim(txt_embd, img_embds))
    ids = np.argsort(cos_scores)[0][-top_k:]

    imgs = [Image.fromarray(video[i]) for i in ids]
    text.empty()
    st.image(imgs)

with open("HOME.md", "r") as f:
    HOME_PAGE = f.read()

def main_page(txt_model, vis_model):
    st.title("Introducing Youtube CLIFS")
    
    st.markdown(HOME_PAGE)

def clifs_page(txt_model, vis_model):
    st.title("CLIFS")

    st.sidebar.markdown("### Controls:")
    seconds = st.sidebar.slider(
        "How many seconds of video to consider?",
        min_value=10,
        max_value=120,
        value=60,
        step=1,
    )
    top_k = st.sidebar.slider(
        "Top K",
        min_value=1,
        max_value=5,
        step=1,
    )
    desc = st.sidebar.text_input(
        "Search Description",
        value="Pancake in the shape of an otter", # panqueque en forma de nutria
        help="Text description of what you want to find in the video",
    )
    url = st.sidebar.text_input(
        "Youtube Video URL",
        value='https://youtu.be/xUv6XgPwGaQ',
        help="Youtube video you'd like to search through",
    )

    submit_button = st.sidebar.button("Search")
    if submit_button:
        ydl_opts = {"format": "mp4[height=360]"}
        with youtube_dl.YoutubeDL(ydl_opts) as ydl:
            info_dict = ydl.extract_info(url, download=False)
            video_url = info_dict.get("url", None)
            find_frames(video_url, txt_model, vis_model, desc, seconds, top_k)

PAGES = {
    "Home": main_page,
    "CLIFS": clifs_page
}



def run():
    st.set_page_config(page_title="Youtube CLIFS")
    # main body
    txt_model, vis_model = get_model()

    st.sidebar.title('Navigation')
    selection = st.sidebar.radio("Go to", list(PAGES.keys()))

    page = PAGES[selection](txt_model, vis_model)
    
    


if __name__ == "__main__":
    run()