File size: 2,978 Bytes
021b099
fd2744e
021b099
 
 
fd2744e
021b099
 
 
 
 
fd2744e
 
 
 
 
021b099
 
fd2744e
021b099
 
 
 
 
 
 
 
 
 
c40e192
 
 
 
 
 
 
 
43512f0
c40e192
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd2744e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c40e192
 
fd2744e
c40e192
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd2744e
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import ffmpeg
import torch
import youtube_dl

import numpy as np
import streamlit as st

from sentence_transformers import SentenceTransformer, util, models
from clip import CLIPModel
from PIL import Image

@st.cache(allow_output_mutation=True, max_entries=1)
def get_model():
    clip = CLIPModel()
    model = SentenceTransformer(modules=[clip]).to(dtype=torch.float32, device=torch.device('cpu'))
    return model


def get_embedding(model, query, video):
    text_emb = model.encode(query, device='cpu')

    # Encode an image:
    images = []
    for img in video:
        images.append(Image.fromarray(img))
    img_embs = model.encode(images, device='cpu')

    return text_emb, img_embs

def find_frames(url, model, desc, top_k, text):
    text.text("Processing video...")
    probe = ffmpeg.probe(url)
    video_stream = next((stream for stream in probe['streams'] if stream['codec_type'] == 'video'), None)
    width = int(video_stream['width'])
    height = int(video_stream['height'])
    out, _ = (
        ffmpeg
        .input(url, t=60)
        .output('pipe:', format='rawvideo', pix_fmt='rgb24')
        .run(capture_stdout=True)
    )
    video = (
        np
        .frombuffer(out, np.uint8)
        .reshape([-1, height, width, 3])
    )[::10]

    txt_embd, img_embds = get_embedding(model, desc, video)
    cos_scores = np.array(util.cos_sim(txt_embd, img_embds))
    ids = np.argsort(cos_scores)[0][-top_k:]

    imgs = [Image.fromarray(video[i]) for i in ids]
    text.empty()
    st.image(imgs)

def main_page(model):
    st.title("Introducing Youtube CLIFS")

def clifs_page(model):
    st.title("CLIFS")

    st.sidebar.markdown("### Controls:")
    top_k = st.sidebar.slider(
        "Top K",
        min_value=1,
        max_value=5,
        step=1,
    )
    desc = st.sidebar.text_input(
        "Search Description",
        value="Two white puppies",
        help="Text description of what you want to find in the video",
    )
    url = st.sidebar.text_input(
        "Youtube Video URL",
        value='https://youtu.be/I3AaW9ZevIU',
        help="Youtube video you'd like to search through",
    )

    submit_button = st.sidebar.button("Search")
    if submit_button:
        text = st.text("Downloading video...")
        hook = lambda d: my_hook(d, )
        ydl_opts = {"format": "mp4[height=360]"}
        with youtube_dl.YoutubeDL(ydl_opts) as ydl:
            info_dict = ydl.extract_info(url, download=False)
            video_url = info_dict.get("url", None)
            find_frames(video_url, model, desc, top_k, text)
            print(video_url)
            # ydl.download([url])

PAGES = {
    "Home": main_page,
    "CLIFS": clifs_page
}



def run():
    st.set_page_config(page_title="Youtube CLIFS")
    # main body
    model = get_model()

    st.sidebar.title('Navigation')
    selection = st.sidebar.radio("Go to", list(PAGES.keys()))

    page = PAGES[selection](model)
    
    


if __name__ == "__main__":
    run()