import streamlit as st import pandas as pd import numpy as np from sentence_transformers import SentenceTransformer from sklearn.metrics.pairwise import cosine_similarity import os from datetime import datetime from datasets import load_dataset # Initialize session state if 'search_history' not in st.session_state: st.session_state['search_history'] = [] if 'search_columns' not in st.session_state: st.session_state['search_columns'] = [] if 'dataset_loaded' not in st.session_state: st.session_state['dataset_loaded'] = False if 'current_page' not in st.session_state: st.session_state['current_page'] = 0 if 'data_cache' not in st.session_state: st.session_state['data_cache'] = None if 'dataset_info' not in st.session_state: st.session_state['dataset_info'] = None ROWS_PER_PAGE = 100 # Number of rows to load at a time @st.cache_resource def get_model(): """Cache the model loading""" return SentenceTransformer('all-MiniLM-L6-v2') @st.cache_data def load_dataset_page(dataset_id, token, page, rows_per_page): """Load and cache a specific page of data""" try: start_idx = page * rows_per_page end_idx = start_idx + rows_per_page dataset = load_dataset( dataset_id, token=token, streaming=False, split=f'train[{start_idx}:{end_idx}]' ) return pd.DataFrame(dataset) except Exception as e: st.error(f"Error loading page {page}: {str(e)}") return pd.DataFrame() @st.cache_data def get_dataset_info(dataset_id, token): """Load and cache dataset information""" try: dataset = load_dataset( dataset_id, token=token, streaming=True ) return dataset['train'].info except Exception as e: st.error(f"Error loading dataset info: {str(e)}") return None class FastDatasetSearcher: def __init__(self, dataset_id="tomg-group-umd/cinepile"): self.dataset_id = dataset_id self.text_model = get_model() self.token = os.environ.get('DATASET_KEY') if not self.token: st.error("Please set the DATASET_KEY environment variable with your Hugging Face token.") st.stop() # Initialize numpy for model inputs self.np = np # Load dataset info if not already loaded if st.session_state['dataset_info'] is None: st.session_state['dataset_info'] = get_dataset_info(self.dataset_id, self.token) def load_page(self, page=0): """Load a specific page of data using cached function""" return load_dataset_page(self.dataset_id, self.token, page, ROWS_PER_PAGE) def quick_search(self, query, df): """Fast search on current page""" if df.empty: return df try: # Get columns to search (excluding numpy array columns) searchable_cols = [] for col in df.columns: sample_val = df[col].iloc[0] if not isinstance(sample_val, (np.ndarray, bytes)): searchable_cols.append(col) # Prepare query query_lower = query.lower() query_embedding = self.text_model.encode([query], show_progress_bar=False)[0] scores = [] # Process each row for _, row in df.iterrows(): # Combine text from searchable columns text_parts = [] for col in searchable_cols: val = row[col] if val is not None: if isinstance(val, (list, dict)): text_parts.append(str(val)) else: text_parts.append(str(val)) text = ' '.join(text_parts) # Calculate scores if text.strip(): # Keyword matching keyword_score = text.lower().count(query_lower) / max(len(text.split()), 1) # Semantic matching text_embedding = self.text_model.encode([text], show_progress_bar=False)[0] semantic_score = float(cosine_similarity([query_embedding], [text_embedding])[0][0]) # Combine scores combined_score = 0.5 * semantic_score + 0.5 * keyword_score else: combined_score = 0.0 scores.append(combined_score) # Get top results results_df = df.copy() results_df['score'] = scores return results_df.sort_values('score', ascending=False) except Exception as e: st.error(f"Search error: {str(e)}") return df # Get top results results_df = df.copy() results_df['score'] = scores return results_df.sort_values('score', ascending=False) def render_result(result): """Render a single search result""" # Get score from the Series score = result.get('score', 0) if 'score' in result else 0 result_filtered = result.drop('score') if 'score' in result else result # Display video if available if 'youtube_id' in result: st.video( f"https://youtube.com/watch?v={result['youtube_id']}&t={result.get('start_time', 0)}" ) # Display other fields cols = st.columns([2, 1]) with cols[0]: for key, value in result_filtered.items(): if isinstance(value, (str, int, float)): st.write(f"**{key}:** {value}") with cols[1]: st.metric("Relevance Score", f"{score:.2%}") def main(): st.title("🎥 Fast Video Dataset Search") # Initialize search class searcher = FastDatasetSearcher() # Show dataset info if st.session_state['dataset_info']: st.sidebar.write("### Dataset Info") st.sidebar.write(f"Total examples: {st.session_state['dataset_info'].splits['train'].num_examples:,}") total_pages = st.session_state['dataset_info'].splits['train'].num_examples // ROWS_PER_PAGE current_page = st.number_input("Page", min_value=0, max_value=total_pages, value=st.session_state['current_page']) else: current_page = st.number_input("Page", min_value=0, value=st.session_state['current_page']) # Load current page with st.spinner(f"Loading page {current_page}..."): df = searcher.load_page(current_page) if df.empty: st.warning("No data available for this page.") return # Search interface col1, col2 = st.columns([3, 1]) with col1: query = st.text_input("Search in current page:", help="Searches within currently loaded data") with col2: max_results = st.slider("Max results", 1, ROWS_PER_PAGE, 10) if query: with st.spinner("Searching..."): results = searcher.quick_search(query, df) # Display results st.write(f"Found {len(results)} results on this page:") for i, (_, result) in enumerate(results.head(max_results).iterrows(), 1): with st.expander(f"Result {i}", expanded=i==1): render_result(result) # Show raw data with st.expander("Show Raw Data"): st.dataframe(df) # Navigation buttons cols = st.columns(2) with cols[0]: if st.button("⬅️ Previous Page") and current_page > 0: st.session_state['current_page'] = current_page - 1 st.rerun() with cols[1]: if st.button("Next Page ➡️"): st.session_state['current_page'] = current_page + 1 st.rerun() if __name__ == "__main__": main()