import streamlit as st import pandas as pd import numpy as np from sentence_transformers import SentenceTransformer from sklearn.metrics.pairwise import cosine_similarity import os from datetime import datetime import requests from datasets import load_dataset from urllib.parse import quote # Initialize session state if 'search_history' not in st.session_state: st.session_state['search_history'] = [] if 'search_columns' not in st.session_state: st.session_state['search_columns'] = [] if 'initial_search_done' not in st.session_state: st.session_state['initial_search_done'] = False if 'dataset' not in st.session_state: st.session_state['dataset'] = None class DatasetSearcher: def __init__(self, dataset_id="tomg-group-umd/cinepile"): self.dataset_id = dataset_id self.text_model = SentenceTransformer('all-MiniLM-L6-v2') self.token = os.environ.get('DATASET_KEY') if not self.token: st.error("Please set the DATASET_KEY environment variable with your Hugging Face token.") st.stop() self.load_dataset() def load_dataset(self): """Load dataset using the datasets library""" try: if st.session_state['dataset'] is None: with st.spinner("Loading dataset..."): st.session_state['dataset'] = load_dataset( self.dataset_id, token=self.token, streaming=False ) self.dataset = st.session_state['dataset'] # Convert first split to DataFrame for easier processing first_split = next(iter(self.dataset.values())) self.df = pd.DataFrame(first_split) # Store column information self.columns = list(self.df.columns) self.text_columns = [col for col in self.columns if self.df[col].dtype == 'object' and not any(term in col.lower() for term in ['embed', 'vector', 'encoding'])] # Update session state columns st.session_state['search_columns'] = self.text_columns # Prepare text embeddings self.prepare_features() except Exception as e: st.error(f"Error loading dataset: {str(e)}") st.error("Please check your authentication token and internet connection.") st.stop() def prepare_features(self): """Prepare text embeddings for semantic search""" try: # Combine text columns for embedding combined_text = self.df[self.text_columns].fillna('').agg(' '.join, axis=1) # Create embeddings in batches to manage memory batch_size = 32 all_embeddings = [] with st.spinner("Preparing search features..."): for i in range(0, len(combined_text), batch_size): batch = combined_text[i:i+batch_size].tolist() embeddings = self.text_model.encode(batch) all_embeddings.append(embeddings) self.text_embeddings = np.vstack(all_embeddings) except Exception as e: st.warning(f"Error preparing features: {str(e)}") self.text_embeddings = np.random.randn(len(self.df), 384) def search(self, query, column=None, top_k=20): """Search the dataset using semantic and keyword matching""" if self.df.empty: return [] # Get semantic similarity scores query_embedding = self.text_model.encode([query])[0] similarities = cosine_similarity([query_embedding], self.text_embeddings)[0] # Get keyword match scores search_columns = [column] if column and column != "All Fields" else self.text_columns keyword_scores = np.zeros(len(self.df)) for col in search_columns: if col in self.df.columns: matches = self.df[col].fillna('').str.lower().str.count(query.lower()) keyword_scores += matches # Combine scores combined_scores = 0.5 * similarities + 0.5 * (keyword_scores / max(1, keyword_scores.max())) # Get top results top_k = min(top_k, len(combined_scores)) top_indices = np.argsort(combined_scores)[-top_k:][::-1] # Format results results = [] for idx in top_indices: result = { 'relevance_score': float(combined_scores[idx]), 'semantic_score': float(similarities[idx]), 'keyword_score': float(keyword_scores[idx]), **self.df.iloc[idx].to_dict() } results.append(result) return results def get_dataset_info(self): """Get information about the dataset""" if not self.dataset: return {} info = { 'splits': list(self.dataset.keys()), 'total_rows': sum(split.num_rows for split in self.dataset.values()), 'columns': self.columns, 'text_columns': self.text_columns, 'sample_rows': len(self.df), 'embeddings_shape': self.text_embeddings.shape } return info def render_video_result(result): """Render a video result with enhanced display""" col1, col2 = st.columns([2, 1]) with col1: if 'title' in result: st.markdown(f"**Title:** {result['title']}") if 'description' in result: st.markdown("**Description:**") st.write(result['description']) # Show timing information if available if 'start_time' in result and 'end_time' in result: st.markdown(f"**Time Range:** {result['start_time']}s - {result['end_time']}s") # Show additional metadata for key, value in result.items(): if key not in ['title', 'description', 'start_time', 'end_time', 'duration', 'relevance_score', 'semantic_score', 'keyword_score', 'video_id', 'youtube_id']: st.markdown(f"**{key.replace('_', ' ').title()}:** {value}") with col2: # Show search scores st.markdown("**Search Scores:**") cols = st.columns(3) cols[0].metric("Overall", f"{result['relevance_score']:.2%}") cols[1].metric("Semantic", f"{result['semantic_score']:.2%}") cols[2].metric("Keyword", f"{result['keyword_score']:.0f} matches") # Display video if available if 'youtube_id' in result: st.video(f"https://youtube.com/watch?v={result['youtube_id']}&t={result.get('start_time', 0)}") def main(): st.title("🎥 Video Dataset Search") # Initialize search class searcher = DatasetSearcher() # Create tabs tab1, tab2 = st.tabs(["🔍 Search", "📊 Dataset Info"]) # ---- Tab 1: Search ---- with tab1: st.subheader("Search Videos") col1, col2 = st.columns([3, 1]) with col1: query = st.text_input("Search query:", value="" if st.session_state['initial_search_done'] else "") with col2: search_column = st.selectbox("Search in field:", ["All Fields"] + st.session_state['search_columns']) col3, col4 = st.columns(2) with col3: num_results = st.slider("Number of results:", 1, 100, 20) with col4: search_button = st.button("🔍 Search") if search_button and query: st.session_state['initial_search_done'] = True selected_column = None if search_column == "All Fields" else search_column with st.spinner("Searching..."): results = searcher.search(query, selected_column, num_results) st.session_state['search_history'].append({ 'query': query, 'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S"), 'results': results[:5] }) for i, result in enumerate(results, 1): with st.expander( f"Result {i}: {result.get('title', result.get('description', 'No title'))[:100]}...", expanded=(i==1) ): render_video_result(result) # ---- Tab 2: Dataset Info ---- with tab2: st.subheader("Dataset Information") info = searcher.get_dataset_info() if info: st.write(f"### Dataset: {searcher.dataset_id}") st.write(f"- Total rows: {info['total_rows']:,}") st.write(f"- Available splits: {', '.join(info['splits'])}") st.write(f"- Number of columns: {len(info['columns'])}") st.write(f"- Searchable text columns: {', '.join(info['text_columns'])}") st.write("### Sample Data") st.dataframe(searcher.df.head()) st.write("### Column Details") for col in info['columns']: st.write(f"- **{col}**: {searcher.df[col].dtype}") # Sidebar with st.sidebar: st.subheader("⚙️ Search History") if st.button("🗑️ Clear History"): st.session_state['search_history'] = [] st.experimental_rerun() st.markdown("### Recent Searches") for entry in reversed(st.session_state['search_history'][-5:]): with st.expander(f"{entry['timestamp']}: {entry['query']}"): for i, result in enumerate(entry['results'], 1): st.write(f"{i}. {result.get('title', result.get('description', 'No title'))[:100]}...") if __name__ == "__main__": main()