|
import streamlit as st |
|
import pandas as pd |
|
import numpy as np |
|
from sentence_transformers import SentenceTransformer |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
import torch |
|
import json |
|
import os |
|
from pathlib import Path |
|
|
|
class VideoRetrieval: |
|
def __init__(self): |
|
self.text_model = SentenceTransformer('all-MiniLM-L6-v2') |
|
self.load_data() |
|
|
|
def load_data(self): |
|
|
|
|
|
self.features = { |
|
'visual_features': np.load('path_to_visual_features.npy'), |
|
'scene_features': np.load('path_to_scene_features.npy'), |
|
'object_features': np.load('path_to_object_features.npy') |
|
} |
|
|
|
|
|
self.clips_df = pd.read_csv('clips_metadata.csv') |
|
|
|
def encode_query(self, query_text): |
|
"""Encode the text query into embeddings""" |
|
return self.text_model.encode(query_text) |
|
|
|
def compute_similarity(self, query_embedding, feature_type='visual_features'): |
|
"""Compute similarity between query and video features""" |
|
similarities = cosine_similarity( |
|
query_embedding.reshape(1, -1), |
|
self.features[feature_type] |
|
) |
|
return similarities[0] |
|
|
|
def retrieve_clips(self, query_text, top_k=3): |
|
"""Retrieve top-k most relevant clips based on query""" |
|
|
|
query_embedding = self.encode_query(query_text) |
|
|
|
|
|
similarities = {} |
|
weights = { |
|
'visual_features': 0.4, |
|
'scene_features': 0.3, |
|
'object_features': 0.3 |
|
} |
|
|
|
for feat_type, weight in weights.items(): |
|
similarities[feat_type] = self.compute_similarity(query_embedding, feat_type) * weight |
|
|
|
|
|
combined_similarities = sum(similarities.values()) |
|
|
|
|
|
top_indices = np.argsort(combined_similarities)[-top_k:][::-1] |
|
|
|
|
|
results = [] |
|
for idx in top_indices: |
|
results.append({ |
|
'clip_id': self.clips_df.iloc[idx]['clip_id'], |
|
'movie_title': self.clips_df.iloc[idx]['movie_title'], |
|
'description': self.clips_df.iloc[idx]['description'], |
|
'timestamp': self.clips_df.iloc[idx]['timestamp'], |
|
'similarity_score': combined_similarities[idx] |
|
}) |
|
|
|
return results |
|
|
|
|
|
def main(): |
|
st.title("Movie Scene Retrieval System") |
|
st.write(""" |
|
Search for movie scenes using natural language descriptions. |
|
The system will retrieve the most relevant 2-3 minute clips based on your query. |
|
""") |
|
|
|
|
|
try: |
|
retrieval_system = st.session_state.retrieval_system |
|
except AttributeError: |
|
retrieval_system = VideoRetrieval() |
|
st.session_state.retrieval_system = retrieval_system |
|
|
|
|
|
query = st.text_input("Enter your scene description:", |
|
"A dramatic confrontation between two characters in a dark room") |
|
|
|
num_results = st.slider("Number of results to show:", min_value=1, max_value=5, value=3) |
|
|
|
if st.button("Search"): |
|
with st.spinner("Searching for relevant clips..."): |
|
results = retrieval_system.retrieve_clips(query, top_k=num_results) |
|
|
|
for i, result in enumerate(results, 1): |
|
st.subheader(f"Result {i}: {result['movie_title']}") |
|
col1, col2 = st.columns([2, 1]) |
|
|
|
with col1: |
|
st.write("**Scene Description:**") |
|
st.write(result['description']) |
|
st.write(f"**Timestamp:** {result['timestamp']}") |
|
|
|
with col2: |
|
st.write("**Similarity Score:**") |
|
st.progress(float(result['similarity_score'])) |
|
|
|
|
|
st.write("---") |
|
|
|
|
|
with st.sidebar: |
|
st.header("About") |
|
st.write(""" |
|
This system uses pre-computed visual features from several expert models to retrieve |
|
relevant movie clips based on natural language descriptions. Features include: |
|
|
|
- Visual scene understanding |
|
- Character interaction analysis |
|
- Object detection |
|
- Action recognition |
|
""") |
|
|
|
st.header("Feature Weights") |
|
st.write("Current weights used for similarity computation:") |
|
st.write("- Visual Features: 40%") |
|
st.write("- Scene Features: 30%") |
|
st.write("- Object Features: 30%") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|
|
|
|
''' |
|
streamlit==1.22.0 |
|
pandas==1.5.3 |
|
numpy==1.23.5 |
|
sentence-transformers==2.2.2 |
|
scikit-learn==1.2.2 |
|
torch==2.0.0 |
|
|
|
streamlit |
|
pandas |
|
numpy |
|
sentence-transformers |
|
scikit-learn |
|
torch |
|
|
|
''' |