|
import streamlit as st |
|
import pandas as pd |
|
import numpy as np |
|
from sentence_transformers import SentenceTransformer |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
import requests |
|
from datetime import datetime |
|
import os |
|
|
|
HF_KEY = os.getenv('DATASET_KEY') |
|
|
|
|
|
if 'search_history' not in st.session_state: |
|
st.session_state['search_history'] = [] |
|
if 'search_columns' not in st.session_state: |
|
st.session_state['search_columns'] = [] |
|
if 'initial_search_done' not in st.session_state: |
|
st.session_state['initial_search_done'] = False |
|
if 'hf_token' not in st.session_state: |
|
st.session_state['hf_token'] = HF_KEY |
|
|
|
def fetch_dataset_info_auth(dataset_id, hf_token): |
|
"""Fetch dataset information with authentication""" |
|
info_url = f"https://huggingface.co/api/datasets/{dataset_id}" |
|
headers = {"Authorization": f"Bearer {hf_token}"} |
|
try: |
|
response = requests.get(info_url, headers=headers, timeout=30) |
|
if response.status_code == 200: |
|
return response.json() |
|
except Exception as e: |
|
st.warning(f"Error fetching dataset info: {e}") |
|
return None |
|
|
|
def fetch_dataset_splits_auth(dataset_id, hf_token): |
|
"""Fetch available splits for the dataset""" |
|
splits_url = f"https://datasets-server.huggingface.co/splits?dataset={dataset_id}" |
|
headers = {"Authorization": f"Bearer {hf_token}"} |
|
try: |
|
response = requests.get(splits_url, headers=headers, timeout=30) |
|
if response.status_code == 200: |
|
return response.json().get('splits', []) |
|
except Exception as e: |
|
st.warning(f"Error fetching splits: {e}") |
|
return [] |
|
|
|
def fetch_parquet_urls_auth(dataset_id, config, split, hf_token): |
|
"""Fetch Parquet file URLs for a specific split""" |
|
parquet_url = f"https://huggingface.co/api/datasets/{dataset_id}/parquet/{config}/{split}" |
|
headers = {"Authorization": f"Bearer {hf_token}"} |
|
try: |
|
response = requests.get(parquet_url, headers=headers, timeout=30) |
|
if response.status_code == 200: |
|
return response.json() |
|
except Exception as e: |
|
st.warning(f"Error fetching parquet URLs: {e}") |
|
return [] |
|
|
|
def fetch_rows_auth(dataset_id, config, split, offset, length, hf_token): |
|
"""Fetch rows with authentication""" |
|
url = f"https://datasets-server.huggingface.co/rows?dataset={dataset_id}&config={config}&split={split}&offset={offset}&length={length}" |
|
headers = {"Authorization": f"Bearer {hf_token}"} |
|
try: |
|
response = requests.get(url, headers=headers, timeout=30) |
|
if response.status_code == 200: |
|
return response.json() |
|
except Exception as e: |
|
st.warning(f"Error fetching rows: {e}") |
|
return None |
|
|
|
class ParquetVideoSearch: |
|
def __init__(self, hf_token): |
|
self.text_model = SentenceTransformer('all-MiniLM-L6-v2') |
|
self.dataset_id = "tomg-group-umd/cinepile" |
|
self.config = "v2" |
|
self.hf_token = hf_token |
|
self.load_dataset() |
|
|
|
def load_dataset(self): |
|
"""Load initial dataset sample""" |
|
try: |
|
rows_data = fetch_rows_auth( |
|
self.dataset_id, |
|
self.config, |
|
"train", |
|
0, |
|
100, |
|
self.hf_token |
|
) |
|
|
|
if rows_data and 'rows' in rows_data: |
|
processed_rows = [] |
|
for row_data in rows_data['rows']: |
|
row = row_data.get('row', row_data) |
|
processed_rows.append(row) |
|
|
|
self.dataset = pd.DataFrame(processed_rows) |
|
st.session_state['search_columns'] = [col for col in self.dataset.columns |
|
if not any(term in col.lower() for term in ['embed', 'vector', 'encoding'])] |
|
else: |
|
self.dataset = self.load_example_data() |
|
|
|
except Exception as e: |
|
st.warning(f"Error loading dataset: {e}") |
|
self.dataset = self.load_example_data() |
|
|
|
self.prepare_features() |
|
|
|
def load_example_data(self): |
|
"""Load example data as fallback""" |
|
return pd.DataFrame([{ |
|
"video_id": "example", |
|
"title": "Example Video", |
|
"description": "Example video content", |
|
"duration": 120, |
|
"start_time": 0, |
|
"end_time": 120 |
|
}]) |
|
|
|
def prepare_features(self): |
|
"""Prepare text features for search""" |
|
try: |
|
|
|
text_fields = ['title', 'description'] if 'title' in self.dataset.columns else ['description'] |
|
combined_text = self.dataset[text_fields].fillna('').agg(' '.join, axis=1) |
|
self.text_embeds = self.text_model.encode(combined_text.tolist()) |
|
|
|
except Exception as e: |
|
st.warning(f"Error preparing features: {e}") |
|
self.text_embeds = np.random.randn(len(self.dataset), 384) |
|
|
|
def search(self, query, column=None, top_k=20): |
|
"""Search using text embeddings and optional column filtering""" |
|
query_embedding = self.text_model.encode([query])[0] |
|
similarities = cosine_similarity([query_embedding], self.text_embeds)[0] |
|
|
|
|
|
if column and column in self.dataset.columns and column != "All Fields": |
|
mask = self.dataset[column].astype(str).str.contains(query, case=False) |
|
similarities[~mask] *= 0.5 |
|
|
|
top_k = min(top_k, len(similarities)) |
|
top_indices = np.argsort(similarities)[-top_k:][::-1] |
|
|
|
results = [] |
|
for idx in top_indices: |
|
result = { |
|
'relevance_score': float(similarities[idx]), |
|
**self.dataset.iloc[idx].to_dict() |
|
} |
|
results.append(result) |
|
|
|
return results |
|
|
|
def render_video_result(result): |
|
"""Render a video result with enhanced display""" |
|
col1, col2 = st.columns([2, 1]) |
|
|
|
with col1: |
|
if 'title' in result: |
|
st.markdown(f"**Title:** {result['title']}") |
|
st.markdown("**Description:**") |
|
st.write(result.get('description', 'No description available')) |
|
|
|
|
|
start_time = result.get('start_time', 0) |
|
end_time = result.get('end_time', result.get('duration', 0)) |
|
st.markdown(f"**Time Range:** {start_time}s - {end_time}s") |
|
|
|
|
|
for key, value in result.items(): |
|
if key not in ['title', 'description', 'start_time', 'end_time', 'duration', |
|
'relevance_score', 'video_id', '_config', '_split']: |
|
st.markdown(f"**{key.replace('_', ' ').title()}:** {value}") |
|
|
|
with col2: |
|
st.markdown(f"**Relevance Score:** {result['relevance_score']:.2%}") |
|
|
|
|
|
video_url = None |
|
if 'video_url' in result: |
|
video_url = result['video_url'] |
|
elif 'youtube_id' in result: |
|
video_url = f"https://youtube.com/watch?v={result['youtube_id']}&t={start_time}" |
|
|
|
if video_url: |
|
st.video(video_url) |
|
|
|
def main(): |
|
st.title("π₯ Video Dataset Search") |
|
|
|
|
|
if not st.session_state['hf_token']: |
|
st.session_state['hf_token'] = st.secrets.get("HF_TOKEN", None) |
|
|
|
if not st.session_state['hf_token']: |
|
hf_token = st.text_input("Enter your Hugging Face API token:", type="password") |
|
if hf_token: |
|
st.session_state['hf_token'] = hf_token |
|
|
|
if not st.session_state.get('hf_token'): |
|
st.warning("Please provide a Hugging Face API token to access the dataset.") |
|
return |
|
|
|
|
|
search = ParquetVideoSearch(st.session_state['hf_token']) |
|
|
|
|
|
tab1, tab2 = st.tabs(["π Video Search", "π Dataset Info"]) |
|
|
|
|
|
with tab1: |
|
st.subheader("Search Videos") |
|
col1, col2 = st.columns([3, 1]) |
|
|
|
with col1: |
|
query = st.text_input("Enter your search query:", |
|
value="" if st.session_state['initial_search_done'] else "") |
|
with col2: |
|
search_column = st.selectbox("Search in field:", |
|
["All Fields"] + st.session_state['search_columns']) |
|
|
|
col3, col4 = st.columns(2) |
|
with col3: |
|
num_results = st.slider("Number of results:", 1, 100, 20) |
|
with col4: |
|
search_button = st.button("π Search") |
|
|
|
if search_button and query: |
|
st.session_state['initial_search_done'] = True |
|
selected_column = None if search_column == "All Fields" else search_column |
|
|
|
with st.spinner("Searching..."): |
|
results = search.search(query, selected_column, num_results) |
|
|
|
st.session_state['search_history'].append({ |
|
'query': query, |
|
'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S"), |
|
'results': results[:5] |
|
}) |
|
|
|
for i, result in enumerate(results, 1): |
|
with st.expander( |
|
f"Result {i}: {result.get('title', result.get('description', 'No title'))[:100]}...", |
|
expanded=(i==1) |
|
): |
|
render_video_result(result) |
|
|
|
|
|
with tab2: |
|
st.subheader("Dataset Information") |
|
|
|
|
|
splits = fetch_dataset_splits_auth(search.dataset_id, st.session_state['hf_token']) |
|
if splits: |
|
st.write("### Available Splits") |
|
for split in splits: |
|
st.write(f"- {split['split']}: {split.get('num_rows', 'unknown')} rows") |
|
|
|
|
|
st.write("### Dataset Statistics") |
|
st.write(f"- Loaded rows: {len(search.dataset)}") |
|
st.write(f"- Available columns: {', '.join(search.dataset.columns)}") |
|
|
|
|
|
st.write("### Sample Data") |
|
st.dataframe(search.dataset.head()) |
|
|
|
|
|
with st.sidebar: |
|
st.subheader("βοΈ Search History") |
|
if st.button("ποΈ Clear History"): |
|
st.session_state['search_history'] = [] |
|
st.experimental_rerun() |
|
|
|
st.markdown("### Recent Searches") |
|
for entry in reversed(st.session_state['search_history'][-5:]): |
|
with st.expander(f"{entry['timestamp']}: {entry['query']}"): |
|
for i, result in enumerate(entry['results'], 1): |
|
st.write(f"{i}. {result.get('title', result.get('description', 'No title'))[:100]}...") |
|
|
|
if __name__ == "__main__": |
|
main() |