awacke1 commited on
Commit
959152c
·
verified ·
1 Parent(s): a0e1cbe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +162 -473
app.py CHANGED
@@ -3,383 +3,193 @@ import pandas as pd
3
  import numpy as np
4
  from sentence_transformers import SentenceTransformer
5
  from sklearn.metrics.pairwise import cosine_similarity
6
- import torch
7
- import json
8
  import os
9
- import glob
10
- from pathlib import Path
11
  from datetime import datetime
12
- import edge_tts
13
- import asyncio
14
- import base64
15
  import requests
16
- from collections import defaultdict
17
- from audio_recorder_streamlit import audio_recorder
18
- import streamlit.components.v1 as components
19
  from urllib.parse import quote
20
- from xml.etree import ElementTree as ET
21
 
22
  # Initialize session state
23
  if 'search_history' not in st.session_state:
24
  st.session_state['search_history'] = []
25
- if 'last_voice_input' not in st.session_state:
26
- st.session_state['last_voice_input'] = ""
27
- if 'transcript_history' not in st.session_state:
28
- st.session_state['transcript_history'] = []
29
- if 'should_rerun' not in st.session_state:
30
- st.session_state['should_rerun'] = False
31
  if 'search_columns' not in st.session_state:
32
  st.session_state['search_columns'] = []
33
  if 'initial_search_done' not in st.session_state:
34
  st.session_state['initial_search_done'] = False
35
- if 'tts_voice' not in st.session_state:
36
- st.session_state['tts_voice'] = "en-US-AriaNeural"
37
- if 'arxiv_last_query' not in st.session_state:
38
- st.session_state['arxiv_last_query'] = ""
39
 
40
- def fetch_dataset_info(dataset_id):
41
- """Fetch dataset information including all available configs and splits"""
42
- info_url = f"https://huggingface.co/api/datasets/{dataset_id}"
43
- try:
44
- response = requests.get(info_url, timeout=30)
45
- if response.status_code == 200:
46
- return response.json()
47
- except Exception as e:
48
- st.warning(f"Error fetching dataset info: {e}")
49
- return None
50
-
51
- def fetch_dataset_rows(dataset_id, config="default", split="train", max_rows=100):
52
- """Fetch rows from a specific config and split of a dataset"""
53
- url = f"https://datasets-server.huggingface.co/first-rows?dataset={dataset_id}&config={config}&split={split}"
54
- try:
55
- response = requests.get(url, timeout=30)
56
- if response.status_code == 200:
57
- data = response.json()
58
- if 'rows' in data:
59
- processed_rows = []
60
- for row_data in data['rows']:
61
- row = row_data.get('row', row_data)
62
- # Process embeddings if present
63
- for key in row:
64
- if any(term in key.lower() for term in ['embed', 'vector', 'encoding']):
65
- if isinstance(row[key], str):
66
- try:
67
- row[key] = [float(x.strip()) for x in row[key].strip('[]').split(',') if x.strip()]
68
- except:
69
- continue
70
- row['_config'] = config
71
- row['_split'] = split
72
- processed_rows.append(row)
73
- return processed_rows
74
- except Exception as e:
75
- st.warning(f"Error fetching rows for {config}/{split}: {e}")
76
- return []
77
-
78
- def search_dataset(dataset_id, search_text, include_configs=None, include_splits=None):
79
- """
80
- Search across all configurations and splits of a dataset
81
-
82
- Args:
83
- dataset_id (str): The Hugging Face dataset ID
84
- search_text (str): Text to search for in descriptions and queries
85
- include_configs (list): List of specific configs to search, or None for all
86
- include_splits (list): List of specific splits to search, or None for all
87
-
88
- Returns:
89
- tuple: (DataFrame of results, list of available configs, list of available splits)
90
- """
91
- # Get dataset info
92
- dataset_info = fetch_dataset_info(dataset_id)
93
- if not dataset_info:
94
- return pd.DataFrame(), [], []
95
-
96
- # Get available configs and splits
97
- configs = include_configs if include_configs else dataset_info.get('config_names', ['default'])
98
- all_rows = []
99
- available_splits = set()
100
-
101
- # Search across configs and splits
102
- for config in configs:
103
- try:
104
- # First fetch split info for this config
105
- splits_url = f"https://datasets-server.huggingface.co/splits?dataset={dataset_id}&config={config}"
106
- splits_response = requests.get(splits_url, timeout=30)
107
- if splits_response.status_code == 200:
108
- splits_data = splits_response.json()
109
- splits = [split['split'] for split in splits_data.get('splits', [])]
110
- if not splits:
111
- splits = ['train'] # fallback to train if no splits found
112
-
113
- # Filter splits if specified
114
- if include_splits:
115
- splits = [s for s in splits if s in include_splits]
116
-
117
- available_splits.update(splits)
118
-
119
- # Fetch and search rows for each split
120
- for split in splits:
121
- rows = fetch_dataset_rows(dataset_id, config, split)
122
- for row in rows:
123
- # Search in all text fields
124
- text_content = ' '.join(str(v) for v in row.values() if isinstance(v, (str, int, float)))
125
- if search_text.lower() in text_content.lower():
126
- row['_matched_text'] = text_content
127
- row['_relevance_score'] = text_content.lower().count(search_text.lower())
128
- all_rows.append(row)
129
-
130
- except Exception as e:
131
- st.warning(f"Error processing config {config}: {e}")
132
- continue
133
-
134
- # Convert to DataFrame and sort by relevance
135
- if all_rows:
136
- df = pd.DataFrame(all_rows)
137
- df = df.sort_values('_relevance_score', ascending=False)
138
- return df, configs, list(available_splits)
139
-
140
- return pd.DataFrame(), configs, list(available_splits)
141
-
142
- class VideoSearch:
143
- def __init__(self):
144
  self.text_model = SentenceTransformer('all-MiniLM-L6-v2')
145
- self.dataset_id = "omegalabsinc/omega-multimodal"
 
 
 
146
  self.load_dataset()
147
-
148
- def fetch_dataset_rows(self):
149
- """Fetch dataset with enhanced search capabilities"""
150
  try:
151
- # First try to get all available data
152
- df, configs, splits = search_dataset(
153
- self.dataset_id,
154
- "", # empty search text to get all data
155
- include_configs=None, # all configs
156
- include_splits=None # all splits
157
- )
 
 
 
 
 
158
 
159
- if not df.empty:
160
- st.session_state['search_columns'] = [col for col in df.columns
161
- if col not in ['video_embed', 'description_embed', 'audio_embed']
162
- and not col.startswith('_')]
163
- return df
164
-
165
- return self.load_example_data()
 
 
 
 
 
166
 
167
  except Exception as e:
168
- st.warning(f"Error loading dataset: {e}")
169
- return self.load_example_data()
170
-
171
- def load_example_data(self):
172
- """Load example data as fallback"""
173
- example_data = [
174
- {
175
- "video_id": "cd21da96-fcca-4c94-a60f-0b1e4e1e29fc",
176
- "youtube_id": "IO-vwtyicn4",
177
- "description": "This video shows a close-up of an ancient text carved into a surface.",
178
- "views": 45489,
179
- "start_time": 1452,
180
- "end_time": 1458,
181
- "video_embed": [0.014160037972033024, -0.003111184574663639, -0.016604168340563774],
182
- "description_embed": [-0.05835828185081482, 0.02589797042310238, 0.11952091753482819]
183
- }
184
- ]
185
- return pd.DataFrame(example_data)
186
 
187
  def prepare_features(self):
188
- """Prepare embeddings with adaptive field detection"""
189
  try:
190
- embed_cols = [col for col in self.dataset.columns
191
- if any(term in col.lower() for term in ['embed', 'vector', 'encoding'])]
192
 
193
- embeddings = {}
194
- for col in embed_cols:
195
- try:
196
- data = []
197
- for row in self.dataset[col]:
198
- if isinstance(row, str):
199
- values = [float(x.strip()) for x in row.strip('[]').split(',') if x.strip()]
200
- elif isinstance(row, list):
201
- values = row
202
- else:
203
- continue
204
- data.append(values)
205
-
206
- if data:
207
- embeddings[col] = np.array(data)
208
- except:
209
- continue
210
 
211
- # Set main embeddings for search
212
- if 'video_embed' in embeddings:
213
- self.video_embeds = embeddings['video_embed']
214
- else:
215
- self.video_embeds = next(iter(embeddings.values()))
216
-
217
- if 'description_embed' in embeddings:
218
- self.text_embeds = embeddings['description_embed']
219
- else:
220
- self.text_embeds = self.video_embeds
221
-
222
- except:
223
- # Fallback to random embeddings
224
- num_rows = len(self.dataset)
225
- self.video_embeds = np.random.randn(num_rows, 384)
226
- self.text_embeds = np.random.randn(num_rows, 384)
227
-
228
- def load_dataset(self):
229
- self.dataset = self.fetch_dataset_rows()
230
- self.prepare_features()
231
 
232
  def search(self, query, column=None, top_k=20):
 
 
 
 
 
233
  query_embedding = self.text_model.encode([query])[0]
234
- video_sims = cosine_similarity([query_embedding], self.video_embeds)[0]
235
- text_sims = cosine_similarity([query_embedding], self.text_embeds)[0]
236
- combined_sims = 0.5 * video_sims + 0.5 * text_sims
 
 
 
 
 
 
 
237
 
238
- # Column filtering
239
- if column and column in self.dataset.columns and column != "All Fields":
240
- mask = self.dataset[column].astype(str).str.contains(query, case=False)
241
- combined_sims[~mask] *= 0.5
242
 
243
- top_k = min(top_k, 100)
244
- top_indices = np.argsort(combined_sims)[-top_k:][::-1]
 
245
 
 
246
  results = []
247
  for idx in top_indices:
248
- result = {'relevance_score': float(combined_sims[idx])}
249
- for col in self.dataset.columns:
250
- if col not in ['video_embed', 'description_embed', 'audio_embed']:
251
- result[col] = self.dataset.iloc[idx][col]
 
 
252
  results.append(result)
253
 
254
  return results
255
 
256
- @st.cache_resource
257
- def get_speech_model():
258
- return edge_tts.Communicate
259
-
260
- async def generate_speech(text, voice=None):
261
- if not text.strip():
262
- return None
263
- if not voice:
264
- voice = st.session_state['tts_voice']
265
- try:
266
- communicate = get_speech_model()(text, voice)
267
- audio_file = f"speech_{datetime.now().strftime('%Y%m%d_%H%M%S')}.mp3"
268
- await communicate.save(audio_file)
269
- return audio_file
270
- except Exception as e:
271
- st.error(f"Error generating speech: {e}")
272
- return None
273
-
274
- def transcribe_audio(audio_path):
275
- """Placeholder for ASR transcription"""
276
- return "ASR not implemented. Integrate a local model or another service here."
277
 
278
- def show_file_manager():
279
- """Display file manager interface"""
280
- st.subheader("📂 File Manager")
281
- col1, col2 = st.columns(2)
282
  with col1:
283
- uploaded_file = st.file_uploader("Upload File", type=['txt', 'md', 'mp3'])
284
- if uploaded_file:
285
- with open(uploaded_file.name, "wb") as f:
286
- f.write(uploaded_file.getvalue())
287
- st.success(f"Uploaded: {uploaded_file.name}")
288
- st.experimental_rerun()
 
 
 
 
 
 
 
 
 
 
289
 
290
  with col2:
291
- if st.button("🗑 Clear All Files"):
292
- for f in glob.glob("*.txt") + glob.glob("*.md") + glob.glob("*.mp3"):
293
- os.remove(f)
294
- st.success("All files cleared!")
295
- st.experimental_rerun()
296
-
297
- files = glob.glob("*.txt") + glob.glob("*.md") + glob.glob("*.mp3")
298
- if files:
299
- st.write("### Existing Files")
300
- for f in files:
301
- with st.expander(f"📄 {os.path.basename(f)}"):
302
- if f.endswith('.mp3'):
303
- st.audio(f)
304
- else:
305
- with open(f, 'r', encoding='utf-8') as file:
306
- st.text_area("Content", file.read(), height=100)
307
- if st.button(f"Delete {os.path.basename(f)}", key=f"del_{f}"):
308
- os.remove(f)
309
- st.experimental_rerun()
310
-
311
- def arxiv_search(query, max_results=5):
312
- """Perform a simple Arxiv search using their API and return top results."""
313
- base_url = "http://export.arxiv.org/api/query?"
314
- search_url = base_url + f"search_query={quote(query)}&start=0&max_results={max_results}"
315
- r = requests.get(search_url)
316
- if r.status_code == 200:
317
- root = ET.fromstring(r.text)
318
- ns = {'atom': 'http://www.w3.org/2005/Atom'}
319
- entries = root.findall('atom:entry', ns)
320
- results = []
321
- for entry in entries:
322
- title = entry.find('atom:title', ns).text.strip()
323
- summary = entry.find('atom:summary', ns).text.strip()
324
- link = None
325
- for l in entry.findall('atom:link', ns):
326
- if l.get('type') == 'text/html':
327
- link = l.get('href')
328
- break
329
- results.append((title, summary, link))
330
- return results
331
- return []
332
-
333
- def perform_arxiv_lookup(q, vocal_summary=True, titles_summary=True, full_audio=False):
334
- results = arxiv_search(q, max_results=5)
335
- if not results:
336
- st.write("No Arxiv results found.")
337
- return
338
- st.markdown(f"**Arxiv Search Results for '{q}':**")
339
- for i, (title, summary, link) in enumerate(results, start=1):
340
- st.markdown(f"**{i}. {title}**")
341
- st.write(summary)
342
- if link:
343
- st.markdown(f"[View Paper]({link})")
344
-
345
- if vocal_summary:
346
- spoken_text = f"Here are some Arxiv results for {q}. "
347
- if titles_summary:
348
- spoken_text += " Titles: " + ", ".join([res[0] for res in results])
349
- else:
350
- # Just first summary if no titles_summary
351
- spoken_text += " " + results[0][1][:200]
352
-
353
- audio_file = asyncio.run(generate_speech(spoken_text))
354
- if audio_file:
355
- st.audio(audio_file)
356
-
357
- if full_audio:
358
- # Full audio of summaries
359
- full_text = ""
360
- for i,(title, summary, _) in enumerate(results, start=1):
361
- full_text += f"Result {i}: {title}. {summary} "
362
- audio_file_full = asyncio.run(generate_speech(full_text))
363
- if audio_file_full:
364
- st.write("### Full Audio")
365
- st.audio(audio_file_full)
366
 
367
  def main():
368
- st.title("🎥 Video & Arxiv Search with Voice (No OpenAI/Anthropic)")
369
 
370
  # Initialize search class
371
- search = VideoSearch()
372
 
373
  # Create tabs
374
- tab1, tab2, tab3, tab4, tab5 = st.tabs(["🔍 Search", "🎙️ Voice Input", "📚 Arxiv", "📂 Files", "🔍 Advanced Search"])
375
 
376
- # ---- Tab 1: Video Search ----
377
  with tab1:
378
  st.subheader("Search Videos")
379
  col1, col2 = st.columns([3, 1])
 
380
  with col1:
381
- query = st.text_input("Enter your search query:",
382
- value="ancient" if not st.session_state['initial_search_done'] else "")
383
  with col2:
384
  search_column = st.selectbox("Search in field:",
385
  ["All Fields"] + st.session_state['search_columns'])
@@ -390,11 +200,12 @@ def main():
390
  with col4:
391
  search_button = st.button("🔍 Search")
392
 
393
- if (search_button or not st.session_state['initial_search_done']) and query:
394
  st.session_state['initial_search_done'] = True
395
  selected_column = None if search_column == "All Fields" else search_column
 
396
  with st.spinner("Searching..."):
397
- results = search.search(query, selected_column, num_results)
398
 
399
  st.session_state['search_history'].append({
400
  'query': query,
@@ -403,151 +214,34 @@ def main():
403
  })
404
 
405
  for i, result in enumerate(results, 1):
406
- with st.expander(f"Result {i}: {result['description'][:100]}...", expanded=(i==1)):
407
- cols = st.columns([2, 1])
408
- with cols[0]:
409
- st.markdown("**Description:**")
410
- st.write(result['description'])
411
- st.markdown(f"**Time Range:** {result['start_time']}s - {result['end_time']}s")
412
- st.markdown(f"**Views:** {result['views']:,}")
413
-
414
- with cols[1]:
415
- st.markdown(f"**Relevance Score:** {result['relevance_score']:.2%}")
416
- if result.get('youtube_id'):
417
- st.video(f"https://youtube.com/watch?v={result['youtube_id']}&t={result['start_time']}")
418
-
419
- if st.button(f"🔊 Audio Summary", key=f"audio_{i}"):
420
- summary = f"Video summary: {result['description'][:200]}"
421
- audio_file = asyncio.run(generate_speech(summary))
422
- if audio_file:
423
- st.audio(audio_file)
424
-
425
- # ---- Tab 2: Voice Input ----
426
  with tab2:
427
- st.subheader("Voice Input")
428
- st.write("🎙️ Record your voice:")
429
- audio_bytes = audio_recorder()
430
- if audio_bytes:
431
- audio_path = f"temp_audio_{datetime.now().strftime('%Y%m%d_%H%M%S')}.wav"
432
- with open(audio_path, "wb") as f:
433
- f.write(audio_bytes)
434
- st.success("Audio recorded successfully!")
435
-
436
- voice_query = transcribe_audio(audio_path)
437
- st.markdown("**Transcribed Text:**")
438
- st.write(voice_query)
439
- st.session_state['last_voice_input'] = voice_query
440
 
441
- if st.button("🔍 Search from Voice"):
442
- results = search.search(voice_query, None, 20)
443
- for i, result in enumerate(results, 1):
444
- with st.expander(f"Result {i}", expanded=(i==1)):
445
- st.write(result['description'])
446
- if result.get('youtube_id'):
447
- st.video(f"https://youtube.com/watch?v={result['youtube_id']}&t={result.get('start_time', 0)}")
448
 
449
- if os.path.exists(audio_path):
450
- os.remove(audio_path)
451
-
452
- # ---- Tab 3: Arxiv Search ----
453
- with tab3:
454
- st.subheader("Arxiv Search")
455
- q = st.text_input("Enter your Arxiv search query:", value=st.session_state['arxiv_last_query'])
456
- vocal_summary = st.checkbox("🎙 Short Audio Summary", value=True)
457
- titles_summary = st.checkbox("🔖 Titles Only", value=True)
458
- full_audio = st.checkbox("📚 Full Audio Results", value=False)
459
-
460
- if st.button("🔍 Arxiv Search"):
461
- st.session_state['arxiv_last_query'] = q
462
- perform_arxiv_lookup(q, vocal_summary=vocal_summary, titles_summary=titles_summary, full_audio=full_audio)
463
-
464
- # ---- Tab 4: File Manager ----
465
- with tab4:
466
- show_file_manager()
467
-
468
- # ---- Tab 5: Advanced Dataset Search ----
469
- with tab5:
470
- st.subheader("Advanced Dataset Search")
471
-
472
- # Dataset input
473
- dataset_id = st.text_input("Dataset ID:", value="omegalabsinc/omega-multimodal")
474
-
475
- # Search configuration
476
- col1, col2 = st.columns([2, 1])
477
- with col1:
478
- search_text = st.text_input("Search text:",
479
- placeholder="Enter text to search across all fields")
480
-
481
- # Get available configs and splits
482
- if dataset_id:
483
- dataset_info = fetch_dataset_info(dataset_id)
484
- if dataset_info:
485
- configs = dataset_info.get('config_names', ['default'])
486
- with col2:
487
- selected_configs = st.multiselect(
488
- "Configurations:",
489
- options=configs,
490
- default=['default'] if 'default' in configs else None
491
- )
492
-
493
- # Fetch available splits
494
- if selected_configs:
495
- all_splits = set()
496
- for config in selected_configs:
497
- splits_url = f"https://datasets-server.huggingface.co/splits?dataset={dataset_id}&config={config}"
498
- try:
499
- response = requests.get(splits_url, timeout=30)
500
- if response.status_code == 200:
501
- splits_data = response.json()
502
- splits = [split['split'] for split in splits_data.get('splits', [])]
503
- all_splits.update(splits)
504
- except Exception as e:
505
- st.warning(f"Error fetching splits for {config}: {e}")
506
-
507
- selected_splits = st.multiselect(
508
- "Splits:",
509
- options=list(all_splits),
510
- default=['train'] if 'train' in all_splits else None
511
- )
512
-
513
- if st.button("🔍 Search Dataset"):
514
- with st.spinner("Searching dataset..."):
515
- results_df, _, _ = search_dataset(
516
- dataset_id,
517
- search_text,
518
- include_configs=selected_configs,
519
- include_splits=selected_splits
520
- )
521
-
522
- if not results_df.empty:
523
- st.write(f"Found {len(results_df)} results")
524
-
525
- # Display results in expandable sections
526
- for idx, row in results_df.iterrows():
527
- with st.expander(
528
- f"Result {idx+1} (Config: {row['_config']}, Split: {row['_split']}, Score: {row['_relevance_score']})"
529
- ):
530
- # Display all fields except internal ones
531
- for col in row.index:
532
- if not col.startswith('_') and not any(
533
- term in col.lower()
534
- for term in ['embed', 'vector', 'encoding']
535
- ):
536
- st.write(f"**{col}:** {row[col]}")
537
-
538
- # Add buttons for audio/video if available
539
- if 'youtube_id' in row:
540
- st.video(
541
- f"https://youtube.com/watch?v={row['youtube_id']}&t={row.get('start_time', 0)}"
542
- )
543
- else:
544
- st.warning("No results found.")
545
- else:
546
- st.error("Unable to fetch dataset information.")
547
 
548
  # Sidebar
549
  with st.sidebar:
550
- st.subheader("⚙️ Settings & History")
551
  if st.button("🗑️ Clear History"):
552
  st.session_state['search_history'] = []
553
  st.experimental_rerun()
@@ -556,12 +250,7 @@ def main():
556
  for entry in reversed(st.session_state['search_history'][-5:]):
557
  with st.expander(f"{entry['timestamp']}: {entry['query']}"):
558
  for i, result in enumerate(entry['results'], 1):
559
- st.write(f"{i}. {result['description'][:100]}...")
560
-
561
- st.markdown("### Voice Settings")
562
- st.selectbox("TTS Voice:",
563
- ["en-US-AriaNeural", "en-US-GuyNeural", "en-GB-SoniaNeural"],
564
- key="tts_voice")
565
 
566
  if __name__ == "__main__":
567
  main()
 
3
  import numpy as np
4
  from sentence_transformers import SentenceTransformer
5
  from sklearn.metrics.pairwise import cosine_similarity
 
 
6
  import os
 
 
7
  from datetime import datetime
 
 
 
8
  import requests
9
+ from datasets import load_dataset
 
 
10
  from urllib.parse import quote
 
11
 
12
  # Initialize session state
13
  if 'search_history' not in st.session_state:
14
  st.session_state['search_history'] = []
 
 
 
 
 
 
15
  if 'search_columns' not in st.session_state:
16
  st.session_state['search_columns'] = []
17
  if 'initial_search_done' not in st.session_state:
18
  st.session_state['initial_search_done'] = False
19
+ if 'dataset' not in st.session_state:
20
+ st.session_state['dataset'] = None
 
 
21
 
22
+ class DatasetSearcher:
23
+ def __init__(self, dataset_id="tomg-group-umd/cinepile"):
24
+ self.dataset_id = dataset_id
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  self.text_model = SentenceTransformer('all-MiniLM-L6-v2')
26
+ self.token = os.environ.get('DATASET_KEY')
27
+ if not self.token:
28
+ st.error("Please set the DATASET_KEY environment variable with your Hugging Face token.")
29
+ st.stop()
30
  self.load_dataset()
31
+
32
+ def load_dataset(self):
33
+ """Load dataset using the datasets library"""
34
  try:
35
+ if st.session_state['dataset'] is None:
36
+ with st.spinner("Loading dataset..."):
37
+ st.session_state['dataset'] = load_dataset(
38
+ self.dataset_id,
39
+ token=self.token,
40
+ streaming=False
41
+ )
42
+
43
+ self.dataset = st.session_state['dataset']
44
+ # Convert first split to DataFrame for easier processing
45
+ first_split = next(iter(self.dataset.values()))
46
+ self.df = pd.DataFrame(first_split)
47
 
48
+ # Store column information
49
+ self.columns = list(self.df.columns)
50
+ self.text_columns = [col for col in self.columns
51
+ if self.df[col].dtype == 'object'
52
+ and not any(term in col.lower()
53
+ for term in ['embed', 'vector', 'encoding'])]
54
+
55
+ # Update session state columns
56
+ st.session_state['search_columns'] = self.text_columns
57
+
58
+ # Prepare text embeddings
59
+ self.prepare_features()
60
 
61
  except Exception as e:
62
+ st.error(f"Error loading dataset: {str(e)}")
63
+ st.error("Please check your authentication token and internet connection.")
64
+ st.stop()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  def prepare_features(self):
67
+ """Prepare text embeddings for semantic search"""
68
  try:
69
+ # Combine text columns for embedding
70
+ combined_text = self.df[self.text_columns].fillna('').agg(' '.join, axis=1)
71
 
72
+ # Create embeddings in batches to manage memory
73
+ batch_size = 32
74
+ all_embeddings = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
+ with st.spinner("Preparing search features..."):
77
+ for i in range(0, len(combined_text), batch_size):
78
+ batch = combined_text[i:i+batch_size].tolist()
79
+ embeddings = self.text_model.encode(batch)
80
+ all_embeddings.append(embeddings)
81
+
82
+ self.text_embeddings = np.vstack(all_embeddings)
83
+
84
+ except Exception as e:
85
+ st.warning(f"Error preparing features: {str(e)}")
86
+ self.text_embeddings = np.random.randn(len(self.df), 384)
 
 
 
 
 
 
 
 
 
87
 
88
  def search(self, query, column=None, top_k=20):
89
+ """Search the dataset using semantic and keyword matching"""
90
+ if self.df.empty:
91
+ return []
92
+
93
+ # Get semantic similarity scores
94
  query_embedding = self.text_model.encode([query])[0]
95
+ similarities = cosine_similarity([query_embedding], self.text_embeddings)[0]
96
+
97
+ # Get keyword match scores
98
+ search_columns = [column] if column and column != "All Fields" else self.text_columns
99
+ keyword_scores = np.zeros(len(self.df))
100
+
101
+ for col in search_columns:
102
+ if col in self.df.columns:
103
+ matches = self.df[col].fillna('').str.lower().str.count(query.lower())
104
+ keyword_scores += matches
105
 
106
+ # Combine scores
107
+ combined_scores = 0.5 * similarities + 0.5 * (keyword_scores / max(1, keyword_scores.max()))
 
 
108
 
109
+ # Get top results
110
+ top_k = min(top_k, len(combined_scores))
111
+ top_indices = np.argsort(combined_scores)[-top_k:][::-1]
112
 
113
+ # Format results
114
  results = []
115
  for idx in top_indices:
116
+ result = {
117
+ 'relevance_score': float(combined_scores[idx]),
118
+ 'semantic_score': float(similarities[idx]),
119
+ 'keyword_score': float(keyword_scores[idx]),
120
+ **self.df.iloc[idx].to_dict()
121
+ }
122
  results.append(result)
123
 
124
  return results
125
 
126
+ def get_dataset_info(self):
127
+ """Get information about the dataset"""
128
+ if not self.dataset:
129
+ return {}
130
+
131
+ info = {
132
+ 'splits': list(self.dataset.keys()),
133
+ 'total_rows': sum(split.num_rows for split in self.dataset.values()),
134
+ 'columns': self.columns,
135
+ 'text_columns': self.text_columns,
136
+ 'sample_rows': len(self.df),
137
+ 'embeddings_shape': self.text_embeddings.shape
138
+ }
139
+
140
+ return info
 
 
 
 
 
 
141
 
142
+ def render_video_result(result):
143
+ """Render a video result with enhanced display"""
144
+ col1, col2 = st.columns([2, 1])
145
+
146
  with col1:
147
+ if 'title' in result:
148
+ st.markdown(f"**Title:** {result['title']}")
149
+ if 'description' in result:
150
+ st.markdown("**Description:**")
151
+ st.write(result['description'])
152
+
153
+ # Show timing information if available
154
+ if 'start_time' in result and 'end_time' in result:
155
+ st.markdown(f"**Time Range:** {result['start_time']}s - {result['end_time']}s")
156
+
157
+ # Show additional metadata
158
+ for key, value in result.items():
159
+ if key not in ['title', 'description', 'start_time', 'end_time', 'duration',
160
+ 'relevance_score', 'semantic_score', 'keyword_score',
161
+ 'video_id', 'youtube_id']:
162
+ st.markdown(f"**{key.replace('_', ' ').title()}:** {value}")
163
 
164
  with col2:
165
+ # Show search scores
166
+ st.markdown("**Search Scores:**")
167
+ cols = st.columns(3)
168
+ cols[0].metric("Overall", f"{result['relevance_score']:.2%}")
169
+ cols[1].metric("Semantic", f"{result['semantic_score']:.2%}")
170
+ cols[2].metric("Keyword", f"{result['keyword_score']:.0f} matches")
171
+
172
+ # Display video if available
173
+ if 'youtube_id' in result:
174
+ st.video(f"https://youtube.com/watch?v={result['youtube_id']}&t={result.get('start_time', 0)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
 
176
  def main():
177
+ st.title("🎥 Video Dataset Search")
178
 
179
  # Initialize search class
180
+ searcher = DatasetSearcher()
181
 
182
  # Create tabs
183
+ tab1, tab2 = st.tabs(["🔍 Search", "📊 Dataset Info"])
184
 
185
+ # ---- Tab 1: Search ----
186
  with tab1:
187
  st.subheader("Search Videos")
188
  col1, col2 = st.columns([3, 1])
189
+
190
  with col1:
191
+ query = st.text_input("Search query:",
192
+ value="" if st.session_state['initial_search_done'] else "")
193
  with col2:
194
  search_column = st.selectbox("Search in field:",
195
  ["All Fields"] + st.session_state['search_columns'])
 
200
  with col4:
201
  search_button = st.button("🔍 Search")
202
 
203
+ if search_button and query:
204
  st.session_state['initial_search_done'] = True
205
  selected_column = None if search_column == "All Fields" else search_column
206
+
207
  with st.spinner("Searching..."):
208
+ results = searcher.search(query, selected_column, num_results)
209
 
210
  st.session_state['search_history'].append({
211
  'query': query,
 
214
  })
215
 
216
  for i, result in enumerate(results, 1):
217
+ with st.expander(
218
+ f"Result {i}: {result.get('title', result.get('description', 'No title'))[:100]}...",
219
+ expanded=(i==1)
220
+ ):
221
+ render_video_result(result)
222
+
223
+ # ---- Tab 2: Dataset Info ----
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  with tab2:
225
+ st.subheader("Dataset Information")
226
+
227
+ info = searcher.get_dataset_info()
228
+ if info:
229
+ st.write(f"### Dataset: {searcher.dataset_id}")
230
+ st.write(f"- Total rows: {info['total_rows']:,}")
231
+ st.write(f"- Available splits: {', '.join(info['splits'])}")
232
+ st.write(f"- Number of columns: {len(info['columns'])}")
233
+ st.write(f"- Searchable text columns: {', '.join(info['text_columns'])}")
 
 
 
 
234
 
235
+ st.write("### Sample Data")
236
+ st.dataframe(searcher.df.head())
 
 
 
 
 
237
 
238
+ st.write("### Column Details")
239
+ for col in info['columns']:
240
+ st.write(f"- **{col}**: {searcher.df[col].dtype}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
 
242
  # Sidebar
243
  with st.sidebar:
244
+ st.subheader("⚙️ Search History")
245
  if st.button("🗑️ Clear History"):
246
  st.session_state['search_history'] = []
247
  st.experimental_rerun()
 
250
  for entry in reversed(st.session_state['search_history'][-5:]):
251
  with st.expander(f"{entry['timestamp']}: {entry['query']}"):
252
  for i, result in enumerate(entry['results'], 1):
253
+ st.write(f"{i}. {result.get('title', result.get('description', 'No title'))[:100]}...")
 
 
 
 
 
254
 
255
  if __name__ == "__main__":
256
  main()