File size: 7,558 Bytes
67c01ec 4e1519e 8427637 959152c 67c01ec 8427637 67c01ec bdefc08 9b95cb7 67c01ec bdefc08 9b95cb7 bdefc08 9b95cb7 bdefc08 959152c bdefc08 959152c 9b95cb7 193042f 9b95cb7 bdefc08 9b95cb7 bdefc08 9b95cb7 193042f 959152c 193042f 959152c 193042f 6afbfac 959152c 9b95cb7 8427637 bdefc08 6afbfac bdefc08 6afbfac 9b95cb7 6afbfac bdefc08 9b95cb7 959152c bdefc08 9b95cb7 bdefc08 8427637 bdefc08 9b95cb7 bdefc08 9b95cb7 bdefc08 9b95cb7 bdefc08 9b95cb7 bdefc08 6afbfac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 |
import streamlit as st
import pandas as pd
import numpy as np
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import os
from datetime import datetime
from datasets import load_dataset
# Initialize session state
if 'search_history' not in st.session_state:
st.session_state['search_history'] = []
if 'search_columns' not in st.session_state:
st.session_state['search_columns'] = []
if 'dataset_loaded' not in st.session_state:
st.session_state['dataset_loaded'] = False
if 'current_page' not in st.session_state:
st.session_state['current_page'] = 0
if 'data_cache' not in st.session_state:
st.session_state['data_cache'] = None
if 'dataset_info' not in st.session_state:
st.session_state['dataset_info'] = None
ROWS_PER_PAGE = 100 # Number of rows to load at a time
@st.cache_resource
def get_model():
"""Cache the model loading"""
return SentenceTransformer('all-MiniLM-L6-v2')
@st.cache_data
def load_dataset_page(dataset_id, token, page, rows_per_page):
"""Load and cache a specific page of data"""
try:
start_idx = page * rows_per_page
end_idx = start_idx + rows_per_page
dataset = load_dataset(
dataset_id,
token=token,
streaming=False,
split=f'train[{start_idx}:{end_idx}]'
)
return pd.DataFrame(dataset)
except Exception as e:
st.error(f"Error loading page {page}: {str(e)}")
return pd.DataFrame()
@st.cache_data
def get_dataset_info(dataset_id, token):
"""Load and cache dataset information"""
try:
dataset = load_dataset(
dataset_id,
token=token,
streaming=True
)
return dataset['train'].info
except Exception as e:
st.error(f"Error loading dataset info: {str(e)}")
return None
class FastDatasetSearcher:
def __init__(self, dataset_id="tomg-group-umd/cinepile"):
self.dataset_id = dataset_id
self.text_model = get_model()
self.token = os.environ.get('DATASET_KEY')
if not self.token:
st.error("Please set the DATASET_KEY environment variable with your Hugging Face token.")
st.stop()
# Initialize numpy for model inputs
self.np = np
# Load dataset info if not already loaded
if st.session_state['dataset_info'] is None:
st.session_state['dataset_info'] = get_dataset_info(self.dataset_id, self.token)
def load_page(self, page=0):
"""Load a specific page of data using cached function"""
return load_dataset_page(self.dataset_id, self.token, page, ROWS_PER_PAGE)
def quick_search(self, query, df):
"""Fast search on current page"""
if df.empty:
return df
try:
# Get columns to search (excluding numpy array columns)
searchable_cols = []
for col in df.columns:
sample_val = df[col].iloc[0]
if not isinstance(sample_val, (np.ndarray, bytes)):
searchable_cols.append(col)
# Prepare query
query_lower = query.lower()
query_embedding = self.text_model.encode([query], show_progress_bar=False)[0]
scores = []
# Process each row
for _, row in df.iterrows():
# Combine text from searchable columns
text_parts = []
for col in searchable_cols:
val = row[col]
if val is not None:
if isinstance(val, (list, dict)):
text_parts.append(str(val))
else:
text_parts.append(str(val))
text = ' '.join(text_parts)
# Calculate scores
if text.strip():
# Keyword matching
keyword_score = text.lower().count(query_lower) / max(len(text.split()), 1)
# Semantic matching
text_embedding = self.text_model.encode([text], show_progress_bar=False)[0]
semantic_score = float(cosine_similarity([query_embedding], [text_embedding])[0][0])
# Combine scores
combined_score = 0.5 * semantic_score + 0.5 * keyword_score
else:
combined_score = 0.0
scores.append(combined_score)
# Get top results
results_df = df.copy()
results_df['score'] = scores
return results_df.sort_values('score', ascending=False)
def render_result(result):
"""Render a single search result"""
score = result.pop('score', 0)
# Display video if available
if 'youtube_id' in result:
st.video(
f"https://youtube.com/watch?v={result['youtube_id']}&t={result.get('start_time', 0)}"
)
# Display other fields
cols = st.columns([2, 1])
with cols[0]:
for key, value in result.items():
if isinstance(value, (str, int, float)):
st.write(f"**{key}:** {value}")
with cols[1]:
st.metric("Relevance Score", f"{score:.2%}")
def main():
st.title("🎥 Fast Video Dataset Search")
# Initialize search class
searcher = FastDatasetSearcher()
# Show dataset info
if st.session_state['dataset_info']:
st.sidebar.write("### Dataset Info")
st.sidebar.write(f"Total examples: {st.session_state['dataset_info'].splits['train'].num_examples:,}")
total_pages = st.session_state['dataset_info'].splits['train'].num_examples // ROWS_PER_PAGE
current_page = st.number_input("Page", min_value=0, max_value=total_pages, value=st.session_state['current_page'])
else:
current_page = st.number_input("Page", min_value=0, value=st.session_state['current_page'])
# Load current page
with st.spinner(f"Loading page {current_page}..."):
df = searcher.load_page(current_page)
if df.empty:
st.warning("No data available for this page.")
return
# Search interface
col1, col2 = st.columns([3, 1])
with col1:
query = st.text_input("Search in current page:",
help="Searches within currently loaded data")
with col2:
max_results = st.slider("Max results", 1, ROWS_PER_PAGE, 10)
if query:
with st.spinner("Searching..."):
results = searcher.quick_search(query, df)
# Display results
st.write(f"Found {len(results)} results on this page:")
for i, (_, result) in enumerate(results.head(max_results).iterrows(), 1):
with st.expander(f"Result {i}", expanded=i==1):
render_result(result)
# Show raw data
with st.expander("Show Raw Data"):
st.dataframe(df)
# Navigation buttons
cols = st.columns(2)
with cols[0]:
if st.button("⬅️ Previous Page") and current_page > 0:
st.session_state['current_page'] = current_page - 1
st.rerun()
with cols[1]:
if st.button("Next Page ➡️"):
st.session_state['current_page'] = current_page + 1
st.rerun()
if __name__ == "__main__":
main() |