|
import streamlit as st |
|
import pandas as pd |
|
import numpy as np |
|
from sentence_transformers import SentenceTransformer |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
import os |
|
from datetime import datetime |
|
import requests |
|
from datasets import load_dataset |
|
from urllib.parse import quote |
|
|
|
|
|
if 'search_history' not in st.session_state: |
|
st.session_state['search_history'] = [] |
|
if 'search_columns' not in st.session_state: |
|
st.session_state['search_columns'] = [] |
|
if 'initial_search_done' not in st.session_state: |
|
st.session_state['initial_search_done'] = False |
|
if 'dataset' not in st.session_state: |
|
st.session_state['dataset'] = None |
|
|
|
class DatasetSearcher: |
|
def __init__(self, dataset_id="tomg-group-umd/cinepile"): |
|
self.dataset_id = dataset_id |
|
self.text_model = SentenceTransformer('all-MiniLM-L6-v2') |
|
self.token = os.environ.get('DATASET_KEY') |
|
if not self.token: |
|
st.error("Please set the DATASET_KEY environment variable with your Hugging Face token.") |
|
st.stop() |
|
self.load_dataset() |
|
|
|
def load_dataset(self): |
|
"""Load dataset using the datasets library""" |
|
try: |
|
if st.session_state['dataset'] is None: |
|
with st.spinner("Loading dataset..."): |
|
st.session_state['dataset'] = load_dataset( |
|
self.dataset_id, |
|
token=self.token, |
|
streaming=False |
|
) |
|
|
|
self.dataset = st.session_state['dataset'] |
|
|
|
first_split = next(iter(self.dataset.values())) |
|
self.df = pd.DataFrame(first_split) |
|
|
|
|
|
self.columns = list(self.df.columns) |
|
|
|
self.text_columns = [] |
|
for col in self.columns: |
|
if col.lower() not in ['embed', 'vector', 'encoding']: |
|
sample_val = self.df[col].iloc[0] if not self.df.empty else None |
|
if isinstance(sample_val, (str, int, float, list, dict)) or sample_val is None: |
|
self.text_columns.append(col) |
|
|
|
|
|
st.session_state['search_columns'] = self.text_columns |
|
|
|
|
|
self.prepare_features() |
|
|
|
except Exception as e: |
|
st.error(f"Error loading dataset: {str(e)}") |
|
st.error("Please check your authentication token and internet connection.") |
|
st.stop() |
|
|
|
def prepare_features(self): |
|
"""Prepare text embeddings for semantic search""" |
|
try: |
|
|
|
processed_texts = [] |
|
for _, row in self.df.iterrows(): |
|
row_texts = [] |
|
for col in self.text_columns: |
|
value = row[col] |
|
if isinstance(value, (list, dict)): |
|
|
|
row_texts.append(str(value)) |
|
elif isinstance(value, (int, float)): |
|
|
|
row_texts.append(str(value)) |
|
elif value is None: |
|
row_texts.append('') |
|
else: |
|
|
|
row_texts.append(str(value)) |
|
processed_texts.append(' '.join(row_texts)) |
|
|
|
|
|
batch_size = 32 |
|
all_embeddings = [] |
|
|
|
with st.spinner("Preparing search features..."): |
|
for i in range(0, len(processed_texts), batch_size): |
|
batch = processed_texts[i:i+batch_size] |
|
embeddings = self.text_model.encode(batch) |
|
all_embeddings.append(embeddings) |
|
|
|
self.text_embeddings = np.vstack(all_embeddings) |
|
|
|
except Exception as e: |
|
st.warning(f"Error preparing features: {str(e)}") |
|
self.text_embeddings = np.random.randn(len(self.df), 384) |
|
|
|
def search(self, query, column=None, top_k=20): |
|
"""Search the dataset using semantic and keyword matching""" |
|
if self.df.empty: |
|
return [] |
|
|
|
|
|
query_embedding = self.text_model.encode([query])[0] |
|
similarities = cosine_similarity([query_embedding], self.text_embeddings)[0] |
|
|
|
|
|
search_columns = [column] if column and column != "All Fields" else self.text_columns |
|
keyword_scores = np.zeros(len(self.df)) |
|
|
|
query_lower = query.lower() |
|
for col in search_columns: |
|
if col in self.df.columns: |
|
for idx, value in enumerate(self.df[col]): |
|
if isinstance(value, (list, dict)): |
|
|
|
text = str(value).lower() |
|
elif isinstance(value, (int, float)): |
|
|
|
text = str(value).lower() |
|
elif value is None: |
|
text = '' |
|
else: |
|
|
|
text = str(value).lower() |
|
|
|
keyword_scores[idx] += text.count(query_lower) |
|
|
|
|
|
combined_scores = 0.5 * similarities + 0.5 * (keyword_scores / max(1, keyword_scores.max())) |
|
|
|
|
|
top_k = min(top_k, len(combined_scores)) |
|
top_indices = np.argsort(combined_scores)[-top_k:][::-1] |
|
|
|
|
|
results = [] |
|
for idx in top_indices: |
|
result = { |
|
'relevance_score': float(combined_scores[idx]), |
|
'semantic_score': float(similarities[idx]), |
|
'keyword_score': float(keyword_scores[idx]), |
|
**self.df.iloc[idx].to_dict() |
|
} |
|
results.append(result) |
|
|
|
return results |
|
|
|
def get_dataset_info(self): |
|
"""Get information about the dataset""" |
|
if not self.dataset: |
|
return {} |
|
|
|
info = { |
|
'splits': list(self.dataset.keys()), |
|
'total_rows': sum(split.num_rows for split in self.dataset.values()), |
|
'columns': self.columns, |
|
'text_columns': self.text_columns, |
|
'sample_rows': len(self.df), |
|
'embeddings_shape': self.text_embeddings.shape |
|
} |
|
|
|
return info |
|
|
|
def render_video_result(result): |
|
"""Render a video result with enhanced display""" |
|
col1, col2 = st.columns([2, 1]) |
|
|
|
with col1: |
|
if 'title' in result: |
|
st.markdown(f"**Title:** {result['title']}") |
|
if 'description' in result: |
|
st.markdown("**Description:**") |
|
st.write(result['description']) |
|
|
|
|
|
if 'start_time' in result and 'end_time' in result: |
|
st.markdown(f"**Time Range:** {result['start_time']}s - {result['end_time']}s") |
|
|
|
|
|
for key, value in result.items(): |
|
if key not in ['title', 'description', 'start_time', 'end_time', 'duration', |
|
'relevance_score', 'semantic_score', 'keyword_score', |
|
'video_id', 'youtube_id']: |
|
st.markdown(f"**{key.replace('_', ' ').title()}:** {value}") |
|
|
|
with col2: |
|
|
|
st.markdown("**Search Scores:**") |
|
cols = st.columns(3) |
|
cols[0].metric("Overall", f"{result['relevance_score']:.2%}") |
|
cols[1].metric("Semantic", f"{result['semantic_score']:.2%}") |
|
cols[2].metric("Keyword", f"{result['keyword_score']:.0f} matches") |
|
|
|
|
|
if 'youtube_id' in result: |
|
st.video(f"https://youtube.com/watch?v={result['youtube_id']}&t={result.get('start_time', 0)}") |
|
|
|
def main(): |
|
st.title("π₯ Video Dataset Search") |
|
|
|
|
|
searcher = DatasetSearcher() |
|
|
|
|
|
tab1, tab2 = st.tabs(["π Search", "π Dataset Info"]) |
|
|
|
|
|
with tab1: |
|
st.subheader("Search Videos") |
|
col1, col2 = st.columns([3, 1]) |
|
|
|
with col1: |
|
query = st.text_input("Search query:", |
|
value="" if st.session_state['initial_search_done'] else "") |
|
with col2: |
|
search_column = st.selectbox("Search in field:", |
|
["All Fields"] + st.session_state['search_columns']) |
|
|
|
col3, col4 = st.columns(2) |
|
with col3: |
|
num_results = st.slider("Number of results:", 1, 100, 20) |
|
with col4: |
|
search_button = st.button("π Search") |
|
|
|
if search_button and query: |
|
st.session_state['initial_search_done'] = True |
|
selected_column = None if search_column == "All Fields" else search_column |
|
|
|
with st.spinner("Searching..."): |
|
results = searcher.search(query, selected_column, num_results) |
|
|
|
st.session_state['search_history'].append({ |
|
'query': query, |
|
'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S"), |
|
'results': results[:5] |
|
}) |
|
|
|
for i, result in enumerate(results, 1): |
|
with st.expander( |
|
f"Result {i}: {result.get('title', result.get('description', 'No title'))[:100]}...", |
|
expanded=(i==1) |
|
): |
|
render_video_result(result) |
|
|
|
|
|
with tab2: |
|
st.subheader("Dataset Information") |
|
|
|
info = searcher.get_dataset_info() |
|
if info: |
|
st.write(f"### Dataset: {searcher.dataset_id}") |
|
st.write(f"- Total rows: {info['total_rows']:,}") |
|
st.write(f"- Available splits: {', '.join(info['splits'])}") |
|
st.write(f"- Number of columns: {len(info['columns'])}") |
|
st.write(f"- Searchable text columns: {', '.join(info['text_columns'])}") |
|
|
|
st.write("### Sample Data") |
|
st.dataframe(searcher.df.head()) |
|
|
|
st.write("### Column Details") |
|
for col in info['columns']: |
|
st.write(f"- **{col}**: {searcher.df[col].dtype}") |
|
|
|
|
|
with st.sidebar: |
|
st.subheader("βοΈ Search History") |
|
if st.button("ποΈ Clear History"): |
|
st.session_state['search_history'] = [] |
|
st.experimental_rerun() |
|
|
|
st.markdown("### Recent Searches") |
|
for entry in reversed(st.session_state['search_history'][-5:]): |
|
with st.expander(f"{entry['timestamp']}: {entry['query']}"): |
|
for i, result in enumerate(entry['results'], 1): |
|
st.write(f"{i}. {result.get('title', result.get('description', 'No title'))[:100]}...") |
|
|
|
if __name__ == "__main__": |
|
main() |