import streamlit as st import pandas as pd import numpy as np from sentence_transformers import SentenceTransformer from sklearn.metrics.pairwise import cosine_similarity import os from datetime import datetime from datasets import load_dataset # Initialize session state if 'search_history' not in st.session_state: st.session_state['search_history'] = [] if 'search_columns' not in st.session_state: st.session_state['search_columns'] = [] if 'dataset_loaded' not in st.session_state: st.session_state['dataset_loaded'] = False if 'current_page' not in st.session_state: st.session_state['current_page'] = 0 if 'data_cache' not in st.session_state: st.session_state['data_cache'] = None ROWS_PER_PAGE = 100 # Number of rows to load at a time @st.cache_resource def get_model(): return SentenceTransformer('all-MiniLM-L6-v2') class FastDatasetSearcher: def __init__(self, dataset_id="tomg-group-umd/cinepile"): self.dataset_id = dataset_id self.text_model = get_model() self.token = os.environ.get('DATASET_KEY') if not self.token: st.error("Please set the DATASET_KEY environment variable with your Hugging Face token.") st.stop() self.load_dataset_info() @st.cache_data def load_dataset_info(self): """Load dataset metadata only""" try: dataset = load_dataset( self.dataset_id, token=self.token, streaming=True ) self.dataset_info = dataset['train'].info return True except Exception as e: st.error(f"Error loading dataset: {str(e)}") return False def load_page(self, page=0): """Load a specific page of data""" if st.session_state['data_cache'] is not None and st.session_state['current_page'] == page: return st.session_state['data_cache'] try: dataset = load_dataset( self.dataset_id, token=self.token, streaming=False, split=f'train[{page*ROWS_PER_PAGE}:{(page+1)*ROWS_PER_PAGE}]' ) df = pd.DataFrame(dataset) st.session_state['data_cache'] = df st.session_state['current_page'] = page return df except Exception as e: st.error(f"Error loading page {page}: {str(e)}") return pd.DataFrame() def quick_search(self, query, df): """Fast search on current page""" scores = [] query_embedding = self.text_model.encode([query], show_progress_bar=False)[0] for _, row in df.iterrows(): # Combine all searchable text fields text = ' '.join(str(v) for v in row.values() if isinstance(v, (str, int, float))) # Quick keyword match keyword_score = text.lower().count(query.lower()) / len(text.split()) # Semantic search on combined text text_embedding = self.text_model.encode([text], show_progress_bar=False)[0] semantic_score = cosine_similarity([query_embedding], [text_embedding])[0][0] # Combine scores combined_score = 0.5 * semantic_score + 0.5 * keyword_score scores.append(combined_score) # Get top results df['score'] = scores return df.sort_values('score', ascending=False) def main(): st.title("🎥 Fast Video Dataset Search") # Initialize search class searcher = FastDatasetSearcher() # Page navigation page = st.number_input("Page", min_value=0, value=st.session_state['current_page']) # Load current page with st.spinner(f"Loading page {page}..."): df = searcher.load_page(page) if df.empty: st.warning("No data available for this page.") return # Search interface query = st.text_input("Search in current page:", help="Searches within currently loaded data") if query: with st.spinner("Searching..."): results = searcher.quick_search(query, df) # Display results st.write(f"Found {len(results)} results on this page:") for i, (_, result) in enumerate(results.iterrows(), 1): score = result.pop('score') with st.expander(f"Result {i} (Score: {score:.2%})", expanded=i==1): # Display video if available if 'youtube_id' in result: st.video( f"https://youtube.com/watch?v={result['youtube_id']}&t={result.get('start_time', 0)}" ) # Display other fields for key, value in result.items(): if isinstance(value, (str, int, float)): st.write(f"**{key}:** {value}") # Show raw data st.subheader("Raw Data") st.dataframe(df) # Navigation buttons cols = st.columns(2) with cols[0]: if st.button("Previous Page") and page > 0: st.session_state['current_page'] -= 1 st.rerun() with cols[1]: if st.button("Next Page"): st.session_state['current_page'] += 1 st.rerun() if __name__ == "__main__": main()