Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pandas as pd | |
import numpy as np | |
from sentence_transformers import SentenceTransformer | |
from sklearn.metrics.pairwise import cosine_similarity | |
import os | |
from datetime import datetime | |
import requests | |
from datasets import load_dataset | |
from urllib.parse import quote | |
# Initialize session state | |
if 'search_history' not in st.session_state: | |
st.session_state['search_history'] = [] | |
if 'search_columns' not in st.session_state: | |
st.session_state['search_columns'] = [] | |
if 'initial_search_done' not in st.session_state: | |
st.session_state['initial_search_done'] = False | |
if 'dataset' not in st.session_state: | |
st.session_state['dataset'] = None | |
class DatasetSearcher: | |
def __init__(self, dataset_id="tomg-group-umd/cinepile"): | |
self.dataset_id = dataset_id | |
self.text_model = SentenceTransformer('all-MiniLM-L6-v2') | |
self.token = os.environ.get('DATASET_KEY') | |
if not self.token: | |
st.error("Please set the DATASET_KEY environment variable with your Hugging Face token.") | |
st.stop() | |
self.load_dataset() | |
def load_dataset(self): | |
"""Load dataset using the datasets library""" | |
try: | |
if st.session_state['dataset'] is None: | |
with st.spinner("Loading dataset..."): | |
st.session_state['dataset'] = load_dataset( | |
self.dataset_id, | |
token=self.token, | |
streaming=False | |
) | |
self.dataset = st.session_state['dataset'] | |
# Convert first split to DataFrame for easier processing | |
first_split = next(iter(self.dataset.values())) | |
self.df = pd.DataFrame(first_split) | |
# Store column information | |
self.columns = list(self.df.columns) | |
self.text_columns = [col for col in self.columns | |
if self.df[col].dtype == 'object' | |
and not any(term in col.lower() | |
for term in ['embed', 'vector', 'encoding'])] | |
# Update session state columns | |
st.session_state['search_columns'] = self.text_columns | |
# Prepare text embeddings | |
self.prepare_features() | |
except Exception as e: | |
st.error(f"Error loading dataset: {str(e)}") | |
st.error("Please check your authentication token and internet connection.") | |
st.stop() | |
def prepare_features(self): | |
"""Prepare text embeddings for semantic search""" | |
try: | |
# Combine text columns for embedding | |
combined_text = self.df[self.text_columns].fillna('').agg(' '.join, axis=1) | |
# Create embeddings in batches to manage memory | |
batch_size = 32 | |
all_embeddings = [] | |
with st.spinner("Preparing search features..."): | |
for i in range(0, len(combined_text), batch_size): | |
batch = combined_text[i:i+batch_size].tolist() | |
embeddings = self.text_model.encode(batch) | |
all_embeddings.append(embeddings) | |
self.text_embeddings = np.vstack(all_embeddings) | |
except Exception as e: | |
st.warning(f"Error preparing features: {str(e)}") | |
self.text_embeddings = np.random.randn(len(self.df), 384) | |
def search(self, query, column=None, top_k=20): | |
"""Search the dataset using semantic and keyword matching""" | |
if self.df.empty: | |
return [] | |
# Get semantic similarity scores | |
query_embedding = self.text_model.encode([query])[0] | |
similarities = cosine_similarity([query_embedding], self.text_embeddings)[0] | |
# Get keyword match scores | |
search_columns = [column] if column and column != "All Fields" else self.text_columns | |
keyword_scores = np.zeros(len(self.df)) | |
for col in search_columns: | |
if col in self.df.columns: | |
matches = self.df[col].fillna('').str.lower().str.count(query.lower()) | |
keyword_scores += matches | |
# Combine scores | |
combined_scores = 0.5 * similarities + 0.5 * (keyword_scores / max(1, keyword_scores.max())) | |
# Get top results | |
top_k = min(top_k, len(combined_scores)) | |
top_indices = np.argsort(combined_scores)[-top_k:][::-1] | |
# Format results | |
results = [] | |
for idx in top_indices: | |
result = { | |
'relevance_score': float(combined_scores[idx]), | |
'semantic_score': float(similarities[idx]), | |
'keyword_score': float(keyword_scores[idx]), | |
**self.df.iloc[idx].to_dict() | |
} | |
results.append(result) | |
return results | |
def get_dataset_info(self): | |
"""Get information about the dataset""" | |
if not self.dataset: | |
return {} | |
info = { | |
'splits': list(self.dataset.keys()), | |
'total_rows': sum(split.num_rows for split in self.dataset.values()), | |
'columns': self.columns, | |
'text_columns': self.text_columns, | |
'sample_rows': len(self.df), | |
'embeddings_shape': self.text_embeddings.shape | |
} | |
return info | |
def render_video_result(result): | |
"""Render a video result with enhanced display""" | |
col1, col2 = st.columns([2, 1]) | |
with col1: | |
if 'title' in result: | |
st.markdown(f"**Title:** {result['title']}") | |
if 'description' in result: | |
st.markdown("**Description:**") | |
st.write(result['description']) | |
# Show timing information if available | |
if 'start_time' in result and 'end_time' in result: | |
st.markdown(f"**Time Range:** {result['start_time']}s - {result['end_time']}s") | |
# Show additional metadata | |
for key, value in result.items(): | |
if key not in ['title', 'description', 'start_time', 'end_time', 'duration', | |
'relevance_score', 'semantic_score', 'keyword_score', | |
'video_id', 'youtube_id']: | |
st.markdown(f"**{key.replace('_', ' ').title()}:** {value}") | |
with col2: | |
# Show search scores | |
st.markdown("**Search Scores:**") | |
cols = st.columns(3) | |
cols[0].metric("Overall", f"{result['relevance_score']:.2%}") | |
cols[1].metric("Semantic", f"{result['semantic_score']:.2%}") | |
cols[2].metric("Keyword", f"{result['keyword_score']:.0f} matches") | |
# Display video if available | |
if 'youtube_id' in result: | |
st.video(f"https://youtube.com/watch?v={result['youtube_id']}&t={result.get('start_time', 0)}") | |
def main(): | |
st.title("π₯ Video Dataset Search") | |
# Initialize search class | |
searcher = DatasetSearcher() | |
# Create tabs | |
tab1, tab2 = st.tabs(["π Search", "π Dataset Info"]) | |
# ---- Tab 1: Search ---- | |
with tab1: | |
st.subheader("Search Videos") | |
col1, col2 = st.columns([3, 1]) | |
with col1: | |
query = st.text_input("Search query:", | |
value="" if st.session_state['initial_search_done'] else "") | |
with col2: | |
search_column = st.selectbox("Search in field:", | |
["All Fields"] + st.session_state['search_columns']) | |
col3, col4 = st.columns(2) | |
with col3: | |
num_results = st.slider("Number of results:", 1, 100, 20) | |
with col4: | |
search_button = st.button("π Search") | |
if search_button and query: | |
st.session_state['initial_search_done'] = True | |
selected_column = None if search_column == "All Fields" else search_column | |
with st.spinner("Searching..."): | |
results = searcher.search(query, selected_column, num_results) | |
st.session_state['search_history'].append({ | |
'query': query, | |
'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S"), | |
'results': results[:5] | |
}) | |
for i, result in enumerate(results, 1): | |
with st.expander( | |
f"Result {i}: {result.get('title', result.get('description', 'No title'))[:100]}...", | |
expanded=(i==1) | |
): | |
render_video_result(result) | |
# ---- Tab 2: Dataset Info ---- | |
with tab2: | |
st.subheader("Dataset Information") | |
info = searcher.get_dataset_info() | |
if info: | |
st.write(f"### Dataset: {searcher.dataset_id}") | |
st.write(f"- Total rows: {info['total_rows']:,}") | |
st.write(f"- Available splits: {', '.join(info['splits'])}") | |
st.write(f"- Number of columns: {len(info['columns'])}") | |
st.write(f"- Searchable text columns: {', '.join(info['text_columns'])}") | |
st.write("### Sample Data") | |
st.dataframe(searcher.df.head()) | |
st.write("### Column Details") | |
for col in info['columns']: | |
st.write(f"- **{col}**: {searcher.df[col].dtype}") | |
# Sidebar | |
with st.sidebar: | |
st.subheader("βοΈ Search History") | |
if st.button("ποΈ Clear History"): | |
st.session_state['search_history'] = [] | |
st.experimental_rerun() | |
st.markdown("### Recent Searches") | |
for entry in reversed(st.session_state['search_history'][-5:]): | |
with st.expander(f"{entry['timestamp']}: {entry['query']}"): | |
for i, result in enumerate(entry['results'], 1): | |
st.write(f"{i}. {result.get('title', result.get('description', 'No title'))[:100]}...") | |
if __name__ == "__main__": | |
main() |