awacke1's picture
Update app.py
88675e3 verified
raw
history blame
14.5 kB
import streamlit as st
import pandas as pd
import numpy as np
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import torch
import json
import os
import glob
import random
from pathlib import Path
from datetime import datetime
import edge_tts
import asyncio
import requests
import streamlit.components.v1 as components
import base64
import re
from xml.etree import ElementTree as ET
from datasets import load_dataset
# -------------------- Configuration & Constants --------------------
# Exactly 11 user names and 11 voices
USER_NAMES = [
"Aria", "Guy", "Sonia", "Tony", "Jenny", "Davis", "Libby", "Clara", "Liam", "Natasha", "William"
]
ENGLISH_VOICES = [
"en-US-AriaNeural", "en-US-GuyNeural", "en-GB-SoniaNeural", "en-GB-TonyNeural",
"en-US-JennyNeural", "en-US-DavisNeural", "en-GB-LibbyNeural", "en-CA-ClaraNeural",
"en-CA-LiamNeural", "en-AU-NatashaNeural", "en-AU-WilliamNeural"
]
# Map each user to a corresponding voice
USER_VOICES = dict(zip(USER_NAMES, ENGLISH_VOICES))
ROWS_PER_PAGE = 100
MIN_SEARCH_SCORE = 0.3
EXACT_MATCH_BOOST = 2.0
SAVED_INPUTS_DIR = "saved_inputs"
os.makedirs(SAVED_INPUTS_DIR, exist_ok=True)
SESSION_VARS = {
'search_history': [],
'last_voice_input': "",
'transcript_history': [],
'should_rerun': False,
'search_columns': [],
'initial_search_done': False,
'arxiv_last_query': "",
'dataset_loaded': False,
'current_page': 0,
'data_cache': None,
'dataset_info': None,
'nps_submitted': False,
'nps_last_shown': None,
'old_val': None,
'voice_text': None,
'user_name': random.choice(USER_NAMES),
'max_items': 100,
'global_voice': "en-US-AriaNeural",
'last_arxiv_input': None # To avoid double-running ArXiv search
}
for var, default in SESSION_VARS.items():
if var not in st.session_state:
st.session_state[var] = default
@st.cache_resource
def get_model():
return SentenceTransformer('all-MiniLM-L6-v2')
def create_voice_component():
mycomponent = components.declare_component(
"mycomponent",
path="mycomponent"
)
return mycomponent
def clean_for_speech(text: str) -> str:
text = text.replace("\n", " ")
text = text.replace("</s>", " ")
text = text.replace("#", "")
text = re.sub(r"\(https?:\/\/[^\)]+\)", "", text)
text = re.sub(r"\s+", " ", text).strip()
return text
async def edge_tts_generate_audio(text, voice="en-US-AriaNeural"):
text = clean_for_speech(text)
if not text.strip():
return None
communicate = edge_tts.Communicate(text, voice)
out_fn = f"speech_{datetime.now().strftime('%Y%m%d_%H%M%S_%f')}.mp3"
await communicate.save(out_fn)
return out_fn
def speak_with_edge_tts(text, voice="en-US-AriaNeural"):
return asyncio.run(edge_tts_generate_audio(text, voice))
def play_and_download_audio(file_path):
if file_path and os.path.exists(file_path):
st.audio(file_path)
dl_link = f'<a href="data:audio/mpeg;base64,{base64.b64encode(open(file_path,"rb").read()).decode()}" download="{os.path.basename(file_path)}">Download {os.path.basename(file_path)}</a>'
st.markdown(dl_link, unsafe_allow_html=True)
def generate_filename(prefix, text):
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
safe_text = re.sub(r'[^\w\s-]', '', text[:50]).strip().lower()
safe_text = re.sub(r'[-\s]+', '-', safe_text)
return f"{prefix}_{timestamp}_{safe_text}.md"
def save_input_as_md(user_name, text, prefix="input"):
if not text.strip():
return
fn = generate_filename(prefix, text)
full_path = os.path.join(SAVED_INPUTS_DIR, fn)
with open(full_path, 'w', encoding='utf-8') as f:
f.write(f"# User: {user_name}\n")
f.write(f"**Timestamp:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
f.write(text)
return full_path
def save_response_as_md(user_name, text, prefix="response"):
if not text.strip():
return
fn = generate_filename(prefix, text)
full_path = os.path.join(SAVED_INPUTS_DIR, fn)
with open(full_path, 'w', encoding='utf-8') as f:
f.write(f"# User: {user_name}\n")
f.write(f"**Timestamp:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
f.write(text)
return full_path
def list_saved_inputs():
files = sorted(glob.glob(os.path.join(SAVED_INPUTS_DIR, "*.md")))
return files
def parse_md_file(fpath):
user_line = ""
ts_line = ""
content_lines = []
with open(fpath, 'r', encoding='utf-8') as f:
lines = f.readlines()
for line in lines:
if line.startswith("# User:"):
user_line = line.replace("# User:", "").strip()
elif line.startswith("**Timestamp:**"):
ts_line = line.replace("**Timestamp:**", "").strip()
else:
content_lines.append(line.strip())
content = "\n".join(content_lines).strip()
return user_line, ts_line, content
def arxiv_search(query, max_results=3):
base_url = "http://export.arxiv.org/api/query"
params = {
'search_query': query.replace(' ', '+'),
'start': 0,
'max_results': max_results
}
response = requests.get(base_url, params=params, timeout=30)
if response.status_code == 200:
root = ET.fromstring(response.text)
ns = {"a": "http://www.w3.org/2005/Atom"}
entries = root.findall('a:entry', ns)
results = []
for entry in entries:
title = entry.find('a:title', ns).text.strip()
summary = entry.find('a:summary', ns).text.strip()
summary_short = summary[:300] + "..."
results.append((title, summary_short))
return results
return []
def summarize_arxiv_results(results):
lines = []
for i, (title, summary) in enumerate(results, 1):
lines.append(f"Result {i}: {title}\n{summary}\n")
return "\n\n".join(lines)
# Simple dataset search: text-based substring search
def simple_dataset_search(query, df):
if df.empty or not query.strip():
return pd.DataFrame()
query_terms = query.lower().split()
matches = []
for idx, row in df.iterrows():
# Combine all text fields into one string
text_parts = []
for col in df.columns:
val = row[col]
if isinstance(val, str):
text_parts.append(val.lower())
elif isinstance(val, (int, float)):
text_parts.append(str(val))
full_text = " ".join(text_parts)
# Check if any query term is in full_text
if any(qt in full_text for qt in query_terms):
matches.append(row)
if matches:
return pd.DataFrame(matches)
return pd.DataFrame()
@st.cache_data
def load_dataset_page(dataset_id, token, page, rows_per_page):
try:
start_idx = page * rows_per_page
end_idx = start_idx + rows_per_page
dataset = load_dataset(
dataset_id,
token=token,
streaming=False,
split=f'train[{start_idx}:{end_idx}]'
)
return pd.DataFrame(dataset)
except:
return pd.DataFrame()
class SimpleDatasetSearcher:
def __init__(self, dataset_id="tomg-group-umd/cinepile"):
self.dataset_id = dataset_id
self.token = os.environ.get('DATASET_KEY')
def load_page(self, page=0):
return load_dataset_page(self.dataset_id, self.token, page, ROWS_PER_PAGE)
def concatenate_mp3(files, output_file):
# Naive binary concatenation of MP3 files
with open(output_file, 'wb') as outfile:
for f in files:
with open(f, 'rb') as infile:
outfile.write(infile.read())
def main():
st.title("πŸŽ™οΈ Voice Chat & Search")
# Sidebar
with st.sidebar:
# Editable user name
st.session_state['user_name'] = st.selectbox("Current User:", USER_NAMES, index=0)
st.session_state['max_items'] = st.number_input("Max Items per search iteration:", min_value=1, max_value=1000, value=st.session_state['max_items'])
st.subheader("πŸ“ Saved Inputs & Responses")
saved_files = list_saved_inputs()
for fpath in saved_files:
user, ts, content = parse_md_file(fpath)
fname = os.path.basename(fpath)
st.write(f"- {fname} (User: {user})")
# Create voice component for input
voice_component = create_voice_component()
voice_val = voice_component(my_input_value="Start speaking...")
# Tabs
tab1, tab2, tab3, tab4 = st.tabs(["πŸ—£οΈ Voice Chat History", "πŸ“š ArXiv Search", "πŸ“Š Dataset Search", "βš™οΈ Settings"])
# ------------------ Voice Chat History -------------------------
with tab1:
st.subheader("Voice Chat History")
files = list_saved_inputs()
conversation = []
for fpath in files:
user, ts, content = parse_md_file(fpath)
conversation.append((user, ts, content))
for user, ts, content in reversed(conversation):
with st.expander(f"{ts} - {user}", expanded=False):
st.write(content)
if st.button(f"πŸ”Š Read Aloud {ts}-{user}", key=f"read_{fpath}"):
voice = USER_VOICES.get(user, "en-US-AriaNeural")
audio_file = speak_with_edge_tts(content, voice=voice)
if audio_file:
play_and_download_audio(audio_file)
# Read entire conversation
if st.button("πŸ“œ Read Conversation"):
# Sort by timestamp to ensure chronological order
# Already in order because files is sorted, but let's rely on chronological order:
# They are sorted ascending, so conversation is appended ascending.
# It's safe to assume files list is chronological by filename.
mp3_files = []
for user, ts, content in conversation:
voice = USER_VOICES.get(user, "en-US-AriaNeural")
audio_file = speak_with_edge_tts(content, voice=voice)
if audio_file:
mp3_files.append(audio_file)
# Show each line's MP3
st.write(f"**{user} ({ts}):**")
play_and_download_audio(audio_file)
if mp3_files:
# Concatenate all mp3 files into one
combined_file = f"full_conversation_{datetime.now().strftime('%Y%m%d_%H%M%S')}.mp3"
concatenate_mp3(mp3_files, combined_file)
st.write("**Full Conversation Audio:**")
play_and_download_audio(combined_file)
# ------------------ ArXiv Search -------------------------
with tab2:
st.subheader("ArXiv Search")
edited_input = st.text_area("Enter or Edit Search Query:", value=(voice_val.strip() if voice_val else ""), height=100)
autorun = st.checkbox("⚑ Auto-Run", value=True)
run_arxiv = st.button("πŸ” ArXiv Search")
input_changed = (edited_input != st.session_state.get('old_val'))
# Only run once:
# Conditions to run ArXiv search:
# - If autorun and input_changed and edited_input non-empty
# - Or if run_arxiv button is pressed and edited_input non-empty
should_run_arxiv = False
if autorun and input_changed and edited_input.strip():
should_run_arxiv = True
if run_arxiv and edited_input.strip():
should_run_arxiv = True
if should_run_arxiv:
st.session_state['old_val'] = edited_input
# Avoid double-running by checking if last_arxiv_input is same
if st.session_state['last_arxiv_input'] != edited_input:
st.session_state['last_arxiv_input'] = edited_input
save_input_as_md(st.session_state['user_name'], edited_input, prefix="input")
with st.spinner("Searching ArXiv..."):
results = arxiv_search(edited_input)
if results:
summary = summarize_arxiv_results(results)
save_response_as_md(st.session_state['user_name'], summary, prefix="response")
st.write(summary)
# Play summary aloud
voice = USER_VOICES.get(st.session_state['user_name'], "en-US-AriaNeural")
audio_file = speak_with_edge_tts(summary, voice=voice)
if audio_file:
play_and_download_audio(audio_file)
else:
st.warning("No results found on ArXiv.")
# ------------------ Dataset Search -------------------------
with tab3:
st.subheader("Dataset Search")
ds_searcher = SimpleDatasetSearcher()
query = st.text_input("Enter dataset search query:")
run_ds_search = st.button("Search Dataset")
num_results = st.slider("Max results:", 1, 100, 20)
if run_ds_search and query.strip():
with st.spinner("Searching dataset..."):
# For simplicity, just load first page
df = ds_searcher.load_page(0)
results = simple_dataset_search(query, df)
if not results.empty:
st.write(f"Found {len(results)} results:")
shown = 0
for i, (_, row) in enumerate(results.iterrows(), 1):
if shown >= num_results:
break
with st.expander(f"Result {i}", expanded=(i==1)):
for k, v in row.items():
st.write(f"**{k}:** {v}")
shown += 1
else:
st.warning("No matching results found.")
# ------------------ Settings Tab -------------------------
with tab4:
st.subheader("Settings")
# Clear search history: deletes all md files and clears session
if st.button("πŸ—‘οΈ Clear Search History"):
# Delete all files
for fpath in list_saved_inputs():
os.remove(fpath)
st.session_state['search_history'] = []
st.success("Search history cleared for everyone!")
st.experimental_rerun()
if __name__ == "__main__":
main()