import streamlit as st import pandas as pd import numpy as np from sentence_transformers import SentenceTransformer from sklearn.metrics.pairwise import cosine_similarity import requests from datetime import datetime import os HF_KEY = os.getenv('DATASET_KEY') # Initialize session state variables if 'search_history' not in st.session_state: st.session_state['search_history'] = [] if 'search_columns' not in st.session_state: st.session_state['search_columns'] = [] if 'initial_search_done' not in st.session_state: st.session_state['initial_search_done'] = False if 'hf_token' not in st.session_state: st.session_state['hf_token'] = HF_KEY def fetch_dataset_info_auth(dataset_id, hf_token): """Fetch dataset information with authentication""" info_url = f"https://huggingface.co/api/datasets/{dataset_id}" headers = {"Authorization": f"Bearer {hf_token}"} try: response = requests.get(info_url, headers=headers, timeout=30) if response.status_code == 200: return response.json() except Exception as e: st.warning(f"Error fetching dataset info: {e}") return None def fetch_dataset_splits_auth(dataset_id, hf_token): """Fetch available splits for the dataset""" splits_url = f"https://datasets-server.huggingface.co/splits?dataset={dataset_id}" headers = {"Authorization": f"Bearer {hf_token}"} try: response = requests.get(splits_url, headers=headers, timeout=30) if response.status_code == 200: return response.json().get('splits', []) except Exception as e: st.warning(f"Error fetching splits: {e}") return [] def fetch_parquet_urls_auth(dataset_id, config, split, hf_token): """Fetch Parquet file URLs for a specific split""" parquet_url = f"https://huggingface.co/api/datasets/{dataset_id}/parquet/{config}/{split}" headers = {"Authorization": f"Bearer {hf_token}"} try: response = requests.get(parquet_url, headers=headers, timeout=30) if response.status_code == 200: return response.json() except Exception as e: st.warning(f"Error fetching parquet URLs: {e}") return [] def fetch_rows_auth(dataset_id, config, split, offset, length, hf_token): """Fetch rows with authentication""" url = f"https://datasets-server.huggingface.co/rows?dataset={dataset_id}&config={config}&split={split}&offset={offset}&length={length}" headers = {"Authorization": f"Bearer {hf_token}"} try: response = requests.get(url, headers=headers, timeout=30) if response.status_code == 200: return response.json() except Exception as e: st.warning(f"Error fetching rows: {e}") return None class ParquetVideoSearch: def __init__(self, hf_token): self.text_model = SentenceTransformer('all-MiniLM-L6-v2') self.dataset_id = "tomg-group-umd/cinepile" self.config = "v2" self.hf_token = hf_token self.load_dataset() def load_dataset(self): """Load initial dataset sample""" try: rows_data = fetch_rows_auth( self.dataset_id, self.config, "train", 0, 100, self.hf_token ) if rows_data and 'rows' in rows_data: processed_rows = [] for row_data in rows_data['rows']: row = row_data.get('row', row_data) processed_rows.append(row) self.dataset = pd.DataFrame(processed_rows) st.session_state['search_columns'] = [col for col in self.dataset.columns if not any(term in col.lower() for term in ['embed', 'vector', 'encoding'])] else: self.dataset = self.load_example_data() except Exception as e: st.warning(f"Error loading dataset: {e}") self.dataset = self.load_example_data() self.prepare_features() def load_example_data(self): """Load example data as fallback""" return pd.DataFrame([{ "video_id": "example", "title": "Example Video", "description": "Example video content", "duration": 120, "start_time": 0, "end_time": 120 }]) def prepare_features(self): """Prepare text features for search""" try: # Combine relevant text fields for search text_fields = ['title', 'description'] if 'title' in self.dataset.columns else ['description'] combined_text = self.dataset[text_fields].fillna('').agg(' '.join, axis=1) self.text_embeds = self.text_model.encode(combined_text.tolist()) except Exception as e: st.warning(f"Error preparing features: {e}") self.text_embeds = np.random.randn(len(self.dataset), 384) def search(self, query, column=None, top_k=20): """Search using text embeddings and optional column filtering""" query_embedding = self.text_model.encode([query])[0] similarities = cosine_similarity([query_embedding], self.text_embeds)[0] # Column filtering if column and column in self.dataset.columns and column != "All Fields": mask = self.dataset[column].astype(str).str.contains(query, case=False) similarities[~mask] *= 0.5 top_k = min(top_k, len(similarities)) top_indices = np.argsort(similarities)[-top_k:][::-1] results = [] for idx in top_indices: result = { 'relevance_score': float(similarities[idx]), **self.dataset.iloc[idx].to_dict() } results.append(result) return results def render_video_result(result): """Render a video result with enhanced display""" col1, col2 = st.columns([2, 1]) with col1: if 'title' in result: st.markdown(f"**Title:** {result['title']}") st.markdown("**Description:**") st.write(result.get('description', 'No description available')) # Show timing information start_time = result.get('start_time', 0) end_time = result.get('end_time', result.get('duration', 0)) st.markdown(f"**Time Range:** {start_time}s - {end_time}s") # Show additional metadata for key, value in result.items(): if key not in ['title', 'description', 'start_time', 'end_time', 'duration', 'relevance_score', 'video_id', '_config', '_split']: st.markdown(f"**{key.replace('_', ' ').title()}:** {value}") with col2: st.markdown(f"**Relevance Score:** {result['relevance_score']:.2%}") # Display video if URL is available video_url = None if 'video_url' in result: video_url = result['video_url'] elif 'youtube_id' in result: video_url = f"https://youtube.com/watch?v={result['youtube_id']}&t={start_time}" if video_url: st.video(video_url) def main(): st.title("🎥 Video Dataset Search") # Get HF token from secrets or user input if not st.session_state['hf_token']: st.session_state['hf_token'] = st.secrets.get("HF_TOKEN", None) if not st.session_state['hf_token']: hf_token = st.text_input("Enter your Hugging Face API token:", type="password") if hf_token: st.session_state['hf_token'] = hf_token if not st.session_state.get('hf_token'): st.warning("Please provide a Hugging Face API token to access the dataset.") return # Initialize search class search = ParquetVideoSearch(st.session_state['hf_token']) # Create tabs tab1, tab2 = st.tabs(["🔍 Video Search", "📊 Dataset Info"]) # ---- Tab 1: Video Search ---- with tab1: st.subheader("Search Videos") col1, col2 = st.columns([3, 1]) with col1: query = st.text_input("Enter your search query:", value="" if st.session_state['initial_search_done'] else "") with col2: search_column = st.selectbox("Search in field:", ["All Fields"] + st.session_state['search_columns']) col3, col4 = st.columns(2) with col3: num_results = st.slider("Number of results:", 1, 100, 20) with col4: search_button = st.button("🔍 Search") if search_button and query: st.session_state['initial_search_done'] = True selected_column = None if search_column == "All Fields" else search_column with st.spinner("Searching..."): results = search.search(query, selected_column, num_results) st.session_state['search_history'].append({ 'query': query, 'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S"), 'results': results[:5] }) for i, result in enumerate(results, 1): with st.expander( f"Result {i}: {result.get('title', result.get('description', 'No title'))[:100]}...", expanded=(i==1) ): render_video_result(result) # ---- Tab 2: Dataset Info ---- with tab2: st.subheader("Dataset Information") # Show available splits splits = fetch_dataset_splits_auth(search.dataset_id, st.session_state['hf_token']) if splits: st.write("### Available Splits") for split in splits: st.write(f"- {split['split']}: {split.get('num_rows', 'unknown')} rows") # Show dataset statistics st.write("### Dataset Statistics") st.write(f"- Loaded rows: {len(search.dataset)}") st.write(f"- Available columns: {', '.join(search.dataset.columns)}") # Show sample data st.write("### Sample Data") st.dataframe(search.dataset.head()) # Sidebar with st.sidebar: st.subheader("⚙️ Search History") if st.button("🗑️ Clear History"): st.session_state['search_history'] = [] st.experimental_rerun() st.markdown("### Recent Searches") for entry in reversed(st.session_state['search_history'][-5:]): with st.expander(f"{entry['timestamp']}: {entry['query']}"): for i, result in enumerate(entry['results'], 1): st.write(f"{i}. {result.get('title', result.get('description', 'No title'))[:100]}...") if __name__ == "__main__": main()