awacke1's picture
Update app.py
bdefc08 verified
raw
history blame
5.43 kB
import streamlit as st
import pandas as pd
import numpy as np
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import os
from datetime import datetime
from datasets import load_dataset
# Initialize session state
if 'search_history' not in st.session_state:
st.session_state['search_history'] = []
if 'search_columns' not in st.session_state:
st.session_state['search_columns'] = []
if 'dataset_loaded' not in st.session_state:
st.session_state['dataset_loaded'] = False
if 'current_page' not in st.session_state:
st.session_state['current_page'] = 0
if 'data_cache' not in st.session_state:
st.session_state['data_cache'] = None
ROWS_PER_PAGE = 100 # Number of rows to load at a time
@st.cache_resource
def get_model():
return SentenceTransformer('all-MiniLM-L6-v2')
class FastDatasetSearcher:
def __init__(self, dataset_id="tomg-group-umd/cinepile"):
self.dataset_id = dataset_id
self.text_model = get_model()
self.token = os.environ.get('DATASET_KEY')
if not self.token:
st.error("Please set the DATASET_KEY environment variable with your Hugging Face token.")
st.stop()
self.load_dataset_info()
@st.cache_data
def load_dataset_info(self):
"""Load dataset metadata only"""
try:
dataset = load_dataset(
self.dataset_id,
token=self.token,
streaming=True
)
self.dataset_info = dataset['train'].info
return True
except Exception as e:
st.error(f"Error loading dataset: {str(e)}")
return False
def load_page(self, page=0):
"""Load a specific page of data"""
if st.session_state['data_cache'] is not None and st.session_state['current_page'] == page:
return st.session_state['data_cache']
try:
dataset = load_dataset(
self.dataset_id,
token=self.token,
streaming=False,
split=f'train[{page*ROWS_PER_PAGE}:{(page+1)*ROWS_PER_PAGE}]'
)
df = pd.DataFrame(dataset)
st.session_state['data_cache'] = df
st.session_state['current_page'] = page
return df
except Exception as e:
st.error(f"Error loading page {page}: {str(e)}")
return pd.DataFrame()
def quick_search(self, query, df):
"""Fast search on current page"""
scores = []
query_embedding = self.text_model.encode([query], show_progress_bar=False)[0]
for _, row in df.iterrows():
# Combine all searchable text fields
text = ' '.join(str(v) for v in row.values() if isinstance(v, (str, int, float)))
# Quick keyword match
keyword_score = text.lower().count(query.lower()) / len(text.split())
# Semantic search on combined text
text_embedding = self.text_model.encode([text], show_progress_bar=False)[0]
semantic_score = cosine_similarity([query_embedding], [text_embedding])[0][0]
# Combine scores
combined_score = 0.5 * semantic_score + 0.5 * keyword_score
scores.append(combined_score)
# Get top results
df['score'] = scores
return df.sort_values('score', ascending=False)
def main():
st.title("πŸŽ₯ Fast Video Dataset Search")
# Initialize search class
searcher = FastDatasetSearcher()
# Page navigation
page = st.number_input("Page", min_value=0, value=st.session_state['current_page'])
# Load current page
with st.spinner(f"Loading page {page}..."):
df = searcher.load_page(page)
if df.empty:
st.warning("No data available for this page.")
return
# Search interface
query = st.text_input("Search in current page:", help="Searches within currently loaded data")
if query:
with st.spinner("Searching..."):
results = searcher.quick_search(query, df)
# Display results
st.write(f"Found {len(results)} results on this page:")
for i, (_, result) in enumerate(results.iterrows(), 1):
score = result.pop('score')
with st.expander(f"Result {i} (Score: {score:.2%})", expanded=i==1):
# Display video if available
if 'youtube_id' in result:
st.video(
f"https://youtube.com/watch?v={result['youtube_id']}&t={result.get('start_time', 0)}"
)
# Display other fields
for key, value in result.items():
if isinstance(value, (str, int, float)):
st.write(f"**{key}:** {value}")
# Show raw data
st.subheader("Raw Data")
st.dataframe(df)
# Navigation buttons
cols = st.columns(2)
with cols[0]:
if st.button("Previous Page") and page > 0:
st.session_state['current_page'] -= 1
st.rerun()
with cols[1]:
if st.button("Next Page"):
st.session_state['current_page'] += 1
st.rerun()
if __name__ == "__main__":
main()