Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pandas as pd | |
| import numpy as np | |
| import json | |
| import os | |
| import glob | |
| import random | |
| from pathlib import Path | |
| from datetime import datetime | |
| import edge_tts | |
| import asyncio | |
| import requests | |
| import streamlit.components.v1 as components | |
| import base64 | |
| import re | |
| from xml.etree import ElementTree as ET | |
| from datasets import load_dataset | |
| # -------------------- Configuration & Constants -------------------- | |
| USER_NAMES = [ | |
| "Aria", "Guy", "Sonia", "Tony", "Jenny", "Davis", "Libby", "Clara", "Liam", "Natasha", "William" | |
| ] | |
| ENGLISH_VOICES = [ | |
| "en-US-AriaNeural", "en-US-GuyNeural", "en-GB-SoniaNeural", "en-GB-TonyNeural", | |
| "en-US-JennyNeural", "en-US-DavisNeural", "en-GB-LibbyNeural", "en-CA-ClaraNeural", | |
| "en-CA-LiamNeural", "en-AU-NatashaNeural", "en-AU-WilliamNeural" | |
| ] | |
| # Map each user to a corresponding voice | |
| USER_VOICES = dict(zip(USER_NAMES, ENGLISH_VOICES)) | |
| ROWS_PER_PAGE = 100 | |
| SAVED_INPUTS_DIR = "saved_inputs" | |
| os.makedirs(SAVED_INPUTS_DIR, exist_ok=True) | |
| SESSION_VARS = { | |
| 'search_history': [], | |
| 'last_voice_input': "", | |
| 'transcript_history': [], | |
| 'should_rerun': False, | |
| 'search_columns': [], | |
| 'initial_search_done': False, | |
| 'arxiv_last_query': "", | |
| 'dataset_loaded': False, | |
| 'current_page': 0, | |
| 'data_cache': None, | |
| 'dataset_info': None, | |
| 'nps_submitted': False, | |
| 'nps_last_shown': None, | |
| 'old_val': None, | |
| 'voice_text': None, | |
| 'user_name': random.choice(USER_NAMES), | |
| 'max_items': 100, | |
| 'global_voice': "en-US-AriaNeural", | |
| 'last_arxiv_input': None | |
| } | |
| for var, default in SESSION_VARS.items(): | |
| if var not in st.session_state: | |
| st.session_state[var] = default | |
| def create_voice_component(): | |
| mycomponent = components.declare_component( | |
| "mycomponent", | |
| path="mycomponent" | |
| ) | |
| return mycomponent | |
| def clean_for_speech(text: str) -> str: | |
| text = text.replace("\n", " ") | |
| text = text.replace("</s>", " ") | |
| text = text.replace("#", "") | |
| text = re.sub(r"\(https?:\/\/[^\)]+\)", "", text) | |
| text = re.sub(r"\s+", " ", text).strip() | |
| return text | |
| async def edge_tts_generate_audio(text, voice="en-US-AriaNeural"): | |
| text = clean_for_speech(text) | |
| if not text.strip(): | |
| return None | |
| communicate = edge_tts.Communicate(text, voice) | |
| out_fn = f"speech_{datetime.now().strftime('%Y%m%d_%H%M%S_%f')}.mp3" | |
| await communicate.save(out_fn) | |
| return out_fn | |
| def speak_with_edge_tts(text, voice="en-US-AriaNeural"): | |
| return asyncio.run(edge_tts_generate_audio(text, voice)) | |
| def play_and_download_audio(file_path): | |
| if file_path and os.path.exists(file_path): | |
| st.audio(file_path) | |
| dl_link = f'<a href="data:audio/mpeg;base64,{base64.b64encode(open(file_path,"rb").read()).decode()}" download="{os.path.basename(file_path)}">Download {os.path.basename(file_path)}</a>' | |
| st.markdown(dl_link, unsafe_allow_html=True) | |
| def generate_filename(prefix, text): | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| safe_text = re.sub(r'[^\w\s-]', '', text[:50]).strip().lower() | |
| safe_text = re.sub(r'[-\s]+', '-', safe_text) | |
| return f"{prefix}_{timestamp}_{safe_text}.md" | |
| def save_input_as_md(user_name, text, prefix="input"): | |
| if not text.strip(): | |
| return | |
| fn = generate_filename(prefix, text) | |
| full_path = os.path.join(SAVED_INPUTS_DIR, fn) | |
| with open(full_path, 'w', encoding='utf-8') as f: | |
| f.write(f"# User: {user_name}\n") | |
| f.write(f"**Timestamp:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n") | |
| f.write(text) | |
| return full_path | |
| def save_response_as_md(user_name, text, prefix="response"): | |
| if not text.strip(): | |
| return | |
| fn = generate_filename(prefix, text) | |
| full_path = os.path.join(SAVED_INPUTS_DIR, fn) | |
| with open(full_path, 'w', encoding='utf-8') as f: | |
| f.write(f"# User: {user_name}\n") | |
| f.write(f"**Timestamp:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n") | |
| f.write(text) | |
| return full_path | |
| def list_saved_inputs(): | |
| files = sorted(glob.glob(os.path.join(SAVED_INPUTS_DIR, "*.md"))) | |
| return files | |
| def parse_md_file(fpath): | |
| user_line = "" | |
| ts_line = "" | |
| content_lines = [] | |
| with open(fpath, 'r', encoding='utf-8') as f: | |
| lines = f.readlines() | |
| for line in lines: | |
| if line.startswith("# User:"): | |
| user_line = line.replace("# User:", "").strip() | |
| elif line.startswith("**Timestamp:**"): | |
| ts_line = line.replace("**Timestamp:**", "").strip() | |
| else: | |
| content_lines.append(line.strip()) | |
| content = "\n".join(content_lines).strip() | |
| return user_line, ts_line, content | |
| def arxiv_search(query, max_results=3): | |
| base_url = "http://export.arxiv.org/api/query" | |
| params = { | |
| 'search_query': query.replace(' ', '+'), | |
| 'start': 0, | |
| 'max_results': max_results | |
| } | |
| response = requests.get(base_url, params=params, timeout=30) | |
| if response.status_code == 200: | |
| root = ET.fromstring(response.text) | |
| ns = {"a": "http://www.w3.org/2005/Atom"} | |
| entries = root.findall('a:entry', ns) | |
| results = [] | |
| for entry in entries: | |
| title = entry.find('a:title', ns).text.strip() | |
| summary = entry.find('a:summary', ns).text.strip() | |
| summary_short = summary[:300] + "..." | |
| results.append((title, summary_short)) | |
| return results | |
| return [] | |
| def summarize_arxiv_results(results): | |
| lines = [] | |
| for i, (title, summary) in enumerate(results, 1): | |
| lines.append(f"Result {i}: {title}\n{summary}\n") | |
| return "\n\n".join(lines) | |
| def simple_dataset_search(query, df): | |
| if df.empty or not query.strip(): | |
| return pd.DataFrame() | |
| query_terms = query.lower().split() | |
| matches = [] | |
| for idx, row in df.iterrows(): | |
| text_parts = [] | |
| for col in df.columns: | |
| val = row[col] | |
| if isinstance(val, str): | |
| text_parts.append(val.lower()) | |
| elif isinstance(val, (int, float)): | |
| text_parts.append(str(val)) | |
| full_text = " ".join(text_parts) | |
| if any(qt in full_text for qt in query_terms): | |
| matches.append(row) | |
| if matches: | |
| return pd.DataFrame(matches) | |
| return pd.DataFrame() | |
| from datasets import load_dataset | |
| def load_dataset_page(dataset_id, token, page, rows_per_page): | |
| try: | |
| start_idx = page * rows_per_page | |
| end_idx = start_idx + rows_per_page | |
| dataset = load_dataset( | |
| dataset_id, | |
| token=token, | |
| streaming=False, | |
| split=f'train[{start_idx}:{end_idx}]' | |
| ) | |
| return pd.DataFrame(dataset) | |
| except: | |
| return pd.DataFrame() | |
| class SimpleDatasetSearcher: | |
| def __init__(self, dataset_id="tomg-group-umd/cinepile"): | |
| self.dataset_id = dataset_id | |
| self.token = os.environ.get('DATASET_KEY') | |
| def load_page(self, page=0): | |
| return load_dataset_page(self.dataset_id, self.token, page, ROWS_PER_PAGE) | |
| def concatenate_mp3(files, output_file): | |
| # Naive binary concatenation of MP3 files | |
| with open(output_file, 'wb') as outfile: | |
| for f in files: | |
| with open(f, 'rb') as infile: | |
| outfile.write(infile.read()) | |
| def main(): | |
| st.title("ποΈ Voice Chat & Search") | |
| # Sidebar | |
| with st.sidebar: | |
| # Editable user name | |
| st.session_state['user_name'] = st.selectbox("Current User:", USER_NAMES, index=0) | |
| st.session_state['max_items'] = st.number_input("Max Items per search iteration:", min_value=1, max_value=1000, value=st.session_state['max_items']) | |
| st.subheader("π Saved Inputs & Responses") | |
| saved_files = list_saved_inputs() | |
| for fpath in saved_files: | |
| user, ts, content = parse_md_file(fpath) | |
| fname = os.path.basename(fpath) | |
| st.write(f"- {fname} (User: {user})") | |
| # Create voice component for input | |
| voice_component = create_voice_component() | |
| voice_val = voice_component(my_input_value="Start speaking...") | |
| # Tabs | |
| tab1, tab2, tab3, tab4 = st.tabs(["π£οΈ Voice Chat History", "π ArXiv Search", "π Dataset Search", "βοΈ Settings"]) | |
| # ------------------ Voice Chat History ------------------------- | |
| with tab1: | |
| st.subheader("Voice Chat History") | |
| files = list_saved_inputs() | |
| conversation = [] | |
| for fpath in files: | |
| user, ts, content = parse_md_file(fpath) | |
| conversation.append((user, ts, content, fpath)) | |
| # Enumerate to ensure unique keys | |
| for i, (user, ts, content, fpath) in enumerate(reversed(conversation), start=1): | |
| with st.expander(f"{ts} - {user}", expanded=False): | |
| st.write(content) | |
| # Make button key unique by including i | |
| if st.button(f"π Read Aloud {ts}-{user}", key=f"read_{i}_{fpath}"): | |
| voice = USER_VOICES.get(user, "en-US-AriaNeural") | |
| audio_file = speak_with_edge_tts(content, voice=voice) | |
| if audio_file: | |
| play_and_download_audio(audio_file) | |
| # Read entire conversation | |
| if st.button("π Read Conversation", key="read_conversation_all"): | |
| # conversation is currently reversed, re-reverse to get chronological | |
| conversation_chrono = list(reversed(conversation)) | |
| mp3_files = [] | |
| for user, ts, content, fpath in conversation_chrono: | |
| voice = USER_VOICES.get(user, "en-US-AriaNeural") | |
| audio_file = speak_with_edge_tts(content, voice=voice) | |
| if audio_file: | |
| mp3_files.append(audio_file) | |
| st.write(f"**{user} ({ts}):**") | |
| play_and_download_audio(audio_file) | |
| if mp3_files: | |
| combined_file = f"full_conversation_{datetime.now().strftime('%Y%m%d_%H%M%S')}.mp3" | |
| concatenate_mp3(mp3_files, combined_file) | |
| st.write("**Full Conversation Audio:**") | |
| play_and_download_audio(combined_file) | |
| # ------------------ ArXiv Search ------------------------- | |
| with tab2: | |
| st.subheader("ArXiv Search") | |
| edited_input = st.text_area("Enter or Edit Search Query:", value=(voice_val.strip() if voice_val else ""), height=100) | |
| autorun = st.checkbox("β‘ Auto-Run", value=True) | |
| run_arxiv = st.button("π ArXiv Search", key="run_arxiv_button") | |
| input_changed = (edited_input != st.session_state.get('old_val')) | |
| should_run_arxiv = False | |
| if autorun and input_changed and edited_input.strip(): | |
| should_run_arxiv = True | |
| if run_arxiv and edited_input.strip(): | |
| should_run_arxiv = True | |
| if should_run_arxiv and st.session_state['last_arxiv_input'] != edited_input: | |
| st.session_state['old_val'] = edited_input | |
| st.session_state['last_arxiv_input'] = edited_input | |
| save_input_as_md(st.session_state['user_name'], edited_input, prefix="input") | |
| with st.spinner("Searching ArXiv..."): | |
| results = arxiv_search(edited_input) | |
| if results: | |
| summary = summarize_arxiv_results(results) | |
| save_response_as_md(st.session_state['user_name'], summary, prefix="response") | |
| st.write(summary) | |
| # Play summary aloud | |
| voice = USER_VOICES.get(st.session_state['user_name'], "en-US-AriaNeural") | |
| audio_file = speak_with_edge_tts(summary, voice=voice) | |
| if audio_file: | |
| play_and_download_audio(audio_file) | |
| else: | |
| st.warning("No results found on ArXiv.") | |
| # ------------------ Dataset Search ------------------------- | |
| with tab3: | |
| st.subheader("Dataset Search") | |
| ds_searcher = SimpleDatasetSearcher() | |
| query = st.text_input("Enter dataset search query:") | |
| run_ds_search = st.button("Search Dataset", key="ds_search_button") | |
| num_results = st.slider("Max results:", 1, 100, 20, key="ds_max_results") | |
| if run_ds_search and query.strip(): | |
| with st.spinner("Searching dataset..."): | |
| df = ds_searcher.load_page(0) | |
| results = simple_dataset_search(query, df) | |
| if not results.empty: | |
| st.write(f"Found {len(results)} results:") | |
| shown = 0 | |
| for i, (_, row) in enumerate(results.iterrows(), 1): | |
| if shown >= num_results: | |
| break | |
| with st.expander(f"Result {i}", expanded=(i==1)): | |
| for k, v in row.items(): | |
| st.write(f"**{k}:** {v}") | |
| shown += 1 | |
| else: | |
| st.warning("No matching results found.") | |
| # ------------------ Settings Tab ------------------------- | |
| with tab4: | |
| st.subheader("Settings") | |
| if st.button("ποΈ Clear Search History", key="clear_history"): | |
| # Delete all files | |
| for fpath in list_saved_inputs(): | |
| os.remove(fpath) | |
| st.session_state['search_history'] = [] | |
| st.success("Search history cleared for everyone!") | |
| st.rerun() | |
| if __name__ == "__main__": | |
| main() | |