|
import streamlit as st |
|
import pandas as pd |
|
import numpy as np |
|
from sentence_transformers import SentenceTransformer |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
import os |
|
from datetime import datetime |
|
from datasets import load_dataset |
|
|
|
|
|
if 'search_history' not in st.session_state: |
|
st.session_state['search_history'] = [] |
|
if 'search_columns' not in st.session_state: |
|
st.session_state['search_columns'] = [] |
|
if 'dataset_loaded' not in st.session_state: |
|
st.session_state['dataset_loaded'] = False |
|
if 'current_page' not in st.session_state: |
|
st.session_state['current_page'] = 0 |
|
if 'data_cache' not in st.session_state: |
|
st.session_state['data_cache'] = None |
|
|
|
ROWS_PER_PAGE = 100 |
|
|
|
@st.cache_resource |
|
def get_model(): |
|
return SentenceTransformer('all-MiniLM-L6-v2') |
|
|
|
class FastDatasetSearcher: |
|
def __init__(self, dataset_id="tomg-group-umd/cinepile"): |
|
self.dataset_id = dataset_id |
|
self.text_model = get_model() |
|
self.token = os.environ.get('DATASET_KEY') |
|
if not self.token: |
|
st.error("Please set the DATASET_KEY environment variable with your Hugging Face token.") |
|
st.stop() |
|
self.load_dataset_info() |
|
|
|
@st.cache_data |
|
def load_dataset_info(self): |
|
"""Load dataset metadata only""" |
|
try: |
|
dataset = load_dataset( |
|
self.dataset_id, |
|
token=self.token, |
|
streaming=True |
|
) |
|
self.dataset_info = dataset['train'].info |
|
return True |
|
except Exception as e: |
|
st.error(f"Error loading dataset: {str(e)}") |
|
return False |
|
|
|
def load_page(self, page=0): |
|
"""Load a specific page of data""" |
|
if st.session_state['data_cache'] is not None and st.session_state['current_page'] == page: |
|
return st.session_state['data_cache'] |
|
|
|
try: |
|
dataset = load_dataset( |
|
self.dataset_id, |
|
token=self.token, |
|
streaming=False, |
|
split=f'train[{page*ROWS_PER_PAGE}:{(page+1)*ROWS_PER_PAGE}]' |
|
) |
|
df = pd.DataFrame(dataset) |
|
st.session_state['data_cache'] = df |
|
st.session_state['current_page'] = page |
|
return df |
|
except Exception as e: |
|
st.error(f"Error loading page {page}: {str(e)}") |
|
return pd.DataFrame() |
|
|
|
def quick_search(self, query, df): |
|
"""Fast search on current page""" |
|
scores = [] |
|
query_embedding = self.text_model.encode([query], show_progress_bar=False)[0] |
|
|
|
for _, row in df.iterrows(): |
|
|
|
text = ' '.join(str(v) for v in row.values() if isinstance(v, (str, int, float))) |
|
|
|
|
|
keyword_score = text.lower().count(query.lower()) / len(text.split()) |
|
|
|
|
|
text_embedding = self.text_model.encode([text], show_progress_bar=False)[0] |
|
semantic_score = cosine_similarity([query_embedding], [text_embedding])[0][0] |
|
|
|
|
|
combined_score = 0.5 * semantic_score + 0.5 * keyword_score |
|
scores.append(combined_score) |
|
|
|
|
|
df['score'] = scores |
|
return df.sort_values('score', ascending=False) |
|
|
|
def main(): |
|
st.title("π₯ Fast Video Dataset Search") |
|
|
|
|
|
searcher = FastDatasetSearcher() |
|
|
|
|
|
page = st.number_input("Page", min_value=0, value=st.session_state['current_page']) |
|
|
|
|
|
with st.spinner(f"Loading page {page}..."): |
|
df = searcher.load_page(page) |
|
|
|
if df.empty: |
|
st.warning("No data available for this page.") |
|
return |
|
|
|
|
|
query = st.text_input("Search in current page:", help="Searches within currently loaded data") |
|
|
|
if query: |
|
with st.spinner("Searching..."): |
|
results = searcher.quick_search(query, df) |
|
|
|
|
|
st.write(f"Found {len(results)} results on this page:") |
|
for i, (_, result) in enumerate(results.iterrows(), 1): |
|
score = result.pop('score') |
|
with st.expander(f"Result {i} (Score: {score:.2%})", expanded=i==1): |
|
|
|
if 'youtube_id' in result: |
|
st.video( |
|
f"https://youtube.com/watch?v={result['youtube_id']}&t={result.get('start_time', 0)}" |
|
) |
|
|
|
|
|
for key, value in result.items(): |
|
if isinstance(value, (str, int, float)): |
|
st.write(f"**{key}:** {value}") |
|
|
|
|
|
st.subheader("Raw Data") |
|
st.dataframe(df) |
|
|
|
|
|
cols = st.columns(2) |
|
with cols[0]: |
|
if st.button("Previous Page") and page > 0: |
|
st.session_state['current_page'] -= 1 |
|
st.rerun() |
|
with cols[1]: |
|
if st.button("Next Page"): |
|
st.session_state['current_page'] += 1 |
|
st.rerun() |
|
|
|
if __name__ == "__main__": |
|
main() |