Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pandas as pd | |
| import numpy as np | |
| from sentence_transformers import SentenceTransformer | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| import os | |
| from datetime import datetime | |
| from datasets import load_dataset | |
| # Initialize session state | |
| if 'search_history' not in st.session_state: | |
| st.session_state['search_history'] = [] | |
| if 'search_columns' not in st.session_state: | |
| st.session_state['search_columns'] = [] | |
| if 'dataset_loaded' not in st.session_state: | |
| st.session_state['dataset_loaded'] = False | |
| if 'current_page' not in st.session_state: | |
| st.session_state['current_page'] = 0 | |
| if 'data_cache' not in st.session_state: | |
| st.session_state['data_cache'] = None | |
| if 'dataset_info' not in st.session_state: | |
| st.session_state['dataset_info'] = None | |
| ROWS_PER_PAGE = 100 # Number of rows to load at a time | |
| def get_model(): | |
| """Cache the model loading""" | |
| return SentenceTransformer('all-MiniLM-L6-v2') | |
| def load_dataset_page(dataset_id, token, page, rows_per_page): | |
| """Load and cache a specific page of data""" | |
| try: | |
| start_idx = page * rows_per_page | |
| end_idx = start_idx + rows_per_page | |
| dataset = load_dataset( | |
| dataset_id, | |
| token=token, | |
| streaming=False, | |
| split=f'train[{start_idx}:{end_idx}]' | |
| ) | |
| return pd.DataFrame(dataset) | |
| except Exception as e: | |
| st.error(f"Error loading page {page}: {str(e)}") | |
| return pd.DataFrame() | |
| def get_dataset_info(dataset_id, token): | |
| """Load and cache dataset information""" | |
| try: | |
| dataset = load_dataset( | |
| dataset_id, | |
| token=token, | |
| streaming=True | |
| ) | |
| return dataset['train'].info | |
| except Exception as e: | |
| st.error(f"Error loading dataset info: {str(e)}") | |
| return None | |
| class FastDatasetSearcher: | |
| def __init__(self, dataset_id="tomg-group-umd/cinepile"): | |
| self.dataset_id = dataset_id | |
| self.text_model = get_model() | |
| self.token = os.environ.get('DATASET_KEY') | |
| if not self.token: | |
| st.error("Please set the DATASET_KEY environment variable with your Hugging Face token.") | |
| st.stop() | |
| # Load dataset info if not already loaded | |
| if st.session_state['dataset_info'] is None: | |
| st.session_state['dataset_info'] = get_dataset_info(self.dataset_id, self.token) | |
| def load_page(self, page=0): | |
| """Load a specific page of data using cached function""" | |
| return load_dataset_page(self.dataset_id, self.token, page, ROWS_PER_PAGE) | |
| def quick_search(self, query, df): | |
| """Fast search on current page""" | |
| if df.empty: | |
| return df | |
| scores = [] | |
| query_embedding = self.text_model.encode([query], show_progress_bar=False)[0] | |
| for _, row in df.iterrows(): | |
| # Combine all searchable text fields | |
| text_values = [] | |
| for v in row.values(): | |
| if isinstance(v, (str, int, float)): | |
| text_values.append(str(v)) | |
| elif isinstance(v, (list, dict)): | |
| text_values.append(str(v)) | |
| text = ' '.join(text_values) | |
| # Quick keyword match | |
| keyword_score = text.lower().count(query.lower()) / (len(text.split()) + 1) # Add 1 to avoid division by zero | |
| # Semantic search on combined text | |
| text_embedding = self.text_model.encode([text], show_progress_bar=False)[0] | |
| semantic_score = cosine_similarity([query_embedding], [text_embedding])[0][0] | |
| # Combine scores | |
| combined_score = 0.5 * semantic_score + 0.5 * keyword_score | |
| scores.append(combined_score) | |
| # Get top results | |
| results_df = df.copy() | |
| results_df['score'] = scores | |
| return results_df.sort_values('score', ascending=False) | |
| def render_result(result): | |
| """Render a single search result""" | |
| score = result.pop('score', 0) | |
| # Display video if available | |
| if 'youtube_id' in result: | |
| st.video( | |
| f"https://youtube.com/watch?v={result['youtube_id']}&t={result.get('start_time', 0)}" | |
| ) | |
| # Display other fields | |
| cols = st.columns([2, 1]) | |
| with cols[0]: | |
| for key, value in result.items(): | |
| if isinstance(value, (str, int, float)): | |
| st.write(f"**{key}:** {value}") | |
| with cols[1]: | |
| st.metric("Relevance Score", f"{score:.2%}") | |
| def main(): | |
| st.title("🎥 Fast Video Dataset Search") | |
| # Initialize search class | |
| searcher = FastDatasetSearcher() | |
| # Show dataset info | |
| if st.session_state['dataset_info']: | |
| st.sidebar.write("### Dataset Info") | |
| st.sidebar.write(f"Total examples: {st.session_state['dataset_info'].splits['train'].num_examples:,}") | |
| total_pages = st.session_state['dataset_info'].splits['train'].num_examples // ROWS_PER_PAGE | |
| current_page = st.number_input("Page", min_value=0, max_value=total_pages, value=st.session_state['current_page']) | |
| else: | |
| current_page = st.number_input("Page", min_value=0, value=st.session_state['current_page']) | |
| # Load current page | |
| with st.spinner(f"Loading page {current_page}..."): | |
| df = searcher.load_page(current_page) | |
| if df.empty: | |
| st.warning("No data available for this page.") | |
| return | |
| # Search interface | |
| col1, col2 = st.columns([3, 1]) | |
| with col1: | |
| query = st.text_input("Search in current page:", | |
| help="Searches within currently loaded data") | |
| with col2: | |
| max_results = st.slider("Max results", 1, ROWS_PER_PAGE, 10) | |
| if query: | |
| with st.spinner("Searching..."): | |
| results = searcher.quick_search(query, df) | |
| # Display results | |
| st.write(f"Found {len(results)} results on this page:") | |
| for i, (_, result) in enumerate(results.head(max_results).iterrows(), 1): | |
| with st.expander(f"Result {i}", expanded=i==1): | |
| render_result(result) | |
| # Show raw data | |
| with st.expander("Show Raw Data"): | |
| st.dataframe(df) | |
| # Navigation buttons | |
| cols = st.columns(2) | |
| with cols[0]: | |
| if st.button("⬅️ Previous Page") and current_page > 0: | |
| st.session_state['current_page'] = current_page - 1 | |
| st.rerun() | |
| with cols[1]: | |
| if st.button("Next Page ➡️"): | |
| st.session_state['current_page'] = current_page + 1 | |
| st.rerun() | |
| if __name__ == "__main__": | |
| main() |