Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -3,383 +3,193 @@ import pandas as pd
|
|
3 |
import numpy as np
|
4 |
from sentence_transformers import SentenceTransformer
|
5 |
from sklearn.metrics.pairwise import cosine_similarity
|
6 |
-
import torch
|
7 |
-
import json
|
8 |
import os
|
9 |
-
import glob
|
10 |
-
from pathlib import Path
|
11 |
from datetime import datetime
|
12 |
-
import edge_tts
|
13 |
-
import asyncio
|
14 |
-
import base64
|
15 |
import requests
|
16 |
-
from
|
17 |
-
from audio_recorder_streamlit import audio_recorder
|
18 |
-
import streamlit.components.v1 as components
|
19 |
from urllib.parse import quote
|
20 |
-
from xml.etree import ElementTree as ET
|
21 |
|
22 |
# Initialize session state
|
23 |
if 'search_history' not in st.session_state:
|
24 |
st.session_state['search_history'] = []
|
25 |
-
if 'last_voice_input' not in st.session_state:
|
26 |
-
st.session_state['last_voice_input'] = ""
|
27 |
-
if 'transcript_history' not in st.session_state:
|
28 |
-
st.session_state['transcript_history'] = []
|
29 |
-
if 'should_rerun' not in st.session_state:
|
30 |
-
st.session_state['should_rerun'] = False
|
31 |
if 'search_columns' not in st.session_state:
|
32 |
st.session_state['search_columns'] = []
|
33 |
if 'initial_search_done' not in st.session_state:
|
34 |
st.session_state['initial_search_done'] = False
|
35 |
-
if '
|
36 |
-
st.session_state['
|
37 |
-
if 'arxiv_last_query' not in st.session_state:
|
38 |
-
st.session_state['arxiv_last_query'] = ""
|
39 |
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
try:
|
44 |
-
response = requests.get(info_url, timeout=30)
|
45 |
-
if response.status_code == 200:
|
46 |
-
return response.json()
|
47 |
-
except Exception as e:
|
48 |
-
st.warning(f"Error fetching dataset info: {e}")
|
49 |
-
return None
|
50 |
-
|
51 |
-
def fetch_dataset_rows(dataset_id, config="default", split="train", max_rows=100):
|
52 |
-
"""Fetch rows from a specific config and split of a dataset"""
|
53 |
-
url = f"https://datasets-server.huggingface.co/first-rows?dataset={dataset_id}&config={config}&split={split}"
|
54 |
-
try:
|
55 |
-
response = requests.get(url, timeout=30)
|
56 |
-
if response.status_code == 200:
|
57 |
-
data = response.json()
|
58 |
-
if 'rows' in data:
|
59 |
-
processed_rows = []
|
60 |
-
for row_data in data['rows']:
|
61 |
-
row = row_data.get('row', row_data)
|
62 |
-
# Process embeddings if present
|
63 |
-
for key in row:
|
64 |
-
if any(term in key.lower() for term in ['embed', 'vector', 'encoding']):
|
65 |
-
if isinstance(row[key], str):
|
66 |
-
try:
|
67 |
-
row[key] = [float(x.strip()) for x in row[key].strip('[]').split(',') if x.strip()]
|
68 |
-
except:
|
69 |
-
continue
|
70 |
-
row['_config'] = config
|
71 |
-
row['_split'] = split
|
72 |
-
processed_rows.append(row)
|
73 |
-
return processed_rows
|
74 |
-
except Exception as e:
|
75 |
-
st.warning(f"Error fetching rows for {config}/{split}: {e}")
|
76 |
-
return []
|
77 |
-
|
78 |
-
def search_dataset(dataset_id, search_text, include_configs=None, include_splits=None):
|
79 |
-
"""
|
80 |
-
Search across all configurations and splits of a dataset
|
81 |
-
|
82 |
-
Args:
|
83 |
-
dataset_id (str): The Hugging Face dataset ID
|
84 |
-
search_text (str): Text to search for in descriptions and queries
|
85 |
-
include_configs (list): List of specific configs to search, or None for all
|
86 |
-
include_splits (list): List of specific splits to search, or None for all
|
87 |
-
|
88 |
-
Returns:
|
89 |
-
tuple: (DataFrame of results, list of available configs, list of available splits)
|
90 |
-
"""
|
91 |
-
# Get dataset info
|
92 |
-
dataset_info = fetch_dataset_info(dataset_id)
|
93 |
-
if not dataset_info:
|
94 |
-
return pd.DataFrame(), [], []
|
95 |
-
|
96 |
-
# Get available configs and splits
|
97 |
-
configs = include_configs if include_configs else dataset_info.get('config_names', ['default'])
|
98 |
-
all_rows = []
|
99 |
-
available_splits = set()
|
100 |
-
|
101 |
-
# Search across configs and splits
|
102 |
-
for config in configs:
|
103 |
-
try:
|
104 |
-
# First fetch split info for this config
|
105 |
-
splits_url = f"https://datasets-server.huggingface.co/splits?dataset={dataset_id}&config={config}"
|
106 |
-
splits_response = requests.get(splits_url, timeout=30)
|
107 |
-
if splits_response.status_code == 200:
|
108 |
-
splits_data = splits_response.json()
|
109 |
-
splits = [split['split'] for split in splits_data.get('splits', [])]
|
110 |
-
if not splits:
|
111 |
-
splits = ['train'] # fallback to train if no splits found
|
112 |
-
|
113 |
-
# Filter splits if specified
|
114 |
-
if include_splits:
|
115 |
-
splits = [s for s in splits if s in include_splits]
|
116 |
-
|
117 |
-
available_splits.update(splits)
|
118 |
-
|
119 |
-
# Fetch and search rows for each split
|
120 |
-
for split in splits:
|
121 |
-
rows = fetch_dataset_rows(dataset_id, config, split)
|
122 |
-
for row in rows:
|
123 |
-
# Search in all text fields
|
124 |
-
text_content = ' '.join(str(v) for v in row.values() if isinstance(v, (str, int, float)))
|
125 |
-
if search_text.lower() in text_content.lower():
|
126 |
-
row['_matched_text'] = text_content
|
127 |
-
row['_relevance_score'] = text_content.lower().count(search_text.lower())
|
128 |
-
all_rows.append(row)
|
129 |
-
|
130 |
-
except Exception as e:
|
131 |
-
st.warning(f"Error processing config {config}: {e}")
|
132 |
-
continue
|
133 |
-
|
134 |
-
# Convert to DataFrame and sort by relevance
|
135 |
-
if all_rows:
|
136 |
-
df = pd.DataFrame(all_rows)
|
137 |
-
df = df.sort_values('_relevance_score', ascending=False)
|
138 |
-
return df, configs, list(available_splits)
|
139 |
-
|
140 |
-
return pd.DataFrame(), configs, list(available_splits)
|
141 |
-
|
142 |
-
class VideoSearch:
|
143 |
-
def __init__(self):
|
144 |
self.text_model = SentenceTransformer('all-MiniLM-L6-v2')
|
145 |
-
self.
|
|
|
|
|
|
|
146 |
self.load_dataset()
|
147 |
-
|
148 |
-
def
|
149 |
-
"""
|
150 |
try:
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
|
|
|
|
|
|
|
|
|
|
158 |
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
|
|
|
|
|
|
|
|
|
|
166 |
|
167 |
except Exception as e:
|
168 |
-
st.
|
169 |
-
|
170 |
-
|
171 |
-
def load_example_data(self):
|
172 |
-
"""Load example data as fallback"""
|
173 |
-
example_data = [
|
174 |
-
{
|
175 |
-
"video_id": "cd21da96-fcca-4c94-a60f-0b1e4e1e29fc",
|
176 |
-
"youtube_id": "IO-vwtyicn4",
|
177 |
-
"description": "This video shows a close-up of an ancient text carved into a surface.",
|
178 |
-
"views": 45489,
|
179 |
-
"start_time": 1452,
|
180 |
-
"end_time": 1458,
|
181 |
-
"video_embed": [0.014160037972033024, -0.003111184574663639, -0.016604168340563774],
|
182 |
-
"description_embed": [-0.05835828185081482, 0.02589797042310238, 0.11952091753482819]
|
183 |
-
}
|
184 |
-
]
|
185 |
-
return pd.DataFrame(example_data)
|
186 |
|
187 |
def prepare_features(self):
|
188 |
-
"""Prepare embeddings
|
189 |
try:
|
190 |
-
|
191 |
-
|
192 |
|
193 |
-
embeddings
|
194 |
-
|
195 |
-
|
196 |
-
data = []
|
197 |
-
for row in self.dataset[col]:
|
198 |
-
if isinstance(row, str):
|
199 |
-
values = [float(x.strip()) for x in row.strip('[]').split(',') if x.strip()]
|
200 |
-
elif isinstance(row, list):
|
201 |
-
values = row
|
202 |
-
else:
|
203 |
-
continue
|
204 |
-
data.append(values)
|
205 |
-
|
206 |
-
if data:
|
207 |
-
embeddings[col] = np.array(data)
|
208 |
-
except:
|
209 |
-
continue
|
210 |
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
except:
|
223 |
-
# Fallback to random embeddings
|
224 |
-
num_rows = len(self.dataset)
|
225 |
-
self.video_embeds = np.random.randn(num_rows, 384)
|
226 |
-
self.text_embeds = np.random.randn(num_rows, 384)
|
227 |
-
|
228 |
-
def load_dataset(self):
|
229 |
-
self.dataset = self.fetch_dataset_rows()
|
230 |
-
self.prepare_features()
|
231 |
|
232 |
def search(self, query, column=None, top_k=20):
|
|
|
|
|
|
|
|
|
|
|
233 |
query_embedding = self.text_model.encode([query])[0]
|
234 |
-
|
235 |
-
|
236 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
237 |
|
238 |
-
#
|
239 |
-
|
240 |
-
mask = self.dataset[column].astype(str).str.contains(query, case=False)
|
241 |
-
combined_sims[~mask] *= 0.5
|
242 |
|
243 |
-
|
244 |
-
|
|
|
245 |
|
|
|
246 |
results = []
|
247 |
for idx in top_indices:
|
248 |
-
result = {
|
249 |
-
|
250 |
-
|
251 |
-
|
|
|
|
|
252 |
results.append(result)
|
253 |
|
254 |
return results
|
255 |
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
st.error(f"Error generating speech: {e}")
|
272 |
-
return None
|
273 |
-
|
274 |
-
def transcribe_audio(audio_path):
|
275 |
-
"""Placeholder for ASR transcription"""
|
276 |
-
return "ASR not implemented. Integrate a local model or another service here."
|
277 |
|
278 |
-
def
|
279 |
-
"""
|
280 |
-
st.
|
281 |
-
|
282 |
with col1:
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
st.
|
288 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
289 |
|
290 |
with col2:
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
with st.expander(f"📄 {os.path.basename(f)}"):
|
302 |
-
if f.endswith('.mp3'):
|
303 |
-
st.audio(f)
|
304 |
-
else:
|
305 |
-
with open(f, 'r', encoding='utf-8') as file:
|
306 |
-
st.text_area("Content", file.read(), height=100)
|
307 |
-
if st.button(f"Delete {os.path.basename(f)}", key=f"del_{f}"):
|
308 |
-
os.remove(f)
|
309 |
-
st.experimental_rerun()
|
310 |
-
|
311 |
-
def arxiv_search(query, max_results=5):
|
312 |
-
"""Perform a simple Arxiv search using their API and return top results."""
|
313 |
-
base_url = "http://export.arxiv.org/api/query?"
|
314 |
-
search_url = base_url + f"search_query={quote(query)}&start=0&max_results={max_results}"
|
315 |
-
r = requests.get(search_url)
|
316 |
-
if r.status_code == 200:
|
317 |
-
root = ET.fromstring(r.text)
|
318 |
-
ns = {'atom': 'http://www.w3.org/2005/Atom'}
|
319 |
-
entries = root.findall('atom:entry', ns)
|
320 |
-
results = []
|
321 |
-
for entry in entries:
|
322 |
-
title = entry.find('atom:title', ns).text.strip()
|
323 |
-
summary = entry.find('atom:summary', ns).text.strip()
|
324 |
-
link = None
|
325 |
-
for l in entry.findall('atom:link', ns):
|
326 |
-
if l.get('type') == 'text/html':
|
327 |
-
link = l.get('href')
|
328 |
-
break
|
329 |
-
results.append((title, summary, link))
|
330 |
-
return results
|
331 |
-
return []
|
332 |
-
|
333 |
-
def perform_arxiv_lookup(q, vocal_summary=True, titles_summary=True, full_audio=False):
|
334 |
-
results = arxiv_search(q, max_results=5)
|
335 |
-
if not results:
|
336 |
-
st.write("No Arxiv results found.")
|
337 |
-
return
|
338 |
-
st.markdown(f"**Arxiv Search Results for '{q}':**")
|
339 |
-
for i, (title, summary, link) in enumerate(results, start=1):
|
340 |
-
st.markdown(f"**{i}. {title}**")
|
341 |
-
st.write(summary)
|
342 |
-
if link:
|
343 |
-
st.markdown(f"[View Paper]({link})")
|
344 |
-
|
345 |
-
if vocal_summary:
|
346 |
-
spoken_text = f"Here are some Arxiv results for {q}. "
|
347 |
-
if titles_summary:
|
348 |
-
spoken_text += " Titles: " + ", ".join([res[0] for res in results])
|
349 |
-
else:
|
350 |
-
# Just first summary if no titles_summary
|
351 |
-
spoken_text += " " + results[0][1][:200]
|
352 |
-
|
353 |
-
audio_file = asyncio.run(generate_speech(spoken_text))
|
354 |
-
if audio_file:
|
355 |
-
st.audio(audio_file)
|
356 |
-
|
357 |
-
if full_audio:
|
358 |
-
# Full audio of summaries
|
359 |
-
full_text = ""
|
360 |
-
for i,(title, summary, _) in enumerate(results, start=1):
|
361 |
-
full_text += f"Result {i}: {title}. {summary} "
|
362 |
-
audio_file_full = asyncio.run(generate_speech(full_text))
|
363 |
-
if audio_file_full:
|
364 |
-
st.write("### Full Audio")
|
365 |
-
st.audio(audio_file_full)
|
366 |
|
367 |
def main():
|
368 |
-
st.title("🎥 Video
|
369 |
|
370 |
# Initialize search class
|
371 |
-
|
372 |
|
373 |
# Create tabs
|
374 |
-
tab1, tab2
|
375 |
|
376 |
-
# ---- Tab 1:
|
377 |
with tab1:
|
378 |
st.subheader("Search Videos")
|
379 |
col1, col2 = st.columns([3, 1])
|
|
|
380 |
with col1:
|
381 |
-
query = st.text_input("
|
382 |
-
|
383 |
with col2:
|
384 |
search_column = st.selectbox("Search in field:",
|
385 |
["All Fields"] + st.session_state['search_columns'])
|
@@ -390,11 +200,12 @@ def main():
|
|
390 |
with col4:
|
391 |
search_button = st.button("🔍 Search")
|
392 |
|
393 |
-
if
|
394 |
st.session_state['initial_search_done'] = True
|
395 |
selected_column = None if search_column == "All Fields" else search_column
|
|
|
396 |
with st.spinner("Searching..."):
|
397 |
-
results =
|
398 |
|
399 |
st.session_state['search_history'].append({
|
400 |
'query': query,
|
@@ -403,151 +214,34 @@ def main():
|
|
403 |
})
|
404 |
|
405 |
for i, result in enumerate(results, 1):
|
406 |
-
with st.expander(
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
with cols[1]:
|
415 |
-
st.markdown(f"**Relevance Score:** {result['relevance_score']:.2%}")
|
416 |
-
if result.get('youtube_id'):
|
417 |
-
st.video(f"https://youtube.com/watch?v={result['youtube_id']}&t={result['start_time']}")
|
418 |
-
|
419 |
-
if st.button(f"🔊 Audio Summary", key=f"audio_{i}"):
|
420 |
-
summary = f"Video summary: {result['description'][:200]}"
|
421 |
-
audio_file = asyncio.run(generate_speech(summary))
|
422 |
-
if audio_file:
|
423 |
-
st.audio(audio_file)
|
424 |
-
|
425 |
-
# ---- Tab 2: Voice Input ----
|
426 |
with tab2:
|
427 |
-
st.subheader("
|
428 |
-
|
429 |
-
|
430 |
-
if
|
431 |
-
|
432 |
-
|
433 |
-
|
434 |
-
st.
|
435 |
-
|
436 |
-
voice_query = transcribe_audio(audio_path)
|
437 |
-
st.markdown("**Transcribed Text:**")
|
438 |
-
st.write(voice_query)
|
439 |
-
st.session_state['last_voice_input'] = voice_query
|
440 |
|
441 |
-
|
442 |
-
|
443 |
-
for i, result in enumerate(results, 1):
|
444 |
-
with st.expander(f"Result {i}", expanded=(i==1)):
|
445 |
-
st.write(result['description'])
|
446 |
-
if result.get('youtube_id'):
|
447 |
-
st.video(f"https://youtube.com/watch?v={result['youtube_id']}&t={result.get('start_time', 0)}")
|
448 |
|
449 |
-
|
450 |
-
|
451 |
-
|
452 |
-
# ---- Tab 3: Arxiv Search ----
|
453 |
-
with tab3:
|
454 |
-
st.subheader("Arxiv Search")
|
455 |
-
q = st.text_input("Enter your Arxiv search query:", value=st.session_state['arxiv_last_query'])
|
456 |
-
vocal_summary = st.checkbox("🎙 Short Audio Summary", value=True)
|
457 |
-
titles_summary = st.checkbox("🔖 Titles Only", value=True)
|
458 |
-
full_audio = st.checkbox("📚 Full Audio Results", value=False)
|
459 |
-
|
460 |
-
if st.button("🔍 Arxiv Search"):
|
461 |
-
st.session_state['arxiv_last_query'] = q
|
462 |
-
perform_arxiv_lookup(q, vocal_summary=vocal_summary, titles_summary=titles_summary, full_audio=full_audio)
|
463 |
-
|
464 |
-
# ---- Tab 4: File Manager ----
|
465 |
-
with tab4:
|
466 |
-
show_file_manager()
|
467 |
-
|
468 |
-
# ---- Tab 5: Advanced Dataset Search ----
|
469 |
-
with tab5:
|
470 |
-
st.subheader("Advanced Dataset Search")
|
471 |
-
|
472 |
-
# Dataset input
|
473 |
-
dataset_id = st.text_input("Dataset ID:", value="omegalabsinc/omega-multimodal")
|
474 |
-
|
475 |
-
# Search configuration
|
476 |
-
col1, col2 = st.columns([2, 1])
|
477 |
-
with col1:
|
478 |
-
search_text = st.text_input("Search text:",
|
479 |
-
placeholder="Enter text to search across all fields")
|
480 |
-
|
481 |
-
# Get available configs and splits
|
482 |
-
if dataset_id:
|
483 |
-
dataset_info = fetch_dataset_info(dataset_id)
|
484 |
-
if dataset_info:
|
485 |
-
configs = dataset_info.get('config_names', ['default'])
|
486 |
-
with col2:
|
487 |
-
selected_configs = st.multiselect(
|
488 |
-
"Configurations:",
|
489 |
-
options=configs,
|
490 |
-
default=['default'] if 'default' in configs else None
|
491 |
-
)
|
492 |
-
|
493 |
-
# Fetch available splits
|
494 |
-
if selected_configs:
|
495 |
-
all_splits = set()
|
496 |
-
for config in selected_configs:
|
497 |
-
splits_url = f"https://datasets-server.huggingface.co/splits?dataset={dataset_id}&config={config}"
|
498 |
-
try:
|
499 |
-
response = requests.get(splits_url, timeout=30)
|
500 |
-
if response.status_code == 200:
|
501 |
-
splits_data = response.json()
|
502 |
-
splits = [split['split'] for split in splits_data.get('splits', [])]
|
503 |
-
all_splits.update(splits)
|
504 |
-
except Exception as e:
|
505 |
-
st.warning(f"Error fetching splits for {config}: {e}")
|
506 |
-
|
507 |
-
selected_splits = st.multiselect(
|
508 |
-
"Splits:",
|
509 |
-
options=list(all_splits),
|
510 |
-
default=['train'] if 'train' in all_splits else None
|
511 |
-
)
|
512 |
-
|
513 |
-
if st.button("🔍 Search Dataset"):
|
514 |
-
with st.spinner("Searching dataset..."):
|
515 |
-
results_df, _, _ = search_dataset(
|
516 |
-
dataset_id,
|
517 |
-
search_text,
|
518 |
-
include_configs=selected_configs,
|
519 |
-
include_splits=selected_splits
|
520 |
-
)
|
521 |
-
|
522 |
-
if not results_df.empty:
|
523 |
-
st.write(f"Found {len(results_df)} results")
|
524 |
-
|
525 |
-
# Display results in expandable sections
|
526 |
-
for idx, row in results_df.iterrows():
|
527 |
-
with st.expander(
|
528 |
-
f"Result {idx+1} (Config: {row['_config']}, Split: {row['_split']}, Score: {row['_relevance_score']})"
|
529 |
-
):
|
530 |
-
# Display all fields except internal ones
|
531 |
-
for col in row.index:
|
532 |
-
if not col.startswith('_') and not any(
|
533 |
-
term in col.lower()
|
534 |
-
for term in ['embed', 'vector', 'encoding']
|
535 |
-
):
|
536 |
-
st.write(f"**{col}:** {row[col]}")
|
537 |
-
|
538 |
-
# Add buttons for audio/video if available
|
539 |
-
if 'youtube_id' in row:
|
540 |
-
st.video(
|
541 |
-
f"https://youtube.com/watch?v={row['youtube_id']}&t={row.get('start_time', 0)}"
|
542 |
-
)
|
543 |
-
else:
|
544 |
-
st.warning("No results found.")
|
545 |
-
else:
|
546 |
-
st.error("Unable to fetch dataset information.")
|
547 |
|
548 |
# Sidebar
|
549 |
with st.sidebar:
|
550 |
-
st.subheader("⚙️
|
551 |
if st.button("🗑️ Clear History"):
|
552 |
st.session_state['search_history'] = []
|
553 |
st.experimental_rerun()
|
@@ -556,12 +250,7 @@ def main():
|
|
556 |
for entry in reversed(st.session_state['search_history'][-5:]):
|
557 |
with st.expander(f"{entry['timestamp']}: {entry['query']}"):
|
558 |
for i, result in enumerate(entry['results'], 1):
|
559 |
-
st.write(f"{i}. {result
|
560 |
-
|
561 |
-
st.markdown("### Voice Settings")
|
562 |
-
st.selectbox("TTS Voice:",
|
563 |
-
["en-US-AriaNeural", "en-US-GuyNeural", "en-GB-SoniaNeural"],
|
564 |
-
key="tts_voice")
|
565 |
|
566 |
if __name__ == "__main__":
|
567 |
main()
|
|
|
3 |
import numpy as np
|
4 |
from sentence_transformers import SentenceTransformer
|
5 |
from sklearn.metrics.pairwise import cosine_similarity
|
|
|
|
|
6 |
import os
|
|
|
|
|
7 |
from datetime import datetime
|
|
|
|
|
|
|
8 |
import requests
|
9 |
+
from datasets import load_dataset
|
|
|
|
|
10 |
from urllib.parse import quote
|
|
|
11 |
|
12 |
# Initialize session state
|
13 |
if 'search_history' not in st.session_state:
|
14 |
st.session_state['search_history'] = []
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
if 'search_columns' not in st.session_state:
|
16 |
st.session_state['search_columns'] = []
|
17 |
if 'initial_search_done' not in st.session_state:
|
18 |
st.session_state['initial_search_done'] = False
|
19 |
+
if 'dataset' not in st.session_state:
|
20 |
+
st.session_state['dataset'] = None
|
|
|
|
|
21 |
|
22 |
+
class DatasetSearcher:
|
23 |
+
def __init__(self, dataset_id="tomg-group-umd/cinepile"):
|
24 |
+
self.dataset_id = dataset_id
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
self.text_model = SentenceTransformer('all-MiniLM-L6-v2')
|
26 |
+
self.token = os.environ.get('DATASET_KEY')
|
27 |
+
if not self.token:
|
28 |
+
st.error("Please set the DATASET_KEY environment variable with your Hugging Face token.")
|
29 |
+
st.stop()
|
30 |
self.load_dataset()
|
31 |
+
|
32 |
+
def load_dataset(self):
|
33 |
+
"""Load dataset using the datasets library"""
|
34 |
try:
|
35 |
+
if st.session_state['dataset'] is None:
|
36 |
+
with st.spinner("Loading dataset..."):
|
37 |
+
st.session_state['dataset'] = load_dataset(
|
38 |
+
self.dataset_id,
|
39 |
+
token=self.token,
|
40 |
+
streaming=False
|
41 |
+
)
|
42 |
+
|
43 |
+
self.dataset = st.session_state['dataset']
|
44 |
+
# Convert first split to DataFrame for easier processing
|
45 |
+
first_split = next(iter(self.dataset.values()))
|
46 |
+
self.df = pd.DataFrame(first_split)
|
47 |
|
48 |
+
# Store column information
|
49 |
+
self.columns = list(self.df.columns)
|
50 |
+
self.text_columns = [col for col in self.columns
|
51 |
+
if self.df[col].dtype == 'object'
|
52 |
+
and not any(term in col.lower()
|
53 |
+
for term in ['embed', 'vector', 'encoding'])]
|
54 |
+
|
55 |
+
# Update session state columns
|
56 |
+
st.session_state['search_columns'] = self.text_columns
|
57 |
+
|
58 |
+
# Prepare text embeddings
|
59 |
+
self.prepare_features()
|
60 |
|
61 |
except Exception as e:
|
62 |
+
st.error(f"Error loading dataset: {str(e)}")
|
63 |
+
st.error("Please check your authentication token and internet connection.")
|
64 |
+
st.stop()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
def prepare_features(self):
|
67 |
+
"""Prepare text embeddings for semantic search"""
|
68 |
try:
|
69 |
+
# Combine text columns for embedding
|
70 |
+
combined_text = self.df[self.text_columns].fillna('').agg(' '.join, axis=1)
|
71 |
|
72 |
+
# Create embeddings in batches to manage memory
|
73 |
+
batch_size = 32
|
74 |
+
all_embeddings = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
|
76 |
+
with st.spinner("Preparing search features..."):
|
77 |
+
for i in range(0, len(combined_text), batch_size):
|
78 |
+
batch = combined_text[i:i+batch_size].tolist()
|
79 |
+
embeddings = self.text_model.encode(batch)
|
80 |
+
all_embeddings.append(embeddings)
|
81 |
+
|
82 |
+
self.text_embeddings = np.vstack(all_embeddings)
|
83 |
+
|
84 |
+
except Exception as e:
|
85 |
+
st.warning(f"Error preparing features: {str(e)}")
|
86 |
+
self.text_embeddings = np.random.randn(len(self.df), 384)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
|
88 |
def search(self, query, column=None, top_k=20):
|
89 |
+
"""Search the dataset using semantic and keyword matching"""
|
90 |
+
if self.df.empty:
|
91 |
+
return []
|
92 |
+
|
93 |
+
# Get semantic similarity scores
|
94 |
query_embedding = self.text_model.encode([query])[0]
|
95 |
+
similarities = cosine_similarity([query_embedding], self.text_embeddings)[0]
|
96 |
+
|
97 |
+
# Get keyword match scores
|
98 |
+
search_columns = [column] if column and column != "All Fields" else self.text_columns
|
99 |
+
keyword_scores = np.zeros(len(self.df))
|
100 |
+
|
101 |
+
for col in search_columns:
|
102 |
+
if col in self.df.columns:
|
103 |
+
matches = self.df[col].fillna('').str.lower().str.count(query.lower())
|
104 |
+
keyword_scores += matches
|
105 |
|
106 |
+
# Combine scores
|
107 |
+
combined_scores = 0.5 * similarities + 0.5 * (keyword_scores / max(1, keyword_scores.max()))
|
|
|
|
|
108 |
|
109 |
+
# Get top results
|
110 |
+
top_k = min(top_k, len(combined_scores))
|
111 |
+
top_indices = np.argsort(combined_scores)[-top_k:][::-1]
|
112 |
|
113 |
+
# Format results
|
114 |
results = []
|
115 |
for idx in top_indices:
|
116 |
+
result = {
|
117 |
+
'relevance_score': float(combined_scores[idx]),
|
118 |
+
'semantic_score': float(similarities[idx]),
|
119 |
+
'keyword_score': float(keyword_scores[idx]),
|
120 |
+
**self.df.iloc[idx].to_dict()
|
121 |
+
}
|
122 |
results.append(result)
|
123 |
|
124 |
return results
|
125 |
|
126 |
+
def get_dataset_info(self):
|
127 |
+
"""Get information about the dataset"""
|
128 |
+
if not self.dataset:
|
129 |
+
return {}
|
130 |
+
|
131 |
+
info = {
|
132 |
+
'splits': list(self.dataset.keys()),
|
133 |
+
'total_rows': sum(split.num_rows for split in self.dataset.values()),
|
134 |
+
'columns': self.columns,
|
135 |
+
'text_columns': self.text_columns,
|
136 |
+
'sample_rows': len(self.df),
|
137 |
+
'embeddings_shape': self.text_embeddings.shape
|
138 |
+
}
|
139 |
+
|
140 |
+
return info
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
|
142 |
+
def render_video_result(result):
|
143 |
+
"""Render a video result with enhanced display"""
|
144 |
+
col1, col2 = st.columns([2, 1])
|
145 |
+
|
146 |
with col1:
|
147 |
+
if 'title' in result:
|
148 |
+
st.markdown(f"**Title:** {result['title']}")
|
149 |
+
if 'description' in result:
|
150 |
+
st.markdown("**Description:**")
|
151 |
+
st.write(result['description'])
|
152 |
+
|
153 |
+
# Show timing information if available
|
154 |
+
if 'start_time' in result and 'end_time' in result:
|
155 |
+
st.markdown(f"**Time Range:** {result['start_time']}s - {result['end_time']}s")
|
156 |
+
|
157 |
+
# Show additional metadata
|
158 |
+
for key, value in result.items():
|
159 |
+
if key not in ['title', 'description', 'start_time', 'end_time', 'duration',
|
160 |
+
'relevance_score', 'semantic_score', 'keyword_score',
|
161 |
+
'video_id', 'youtube_id']:
|
162 |
+
st.markdown(f"**{key.replace('_', ' ').title()}:** {value}")
|
163 |
|
164 |
with col2:
|
165 |
+
# Show search scores
|
166 |
+
st.markdown("**Search Scores:**")
|
167 |
+
cols = st.columns(3)
|
168 |
+
cols[0].metric("Overall", f"{result['relevance_score']:.2%}")
|
169 |
+
cols[1].metric("Semantic", f"{result['semantic_score']:.2%}")
|
170 |
+
cols[2].metric("Keyword", f"{result['keyword_score']:.0f} matches")
|
171 |
+
|
172 |
+
# Display video if available
|
173 |
+
if 'youtube_id' in result:
|
174 |
+
st.video(f"https://youtube.com/watch?v={result['youtube_id']}&t={result.get('start_time', 0)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
|
176 |
def main():
|
177 |
+
st.title("🎥 Video Dataset Search")
|
178 |
|
179 |
# Initialize search class
|
180 |
+
searcher = DatasetSearcher()
|
181 |
|
182 |
# Create tabs
|
183 |
+
tab1, tab2 = st.tabs(["🔍 Search", "📊 Dataset Info"])
|
184 |
|
185 |
+
# ---- Tab 1: Search ----
|
186 |
with tab1:
|
187 |
st.subheader("Search Videos")
|
188 |
col1, col2 = st.columns([3, 1])
|
189 |
+
|
190 |
with col1:
|
191 |
+
query = st.text_input("Search query:",
|
192 |
+
value="" if st.session_state['initial_search_done'] else "")
|
193 |
with col2:
|
194 |
search_column = st.selectbox("Search in field:",
|
195 |
["All Fields"] + st.session_state['search_columns'])
|
|
|
200 |
with col4:
|
201 |
search_button = st.button("🔍 Search")
|
202 |
|
203 |
+
if search_button and query:
|
204 |
st.session_state['initial_search_done'] = True
|
205 |
selected_column = None if search_column == "All Fields" else search_column
|
206 |
+
|
207 |
with st.spinner("Searching..."):
|
208 |
+
results = searcher.search(query, selected_column, num_results)
|
209 |
|
210 |
st.session_state['search_history'].append({
|
211 |
'query': query,
|
|
|
214 |
})
|
215 |
|
216 |
for i, result in enumerate(results, 1):
|
217 |
+
with st.expander(
|
218 |
+
f"Result {i}: {result.get('title', result.get('description', 'No title'))[:100]}...",
|
219 |
+
expanded=(i==1)
|
220 |
+
):
|
221 |
+
render_video_result(result)
|
222 |
+
|
223 |
+
# ---- Tab 2: Dataset Info ----
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
224 |
with tab2:
|
225 |
+
st.subheader("Dataset Information")
|
226 |
+
|
227 |
+
info = searcher.get_dataset_info()
|
228 |
+
if info:
|
229 |
+
st.write(f"### Dataset: {searcher.dataset_id}")
|
230 |
+
st.write(f"- Total rows: {info['total_rows']:,}")
|
231 |
+
st.write(f"- Available splits: {', '.join(info['splits'])}")
|
232 |
+
st.write(f"- Number of columns: {len(info['columns'])}")
|
233 |
+
st.write(f"- Searchable text columns: {', '.join(info['text_columns'])}")
|
|
|
|
|
|
|
|
|
234 |
|
235 |
+
st.write("### Sample Data")
|
236 |
+
st.dataframe(searcher.df.head())
|
|
|
|
|
|
|
|
|
|
|
237 |
|
238 |
+
st.write("### Column Details")
|
239 |
+
for col in info['columns']:
|
240 |
+
st.write(f"- **{col}**: {searcher.df[col].dtype}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
241 |
|
242 |
# Sidebar
|
243 |
with st.sidebar:
|
244 |
+
st.subheader("⚙️ Search History")
|
245 |
if st.button("🗑️ Clear History"):
|
246 |
st.session_state['search_history'] = []
|
247 |
st.experimental_rerun()
|
|
|
250 |
for entry in reversed(st.session_state['search_history'][-5:]):
|
251 |
with st.expander(f"{entry['timestamp']}: {entry['query']}"):
|
252 |
for i, result in enumerate(entry['results'], 1):
|
253 |
+
st.write(f"{i}. {result.get('title', result.get('description', 'No title'))[:100]}...")
|
|
|
|
|
|
|
|
|
|
|
254 |
|
255 |
if __name__ == "__main__":
|
256 |
main()
|