Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pandas as pd | |
| import numpy as np | |
| from sentence_transformers import SentenceTransformer | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| import torch | |
| import json | |
| import os | |
| import glob | |
| from pathlib import Path | |
| from datetime import datetime | |
| import edge_tts | |
| import asyncio | |
| import base64 | |
| import requests | |
| from collections import defaultdict | |
| from audio_recorder_streamlit import audio_recorder | |
| import streamlit.components.v1 as components | |
| import re | |
| from urllib.parse import quote | |
| from xml.etree import ElementTree as ET | |
| # Initialize session state | |
| if 'search_history' not in st.session_state: | |
| st.session_state['search_history'] = [] | |
| if 'last_voice_input' not in st.session_state: | |
| st.session_state['last_voice_input'] = "" | |
| if 'transcript_history' not in st.session_state: | |
| st.session_state['transcript_history'] = [] | |
| if 'should_rerun' not in st.session_state: | |
| st.session_state['should_rerun'] = False | |
| if 'search_columns' not in st.session_state: | |
| st.session_state['search_columns'] = [] | |
| if 'initial_search_done' not in st.session_state: | |
| st.session_state['initial_search_done'] = False | |
| if 'tts_voice' not in st.session_state: | |
| st.session_state['tts_voice'] = "en-US-AriaNeural" | |
| if 'arxiv_last_query' not in st.session_state: | |
| st.session_state['arxiv_last_query'] = "" | |
| if 'old_val' not in st.session_state: | |
| st.session_state['old_val'] = None | |
| def highlight_text(text, query): | |
| """Highlight case-insensitive occurrences of query in text with bold formatting.""" | |
| if not query: | |
| return text | |
| pattern = re.compile(re.escape(query), re.IGNORECASE) | |
| return pattern.sub(lambda m: f"**{m.group(0)}**", text) | |
| class VideoSearch: | |
| def __init__(self): | |
| self.text_model = SentenceTransformer('all-MiniLM-L6-v2') | |
| self.load_dataset() | |
| def fetch_dataset_rows(self): | |
| """Fetch dataset from Hugging Face API""" | |
| try: | |
| url = "https://datasets-server.huggingface.co/first-rows?dataset=omegalabsinc%2Fomega-multimodal&config=default&split=train" | |
| response = requests.get(url, timeout=30) | |
| if response.status_code == 200: | |
| data = response.json() | |
| if 'rows' in data: | |
| processed_rows = [] | |
| for row_data in data['rows']: | |
| row = row_data.get('row', row_data) | |
| for key in row: | |
| if any(term in key.lower() for term in ['embed', 'vector', 'encoding']): | |
| if isinstance(row[key], str): | |
| try: | |
| row[key] = [float(x.strip()) for x in row[key].strip('[]').split(',') if x.strip()] | |
| except: | |
| continue | |
| processed_rows.append(row) | |
| df = pd.DataFrame(processed_rows) | |
| st.session_state['search_columns'] = [col for col in df.columns | |
| if col not in ['video_embed', 'description_embed', 'audio_embed']] | |
| return df | |
| return self.load_example_data() | |
| except: | |
| return self.load_example_data() | |
| def prepare_features(self): | |
| """Prepare embeddings with adaptive field detection""" | |
| try: | |
| embed_cols = [col for col in self.dataset.columns | |
| if any(term in col.lower() for term in ['embed', 'vector', 'encoding'])] | |
| embeddings = {} | |
| for col in embed_cols: | |
| try: | |
| data = [] | |
| for row in self.dataset[col]: | |
| if isinstance(row, str): | |
| values = [float(x.strip()) for x in row.strip('[]').split(',') if x.strip()] | |
| elif isinstance(row, list): | |
| values = row | |
| else: | |
| continue | |
| data.append(values) | |
| if data: | |
| embeddings[col] = np.array(data) | |
| except: | |
| continue | |
| if 'video_embed' in embeddings: | |
| self.video_embeds = embeddings['video_embed'] | |
| else: | |
| self.video_embeds = next(iter(embeddings.values())) | |
| if 'description_embed' in embeddings: | |
| self.text_embeds = embeddings['description_embed'] | |
| else: | |
| self.text_embeds = self.video_embeds | |
| except: | |
| # Fallback to random embeddings | |
| num_rows = len(self.dataset) | |
| self.video_embeds = np.random.randn(num_rows, 384) | |
| self.text_embeds = np.random.randn(num_rows, 384) | |
| def load_example_data(self): | |
| """Load example data as fallback""" | |
| example_data = [ | |
| { | |
| "video_id": "cd21da96-fcca-4c94-a60f-0b1e4e1e29fc", | |
| "youtube_id": "IO-vwtyicn4", | |
| "description": "This video shows a close-up of an ancient text carved into a surface.", | |
| "views": 45489, | |
| "start_time": 1452, | |
| "end_time": 1458, | |
| "video_embed": [0.014160037972033024, -0.003111184574663639, -0.016604168340563774], | |
| "description_embed": [-0.05835828185081482, 0.02589797042310238, 0.11952091753482819] | |
| } | |
| ] | |
| return pd.DataFrame(example_data) | |
| def load_dataset(self): | |
| self.dataset = self.fetch_dataset_rows() | |
| self.prepare_features() | |
| def search(self, query, column=None, top_k=20): | |
| # Semantic search | |
| query_embedding = self.text_model.encode([query])[0] | |
| video_sims = cosine_similarity([query_embedding], self.video_embeds)[0] | |
| text_sims = cosine_similarity([query_embedding], self.text_embeds)[0] | |
| combined_sims = 0.5 * video_sims + 0.5 * text_sims | |
| # If a column is selected (not All Fields), strictly filter by textual match | |
| if column and column in self.dataset.columns and column != "All Fields": | |
| mask = self.dataset[column].astype(str).str.contains(query, case=False, na=False) | |
| # Only keep rows that contain the query text in the selected column | |
| combined_sims = combined_sims[mask] | |
| filtered_dataset = self.dataset[mask].copy() | |
| else: | |
| filtered_dataset = self.dataset.copy() | |
| # Get top results | |
| top_k = min(top_k, len(combined_sims)) | |
| if top_k == 0: | |
| return [] | |
| top_indices = np.argsort(combined_sims)[-top_k:][::-1] | |
| results = [] | |
| filtered_dataset = filtered_dataset.iloc[top_indices] | |
| filtered_sims = combined_sims[top_indices] | |
| for idx, row in zip(top_indices, filtered_dataset.itertuples()): | |
| result = {'relevance_score': float(filtered_sims[list(top_indices).index(idx)])} | |
| for col in filtered_dataset.columns: | |
| if col not in ['video_embed', 'description_embed', 'audio_embed']: | |
| result[col] = getattr(row, col) | |
| results.append(result) | |
| return results | |
| def get_speech_model(): | |
| return edge_tts.Communicate | |
| async def generate_speech(text, voice=None): | |
| if not text.strip(): | |
| return None | |
| if not voice: | |
| voice = st.session_state['tts_voice'] | |
| try: | |
| communicate = get_speech_model()(text, voice) | |
| audio_file = f"speech_{datetime.now().strftime('%Y%m%d_%H%M%S')}.mp3" | |
| await communicate.save(audio_file) | |
| return audio_file | |
| except Exception as e: | |
| st.error(f"Error generating speech: {e}") | |
| return None | |
| def show_file_manager(): | |
| """Display file manager interface""" | |
| st.subheader("π File Manager") | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| uploaded_file = st.file_uploader("Upload File", type=['txt', 'md', 'mp3']) | |
| if uploaded_file: | |
| with open(uploaded_file.name, "wb") as f: | |
| f.write(uploaded_file.getvalue()) | |
| st.success(f"Uploaded: {uploaded_file.name}") | |
| st.experimental_rerun() | |
| with col2: | |
| if st.button("π Clear All Files"): | |
| for f in glob.glob("*.txt") + glob.glob("*.md") + glob.glob("*.mp3"): | |
| os.remove(f) | |
| st.success("All files cleared!") | |
| st.experimental_rerun() | |
| files = glob.glob("*.txt") + glob.glob("*.md") + glob.glob("*.mp3") | |
| if files: | |
| st.write("### Existing Files") | |
| for f in files: | |
| with st.expander(f"π {os.path.basename(f)}"): | |
| if f.endswith('.mp3'): | |
| st.audio(f) | |
| else: | |
| with open(f, 'r', encoding='utf-8') as file: | |
| st.text_area("Content", file.read(), height=100) | |
| if st.button(f"Delete {os.path.basename(f)}", key=f"del_{f}"): | |
| os.remove(f) | |
| st.experimental_rerun() | |
| def arxiv_search(query, max_results=5): | |
| """Perform a simple Arxiv search using their API and return top results.""" | |
| base_url = "http://export.arxiv.org/api/query?" | |
| search_url = base_url + f"search_query={quote(query)}&start=0&max_results={max_results}" | |
| r = requests.get(search_url) | |
| if r.status_code == 200: | |
| root = ET.fromstring(r.text) | |
| ns = {'atom': 'http://www.w3.org/2005/Atom'} | |
| entries = root.findall('atom:entry', ns) | |
| results = [] | |
| for entry in entries: | |
| title = entry.find('atom:title', ns).text.strip() | |
| summary = entry.find('atom:summary', ns).text.strip() | |
| link = None | |
| for l in entry.findall('atom:link', ns): | |
| if l.get('type') == 'text/html': | |
| link = l.get('href') | |
| break | |
| results.append((title, summary, link)) | |
| return results | |
| return [] | |
| def perform_arxiv_lookup(q, vocal_summary=True, titles_summary=True, full_audio=False): | |
| results = arxiv_search(q, max_results=5) | |
| if not results: | |
| st.write("No Arxiv results found.") | |
| return | |
| st.markdown(f"**Arxiv Search Results for '{q}':**") | |
| for i, (title, summary, link) in enumerate(results, start=1): | |
| st.markdown(f"**{i}. {title}**") | |
| st.write(summary) | |
| if link: | |
| st.markdown(f"[View Paper]({link})") | |
| # TTS Options | |
| if vocal_summary: | |
| spoken_text = f"Here are some Arxiv results for {q}. " | |
| if titles_summary: | |
| spoken_text += " Titles: " + ", ".join([res[0] for res in results]) | |
| else: | |
| spoken_text += " " + results[0][1][:200] | |
| audio_file = asyncio.run(generate_speech(spoken_text)) | |
| if audio_file: | |
| st.audio(audio_file) | |
| if full_audio: | |
| full_text = "" | |
| for i,(title, summary, _) in enumerate(results, start=1): | |
| full_text += f"Result {i}: {title}. {summary} " | |
| audio_file_full = asyncio.run(generate_speech(full_text)) | |
| if audio_file_full: | |
| st.write("### Full Audio") | |
| st.audio(audio_file_full) | |
| def main(): | |
| st.title("π₯ Video & Arxiv Search with Voice Input") | |
| search = VideoSearch() | |
| tab1, tab2, tab3, tab4 = st.tabs(["π Search", "ποΈ Voice Input", "π Arxiv", "π Files"]) | |
| # ---- Tab 1: Video Search ---- | |
| with tab1: | |
| st.subheader("Search Videos") | |
| col1, col2 = st.columns([3, 1]) | |
| with col1: | |
| query = st.text_input("Enter your search query:", | |
| value="ancient" if not st.session_state['initial_search_done'] else "") | |
| with col2: | |
| search_column = st.selectbox("Search in field:", | |
| ["All Fields"] + st.session_state['search_columns']) | |
| col3, col4 = st.columns(2) | |
| with col3: | |
| num_results = st.slider("Number of results:", 1, 100, 20) | |
| with col4: | |
| search_button = st.button("π Search") | |
| if (search_button or not st.session_state['initial_search_done']) and query: | |
| st.session_state['initial_search_done'] = True | |
| selected_column = None if search_column == "All Fields" else search_column | |
| with st.spinner("Searching..."): | |
| results = search.search(query, selected_column, num_results) | |
| st.session_state['search_history'].append({ | |
| 'query': query, | |
| 'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S"), | |
| 'results': results[:5] | |
| }) | |
| for i, result in enumerate(results, 1): | |
| # Highlight the query in the description | |
| highlighted_desc = highlight_text(result['description'], query) | |
| with st.expander(f"Result {i}: {result['description'][:100]}...", expanded=(i==1)): | |
| cols = st.columns([2, 1]) | |
| with cols[0]: | |
| st.markdown("**Description:**") | |
| st.write(highlighted_desc) | |
| st.markdown(f"**Time Range:** {result['start_time']}s - {result['end_time']}s") | |
| st.markdown(f"**Views:** {result['views']:,}") | |
| with cols[1]: | |
| st.markdown(f"**Relevance Score:** {result['relevance_score']:.2%}") | |
| if result.get('youtube_id'): | |
| st.video(f"https://youtube.com/watch?v={result['youtube_id']}&t={result['start_time']}") | |
| if st.button(f"π Audio Summary {i}", key=f"audio_{i}"): | |
| summary = f"Video summary: {result['description'][:200]}" | |
| audio_file = asyncio.run(generate_speech(summary)) | |
| if audio_file: | |
| st.audio(audio_file) | |
| # ---- Tab 2: Voice Input ---- | |
| # Reintroduce the mycomponent from earlier code for voice input accumulation | |
| with tab2: | |
| st.subheader("Voice Input (HTML Component)") | |
| # Declare the custom component | |
| mycomponent = components.declare_component("mycomponent", path="mycomponent") | |
| # Use the component to get accumulated voice input | |
| val = mycomponent(my_input_value="Hello") | |
| if val: | |
| val_stripped = val.replace('\n', ' ') | |
| edited_input = st.text_area("βοΈ Edit Input:", value=val_stripped, height=100) | |
| # Just allow searching the videos from the edited input | |
| if st.button("π Search from Edited Voice Input"): | |
| results = search.search(edited_input, None, 20) | |
| for i, result in enumerate(results, 1): | |
| # Highlight query in description | |
| highlighted_desc = highlight_text(result['description'], edited_input) | |
| with st.expander(f"Result {i}", expanded=(i==1)): | |
| st.write(highlighted_desc) | |
| if result.get('youtube_id'): | |
| st.video(f"https://youtube.com/watch?v={result['youtube_id']}&t={result.get('start_time', 0)}") | |
| # Optionally also let user record audio via audio_recorder (not integrated with transcription) | |
| st.write("Or record audio (not ASR integrated):") | |
| audio_bytes = audio_recorder() | |
| if audio_bytes: | |
| audio_path = f"temp_audio_{datetime.now().strftime('%Y%m%d_%H%M%S')}.wav" | |
| with open(audio_path, "wb") as f: | |
| f.write(audio_bytes) | |
| st.success("Audio recorded successfully!") | |
| # No transcription is provided since no external ASR is included here. | |
| if os.path.exists(audio_path): | |
| os.remove(audio_path) | |
| # ---- Tab 3: Arxiv Search ---- | |
| with tab3: | |
| st.subheader("Arxiv Search") | |
| q = st.text_input("Enter your Arxiv search query:", value=st.session_state['arxiv_last_query']) | |
| vocal_summary = st.checkbox("π Short Audio Summary", value=True) | |
| titles_summary = st.checkbox("π Titles Only", value=True) | |
| full_audio = st.checkbox("π Full Audio Results", value=False) | |
| if st.button("π Arxiv Search"): | |
| st.session_state['arxiv_last_query'] = q | |
| perform_arxiv_lookup(q, vocal_summary=vocal_summary, titles_summary=titles_summary, full_audio=full_audio) | |
| # ---- Tab 4: File Manager ---- | |
| with tab4: | |
| show_file_manager() | |
| # Sidebar | |
| with st.sidebar: | |
| st.subheader("βοΈ Settings & History") | |
| if st.button("ποΈ Clear History"): | |
| st.session_state['search_history'] = [] | |
| st.experimental_rerun() | |
| st.markdown("### Recent Searches") | |
| for entry in reversed(st.session_state['search_history'][-5:]): | |
| with st.expander(f"{entry['timestamp']}: {entry['query']}"): | |
| for i, result in enumerate(entry['results'], 1): | |
| st.write(f"{i}. {result['description'][:100]}...") | |
| st.markdown("### Voice Settings") | |
| st.selectbox("TTS Voice:", | |
| ["en-US-AriaNeural", "en-US-GuyNeural", "en-GB-SoniaNeural"], | |
| key="tts_voice") | |
| if __name__ == "__main__": | |
| main() | |