|
import streamlit as st |
|
import pandas as pd |
|
import numpy as np |
|
from sentence_transformers import SentenceTransformer |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
import os |
|
from datetime import datetime |
|
from datasets import load_dataset |
|
|
|
|
|
if 'search_history' not in st.session_state: |
|
st.session_state['search_history'] = [] |
|
if 'search_columns' not in st.session_state: |
|
st.session_state['search_columns'] = [] |
|
if 'dataset_loaded' not in st.session_state: |
|
st.session_state['dataset_loaded'] = False |
|
if 'current_page' not in st.session_state: |
|
st.session_state['current_page'] = 0 |
|
if 'data_cache' not in st.session_state: |
|
st.session_state['data_cache'] = None |
|
if 'dataset_info' not in st.session_state: |
|
st.session_state['dataset_info'] = None |
|
|
|
ROWS_PER_PAGE = 100 |
|
|
|
@st.cache_resource |
|
def get_model(): |
|
"""Cache the model loading""" |
|
return SentenceTransformer('all-MiniLM-L6-v2') |
|
|
|
@st.cache_data |
|
def load_dataset_page(dataset_id, token, page, rows_per_page): |
|
"""Load and cache a specific page of data""" |
|
try: |
|
start_idx = page * rows_per_page |
|
end_idx = start_idx + rows_per_page |
|
dataset = load_dataset( |
|
dataset_id, |
|
token=token, |
|
streaming=False, |
|
split=f'train[{start_idx}:{end_idx}]' |
|
) |
|
return pd.DataFrame(dataset) |
|
except Exception as e: |
|
st.error(f"Error loading page {page}: {str(e)}") |
|
return pd.DataFrame() |
|
|
|
@st.cache_data |
|
def get_dataset_info(dataset_id, token): |
|
"""Load and cache dataset information""" |
|
try: |
|
dataset = load_dataset( |
|
dataset_id, |
|
token=token, |
|
streaming=True |
|
) |
|
return dataset['train'].info |
|
except Exception as e: |
|
st.error(f"Error loading dataset info: {str(e)}") |
|
return None |
|
|
|
class FastDatasetSearcher: |
|
def __init__(self, dataset_id="tomg-group-umd/cinepile"): |
|
self.dataset_id = dataset_id |
|
self.text_model = get_model() |
|
self.token = os.environ.get('DATASET_KEY') |
|
if not self.token: |
|
st.error("Please set the DATASET_KEY environment variable with your Hugging Face token.") |
|
st.stop() |
|
|
|
|
|
if st.session_state['dataset_info'] is None: |
|
st.session_state['dataset_info'] = get_dataset_info(self.dataset_id, self.token) |
|
|
|
def load_page(self, page=0): |
|
"""Load a specific page of data using cached function""" |
|
return load_dataset_page(self.dataset_id, self.token, page, ROWS_PER_PAGE) |
|
|
|
def quick_search(self, query, df): |
|
"""Fast search on current page""" |
|
if df.empty: |
|
return df |
|
|
|
scores = [] |
|
query_embedding = self.text_model.encode([query], show_progress_bar=False)[0] |
|
|
|
for _, row in df.iterrows(): |
|
|
|
text_values = [] |
|
for v in row.values(): |
|
if isinstance(v, (str, int, float)): |
|
text_values.append(str(v)) |
|
elif isinstance(v, (list, dict)): |
|
text_values.append(str(v)) |
|
text = ' '.join(text_values) |
|
|
|
|
|
keyword_score = text.lower().count(query.lower()) / (len(text.split()) + 1) |
|
|
|
|
|
text_embedding = self.text_model.encode([text], show_progress_bar=False)[0] |
|
semantic_score = cosine_similarity([query_embedding], [text_embedding])[0][0] |
|
|
|
|
|
combined_score = 0.5 * semantic_score + 0.5 * keyword_score |
|
scores.append(combined_score) |
|
|
|
|
|
results_df = df.copy() |
|
results_df['score'] = scores |
|
return results_df.sort_values('score', ascending=False) |
|
|
|
def render_result(result): |
|
"""Render a single search result""" |
|
score = result.pop('score', 0) |
|
|
|
|
|
if 'youtube_id' in result: |
|
st.video( |
|
f"https://youtube.com/watch?v={result['youtube_id']}&t={result.get('start_time', 0)}" |
|
) |
|
|
|
|
|
cols = st.columns([2, 1]) |
|
with cols[0]: |
|
for key, value in result.items(): |
|
if isinstance(value, (str, int, float)): |
|
st.write(f"**{key}:** {value}") |
|
|
|
with cols[1]: |
|
st.metric("Relevance Score", f"{score:.2%}") |
|
|
|
def main(): |
|
st.title("🎥 Fast Video Dataset Search") |
|
|
|
|
|
searcher = FastDatasetSearcher() |
|
|
|
|
|
if st.session_state['dataset_info']: |
|
st.sidebar.write("### Dataset Info") |
|
st.sidebar.write(f"Total examples: {st.session_state['dataset_info'].splits['train'].num_examples:,}") |
|
|
|
total_pages = st.session_state['dataset_info'].splits['train'].num_examples // ROWS_PER_PAGE |
|
current_page = st.number_input("Page", min_value=0, max_value=total_pages, value=st.session_state['current_page']) |
|
else: |
|
current_page = st.number_input("Page", min_value=0, value=st.session_state['current_page']) |
|
|
|
|
|
with st.spinner(f"Loading page {current_page}..."): |
|
df = searcher.load_page(current_page) |
|
|
|
if df.empty: |
|
st.warning("No data available for this page.") |
|
return |
|
|
|
|
|
col1, col2 = st.columns([3, 1]) |
|
with col1: |
|
query = st.text_input("Search in current page:", |
|
help="Searches within currently loaded data") |
|
with col2: |
|
max_results = st.slider("Max results", 1, ROWS_PER_PAGE, 10) |
|
|
|
if query: |
|
with st.spinner("Searching..."): |
|
results = searcher.quick_search(query, df) |
|
|
|
|
|
st.write(f"Found {len(results)} results on this page:") |
|
for i, (_, result) in enumerate(results.head(max_results).iterrows(), 1): |
|
with st.expander(f"Result {i}", expanded=i==1): |
|
render_result(result) |
|
|
|
|
|
with st.expander("Show Raw Data"): |
|
st.dataframe(df) |
|
|
|
|
|
cols = st.columns(2) |
|
with cols[0]: |
|
if st.button("⬅️ Previous Page") and current_page > 0: |
|
st.session_state['current_page'] = current_page - 1 |
|
st.rerun() |
|
with cols[1]: |
|
if st.button("Next Page ➡️"): |
|
st.session_state['current_page'] = current_page + 1 |
|
st.rerun() |
|
|
|
if __name__ == "__main__": |
|
main() |