awacke1's picture
Update app.py
9b95cb7 verified
raw
history blame
6.75 kB
import streamlit as st
import pandas as pd
import numpy as np
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import os
from datetime import datetime
from datasets import load_dataset
# Initialize session state
if 'search_history' not in st.session_state:
st.session_state['search_history'] = []
if 'search_columns' not in st.session_state:
st.session_state['search_columns'] = []
if 'dataset_loaded' not in st.session_state:
st.session_state['dataset_loaded'] = False
if 'current_page' not in st.session_state:
st.session_state['current_page'] = 0
if 'data_cache' not in st.session_state:
st.session_state['data_cache'] = None
if 'dataset_info' not in st.session_state:
st.session_state['dataset_info'] = None
ROWS_PER_PAGE = 100 # Number of rows to load at a time
@st.cache_resource
def get_model():
"""Cache the model loading"""
return SentenceTransformer('all-MiniLM-L6-v2')
@st.cache_data
def load_dataset_page(dataset_id, token, page, rows_per_page):
"""Load and cache a specific page of data"""
try:
start_idx = page * rows_per_page
end_idx = start_idx + rows_per_page
dataset = load_dataset(
dataset_id,
token=token,
streaming=False,
split=f'train[{start_idx}:{end_idx}]'
)
return pd.DataFrame(dataset)
except Exception as e:
st.error(f"Error loading page {page}: {str(e)}")
return pd.DataFrame()
@st.cache_data
def get_dataset_info(dataset_id, token):
"""Load and cache dataset information"""
try:
dataset = load_dataset(
dataset_id,
token=token,
streaming=True
)
return dataset['train'].info
except Exception as e:
st.error(f"Error loading dataset info: {str(e)}")
return None
class FastDatasetSearcher:
def __init__(self, dataset_id="tomg-group-umd/cinepile"):
self.dataset_id = dataset_id
self.text_model = get_model()
self.token = os.environ.get('DATASET_KEY')
if not self.token:
st.error("Please set the DATASET_KEY environment variable with your Hugging Face token.")
st.stop()
# Load dataset info if not already loaded
if st.session_state['dataset_info'] is None:
st.session_state['dataset_info'] = get_dataset_info(self.dataset_id, self.token)
def load_page(self, page=0):
"""Load a specific page of data using cached function"""
return load_dataset_page(self.dataset_id, self.token, page, ROWS_PER_PAGE)
def quick_search(self, query, df):
"""Fast search on current page"""
if df.empty:
return df
scores = []
query_embedding = self.text_model.encode([query], show_progress_bar=False)[0]
for _, row in df.iterrows():
# Combine all searchable text fields
text_values = []
for v in row.values():
if isinstance(v, (str, int, float)):
text_values.append(str(v))
elif isinstance(v, (list, dict)):
text_values.append(str(v))
text = ' '.join(text_values)
# Quick keyword match
keyword_score = text.lower().count(query.lower()) / (len(text.split()) + 1) # Add 1 to avoid division by zero
# Semantic search on combined text
text_embedding = self.text_model.encode([text], show_progress_bar=False)[0]
semantic_score = cosine_similarity([query_embedding], [text_embedding])[0][0]
# Combine scores
combined_score = 0.5 * semantic_score + 0.5 * keyword_score
scores.append(combined_score)
# Get top results
results_df = df.copy()
results_df['score'] = scores
return results_df.sort_values('score', ascending=False)
def render_result(result):
"""Render a single search result"""
score = result.pop('score', 0)
# Display video if available
if 'youtube_id' in result:
st.video(
f"https://youtube.com/watch?v={result['youtube_id']}&t={result.get('start_time', 0)}"
)
# Display other fields
cols = st.columns([2, 1])
with cols[0]:
for key, value in result.items():
if isinstance(value, (str, int, float)):
st.write(f"**{key}:** {value}")
with cols[1]:
st.metric("Relevance Score", f"{score:.2%}")
def main():
st.title("🎥 Fast Video Dataset Search")
# Initialize search class
searcher = FastDatasetSearcher()
# Show dataset info
if st.session_state['dataset_info']:
st.sidebar.write("### Dataset Info")
st.sidebar.write(f"Total examples: {st.session_state['dataset_info'].splits['train'].num_examples:,}")
total_pages = st.session_state['dataset_info'].splits['train'].num_examples // ROWS_PER_PAGE
current_page = st.number_input("Page", min_value=0, max_value=total_pages, value=st.session_state['current_page'])
else:
current_page = st.number_input("Page", min_value=0, value=st.session_state['current_page'])
# Load current page
with st.spinner(f"Loading page {current_page}..."):
df = searcher.load_page(current_page)
if df.empty:
st.warning("No data available for this page.")
return
# Search interface
col1, col2 = st.columns([3, 1])
with col1:
query = st.text_input("Search in current page:",
help="Searches within currently loaded data")
with col2:
max_results = st.slider("Max results", 1, ROWS_PER_PAGE, 10)
if query:
with st.spinner("Searching..."):
results = searcher.quick_search(query, df)
# Display results
st.write(f"Found {len(results)} results on this page:")
for i, (_, result) in enumerate(results.head(max_results).iterrows(), 1):
with st.expander(f"Result {i}", expanded=i==1):
render_result(result)
# Show raw data
with st.expander("Show Raw Data"):
st.dataframe(df)
# Navigation buttons
cols = st.columns(2)
with cols[0]:
if st.button("⬅️ Previous Page") and current_page > 0:
st.session_state['current_page'] = current_page - 1
st.rerun()
with cols[1]:
if st.button("Next Page ➡️"):
st.session_state['current_page'] = current_page + 1
st.rerun()
if __name__ == "__main__":
main()