File size: 5,431 Bytes
67c01ec 4e1519e 8427637 959152c 67c01ec 8427637 67c01ec bdefc08 67c01ec bdefc08 959152c bdefc08 959152c bdefc08 959152c bdefc08 6afbfac bdefc08 6afbfac 959152c bdefc08 6afbfac bdefc08 8427637 bdefc08 959152c bdefc08 959152c bdefc08 6afbfac 959152c bdefc08 8427637 bdefc08 6afbfac bdefc08 6afbfac bdefc08 6afbfac bdefc08 959152c bdefc08 8427637 bdefc08 6afbfac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
import streamlit as st
import pandas as pd
import numpy as np
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import os
from datetime import datetime
from datasets import load_dataset
# Initialize session state
if 'search_history' not in st.session_state:
st.session_state['search_history'] = []
if 'search_columns' not in st.session_state:
st.session_state['search_columns'] = []
if 'dataset_loaded' not in st.session_state:
st.session_state['dataset_loaded'] = False
if 'current_page' not in st.session_state:
st.session_state['current_page'] = 0
if 'data_cache' not in st.session_state:
st.session_state['data_cache'] = None
ROWS_PER_PAGE = 100 # Number of rows to load at a time
@st.cache_resource
def get_model():
return SentenceTransformer('all-MiniLM-L6-v2')
class FastDatasetSearcher:
def __init__(self, dataset_id="tomg-group-umd/cinepile"):
self.dataset_id = dataset_id
self.text_model = get_model()
self.token = os.environ.get('DATASET_KEY')
if not self.token:
st.error("Please set the DATASET_KEY environment variable with your Hugging Face token.")
st.stop()
self.load_dataset_info()
@st.cache_data
def load_dataset_info(self):
"""Load dataset metadata only"""
try:
dataset = load_dataset(
self.dataset_id,
token=self.token,
streaming=True
)
self.dataset_info = dataset['train'].info
return True
except Exception as e:
st.error(f"Error loading dataset: {str(e)}")
return False
def load_page(self, page=0):
"""Load a specific page of data"""
if st.session_state['data_cache'] is not None and st.session_state['current_page'] == page:
return st.session_state['data_cache']
try:
dataset = load_dataset(
self.dataset_id,
token=self.token,
streaming=False,
split=f'train[{page*ROWS_PER_PAGE}:{(page+1)*ROWS_PER_PAGE}]'
)
df = pd.DataFrame(dataset)
st.session_state['data_cache'] = df
st.session_state['current_page'] = page
return df
except Exception as e:
st.error(f"Error loading page {page}: {str(e)}")
return pd.DataFrame()
def quick_search(self, query, df):
"""Fast search on current page"""
scores = []
query_embedding = self.text_model.encode([query], show_progress_bar=False)[0]
for _, row in df.iterrows():
# Combine all searchable text fields
text = ' '.join(str(v) for v in row.values() if isinstance(v, (str, int, float)))
# Quick keyword match
keyword_score = text.lower().count(query.lower()) / len(text.split())
# Semantic search on combined text
text_embedding = self.text_model.encode([text], show_progress_bar=False)[0]
semantic_score = cosine_similarity([query_embedding], [text_embedding])[0][0]
# Combine scores
combined_score = 0.5 * semantic_score + 0.5 * keyword_score
scores.append(combined_score)
# Get top results
df['score'] = scores
return df.sort_values('score', ascending=False)
def main():
st.title("🎥 Fast Video Dataset Search")
# Initialize search class
searcher = FastDatasetSearcher()
# Page navigation
page = st.number_input("Page", min_value=0, value=st.session_state['current_page'])
# Load current page
with st.spinner(f"Loading page {page}..."):
df = searcher.load_page(page)
if df.empty:
st.warning("No data available for this page.")
return
# Search interface
query = st.text_input("Search in current page:", help="Searches within currently loaded data")
if query:
with st.spinner("Searching..."):
results = searcher.quick_search(query, df)
# Display results
st.write(f"Found {len(results)} results on this page:")
for i, (_, result) in enumerate(results.iterrows(), 1):
score = result.pop('score')
with st.expander(f"Result {i} (Score: {score:.2%})", expanded=i==1):
# Display video if available
if 'youtube_id' in result:
st.video(
f"https://youtube.com/watch?v={result['youtube_id']}&t={result.get('start_time', 0)}"
)
# Display other fields
for key, value in result.items():
if isinstance(value, (str, int, float)):
st.write(f"**{key}:** {value}")
# Show raw data
st.subheader("Raw Data")
st.dataframe(df)
# Navigation buttons
cols = st.columns(2)
with cols[0]:
if st.button("Previous Page") and page > 0:
st.session_state['current_page'] -= 1
st.rerun()
with cols[1]:
if st.button("Next Page"):
st.session_state['current_page'] += 1
st.rerun()
if __name__ == "__main__":
main() |