awacke1's picture
Update app.py
959152c verified
raw
history blame
10.1 kB
import streamlit as st
import pandas as pd
import numpy as np
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import os
from datetime import datetime
import requests
from datasets import load_dataset
from urllib.parse import quote
# Initialize session state
if 'search_history' not in st.session_state:
st.session_state['search_history'] = []
if 'search_columns' not in st.session_state:
st.session_state['search_columns'] = []
if 'initial_search_done' not in st.session_state:
st.session_state['initial_search_done'] = False
if 'dataset' not in st.session_state:
st.session_state['dataset'] = None
class DatasetSearcher:
def __init__(self, dataset_id="tomg-group-umd/cinepile"):
self.dataset_id = dataset_id
self.text_model = SentenceTransformer('all-MiniLM-L6-v2')
self.token = os.environ.get('DATASET_KEY')
if not self.token:
st.error("Please set the DATASET_KEY environment variable with your Hugging Face token.")
st.stop()
self.load_dataset()
def load_dataset(self):
"""Load dataset using the datasets library"""
try:
if st.session_state['dataset'] is None:
with st.spinner("Loading dataset..."):
st.session_state['dataset'] = load_dataset(
self.dataset_id,
token=self.token,
streaming=False
)
self.dataset = st.session_state['dataset']
# Convert first split to DataFrame for easier processing
first_split = next(iter(self.dataset.values()))
self.df = pd.DataFrame(first_split)
# Store column information
self.columns = list(self.df.columns)
self.text_columns = [col for col in self.columns
if self.df[col].dtype == 'object'
and not any(term in col.lower()
for term in ['embed', 'vector', 'encoding'])]
# Update session state columns
st.session_state['search_columns'] = self.text_columns
# Prepare text embeddings
self.prepare_features()
except Exception as e:
st.error(f"Error loading dataset: {str(e)}")
st.error("Please check your authentication token and internet connection.")
st.stop()
def prepare_features(self):
"""Prepare text embeddings for semantic search"""
try:
# Combine text columns for embedding
combined_text = self.df[self.text_columns].fillna('').agg(' '.join, axis=1)
# Create embeddings in batches to manage memory
batch_size = 32
all_embeddings = []
with st.spinner("Preparing search features..."):
for i in range(0, len(combined_text), batch_size):
batch = combined_text[i:i+batch_size].tolist()
embeddings = self.text_model.encode(batch)
all_embeddings.append(embeddings)
self.text_embeddings = np.vstack(all_embeddings)
except Exception as e:
st.warning(f"Error preparing features: {str(e)}")
self.text_embeddings = np.random.randn(len(self.df), 384)
def search(self, query, column=None, top_k=20):
"""Search the dataset using semantic and keyword matching"""
if self.df.empty:
return []
# Get semantic similarity scores
query_embedding = self.text_model.encode([query])[0]
similarities = cosine_similarity([query_embedding], self.text_embeddings)[0]
# Get keyword match scores
search_columns = [column] if column and column != "All Fields" else self.text_columns
keyword_scores = np.zeros(len(self.df))
for col in search_columns:
if col in self.df.columns:
matches = self.df[col].fillna('').str.lower().str.count(query.lower())
keyword_scores += matches
# Combine scores
combined_scores = 0.5 * similarities + 0.5 * (keyword_scores / max(1, keyword_scores.max()))
# Get top results
top_k = min(top_k, len(combined_scores))
top_indices = np.argsort(combined_scores)[-top_k:][::-1]
# Format results
results = []
for idx in top_indices:
result = {
'relevance_score': float(combined_scores[idx]),
'semantic_score': float(similarities[idx]),
'keyword_score': float(keyword_scores[idx]),
**self.df.iloc[idx].to_dict()
}
results.append(result)
return results
def get_dataset_info(self):
"""Get information about the dataset"""
if not self.dataset:
return {}
info = {
'splits': list(self.dataset.keys()),
'total_rows': sum(split.num_rows for split in self.dataset.values()),
'columns': self.columns,
'text_columns': self.text_columns,
'sample_rows': len(self.df),
'embeddings_shape': self.text_embeddings.shape
}
return info
def render_video_result(result):
"""Render a video result with enhanced display"""
col1, col2 = st.columns([2, 1])
with col1:
if 'title' in result:
st.markdown(f"**Title:** {result['title']}")
if 'description' in result:
st.markdown("**Description:**")
st.write(result['description'])
# Show timing information if available
if 'start_time' in result and 'end_time' in result:
st.markdown(f"**Time Range:** {result['start_time']}s - {result['end_time']}s")
# Show additional metadata
for key, value in result.items():
if key not in ['title', 'description', 'start_time', 'end_time', 'duration',
'relevance_score', 'semantic_score', 'keyword_score',
'video_id', 'youtube_id']:
st.markdown(f"**{key.replace('_', ' ').title()}:** {value}")
with col2:
# Show search scores
st.markdown("**Search Scores:**")
cols = st.columns(3)
cols[0].metric("Overall", f"{result['relevance_score']:.2%}")
cols[1].metric("Semantic", f"{result['semantic_score']:.2%}")
cols[2].metric("Keyword", f"{result['keyword_score']:.0f} matches")
# Display video if available
if 'youtube_id' in result:
st.video(f"https://youtube.com/watch?v={result['youtube_id']}&t={result.get('start_time', 0)}")
def main():
st.title("πŸŽ₯ Video Dataset Search")
# Initialize search class
searcher = DatasetSearcher()
# Create tabs
tab1, tab2 = st.tabs(["πŸ” Search", "πŸ“Š Dataset Info"])
# ---- Tab 1: Search ----
with tab1:
st.subheader("Search Videos")
col1, col2 = st.columns([3, 1])
with col1:
query = st.text_input("Search query:",
value="" if st.session_state['initial_search_done'] else "")
with col2:
search_column = st.selectbox("Search in field:",
["All Fields"] + st.session_state['search_columns'])
col3, col4 = st.columns(2)
with col3:
num_results = st.slider("Number of results:", 1, 100, 20)
with col4:
search_button = st.button("πŸ” Search")
if search_button and query:
st.session_state['initial_search_done'] = True
selected_column = None if search_column == "All Fields" else search_column
with st.spinner("Searching..."):
results = searcher.search(query, selected_column, num_results)
st.session_state['search_history'].append({
'query': query,
'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
'results': results[:5]
})
for i, result in enumerate(results, 1):
with st.expander(
f"Result {i}: {result.get('title', result.get('description', 'No title'))[:100]}...",
expanded=(i==1)
):
render_video_result(result)
# ---- Tab 2: Dataset Info ----
with tab2:
st.subheader("Dataset Information")
info = searcher.get_dataset_info()
if info:
st.write(f"### Dataset: {searcher.dataset_id}")
st.write(f"- Total rows: {info['total_rows']:,}")
st.write(f"- Available splits: {', '.join(info['splits'])}")
st.write(f"- Number of columns: {len(info['columns'])}")
st.write(f"- Searchable text columns: {', '.join(info['text_columns'])}")
st.write("### Sample Data")
st.dataframe(searcher.df.head())
st.write("### Column Details")
for col in info['columns']:
st.write(f"- **{col}**: {searcher.df[col].dtype}")
# Sidebar
with st.sidebar:
st.subheader("βš™οΈ Search History")
if st.button("πŸ—‘οΈ Clear History"):
st.session_state['search_history'] = []
st.experimental_rerun()
st.markdown("### Recent Searches")
for entry in reversed(st.session_state['search_history'][-5:]):
with st.expander(f"{entry['timestamp']}: {entry['query']}"):
for i, result in enumerate(entry['results'], 1):
st.write(f"{i}. {result.get('title', result.get('description', 'No title'))[:100]}...")
if __name__ == "__main__":
main()