| import ffmpeg | |
| import torch | |
| import youtube_dl | |
| import numpy as np | |
| import streamlit as st | |
| from sentence_transformers import SentenceTransformer, util, models | |
| from clip import CLIPModel | |
| from PIL import Image | |
| def get_model(): | |
| clip = CLIPModel() | |
| model = SentenceTransformer(modules=[clip]).to(dtype=torch.float32, device=torch.device('cpu')) | |
| return model | |
| def get_embedding(model, query, video): | |
| text_emb = model.encode(query, device='cpu') | |
| # Encode an image: | |
| images = [] | |
| for img in video: | |
| images.append(Image.fromarray(img)) | |
| img_embs = model.encode(images, device='cpu') | |
| return text_emb, img_embs | |
| def find_frames(url, model, desc, top_k, text): | |
| text.text("Processing video...") | |
| probe = ffmpeg.probe(url) | |
| video_stream = next((stream for stream in probe['streams'] if stream['codec_type'] == 'video'), None) | |
| width = int(video_stream['width']) | |
| height = int(video_stream['height']) | |
| out, _ = ( | |
| ffmpeg | |
| .input(url, t=60) | |
| .output('pipe:', format='rawvideo', pix_fmt='rgb24') | |
| .run(capture_stdout=True) | |
| ) | |
| video = ( | |
| np | |
| .frombuffer(out, np.uint8) | |
| .reshape([-1, height, width, 3]) | |
| )[::10] | |
| txt_embd, img_embds = get_embedding(model, desc, video) | |
| cos_scores = np.array(util.cos_sim(txt_embd, img_embds)) | |
| ids = np.argsort(cos_scores)[0][-top_k:] | |
| imgs = [Image.fromarray(video[i]) for i in ids] | |
| text.empty() | |
| st.image(imgs) | |
| def main_page(model): | |
| st.title("Introducing Youtube CLIFS") | |
| def clifs_page(model): | |
| st.title("CLIFS") | |
| st.sidebar.markdown("### Controls:") | |
| top_k = st.sidebar.slider( | |
| "Top K", | |
| min_value=1, | |
| max_value=5, | |
| step=1, | |
| ) | |
| desc = st.sidebar.text_input( | |
| "Search Description", | |
| value="Two white puppies", | |
| help="Text description of what you want to find in the video", | |
| ) | |
| url = st.sidebar.text_input( | |
| "Youtube Video URL", | |
| value='https://youtu.be/I3AaW9ZevIU', | |
| help="Youtube video you'd like to search through", | |
| ) | |
| submit_button = st.sidebar.button("Search") | |
| if submit_button: | |
| text = st.text("Downloading video...") | |
| hook = lambda d: my_hook(d, ) | |
| ydl_opts = {"format": "mp4[height=360]"} | |
| with youtube_dl.YoutubeDL(ydl_opts) as ydl: | |
| info_dict = ydl.extract_info(url, download=False) | |
| video_url = info_dict.get("url", None) | |
| find_frames(video_url, model, desc, top_k, text) | |
| print(video_url) | |
| # ydl.download([url]) | |
| PAGES = { | |
| "Home": main_page, | |
| "CLIFS": clifs_page | |
| } | |
| def run(): | |
| st.set_page_config(page_title="Youtube CLIFS") | |
| # main body | |
| model = get_model() | |
| st.sidebar.title('Navigation') | |
| selection = st.sidebar.radio("Go to", list(PAGES.keys())) | |
| page = PAGES[selection](model) | |
| if __name__ == "__main__": | |
| run() |