Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pandas as pd | |
import numpy as np | |
from sentence_transformers import SentenceTransformer | |
from sklearn.metrics.pairwise import cosine_similarity | |
import torch | |
import json | |
import os | |
import glob | |
import random | |
from pathlib import Path | |
from datetime import datetime, timedelta | |
import edge_tts | |
import asyncio | |
import requests | |
from collections import defaultdict | |
import streamlit.components.v1 as components | |
from urllib.parse import quote | |
from xml.etree import ElementTree as ET | |
from datasets import load_dataset | |
import base64 | |
import re | |
# -------------------- Configuration & Constants -------------------- | |
USER_NAMES = [ | |
"Alex", "Jordan", "Taylor", "Morgan", "Rowan", "Avery", "Riley", "Quinn", | |
"Casey", "Jesse", "Reese", "Skyler", "Ellis", "Devon", "Aubrey", "Kendall", | |
"Parker", "Dakota", "Sage", "Finley" | |
] | |
ENGLISH_VOICES = [ | |
"en-US-AriaNeural", "en-US-GuyNeural", "en-GB-SoniaNeural", "en-GB-TonyNeural", | |
"en-US-JennyNeural", "en-US-DavisNeural", "en-GB-LibbyNeural", "en-CA-ClaraNeural", | |
"en-CA-LiamNeural", "en-AU-NatashaNeural", "en-AU-WilliamNeural" | |
] | |
ROWS_PER_PAGE = 100 | |
MIN_SEARCH_SCORE = 0.3 | |
EXACT_MATCH_BOOST = 2.0 | |
SAVED_INPUTS_DIR = "saved_inputs" | |
os.makedirs(SAVED_INPUTS_DIR, exist_ok=True) | |
SESSION_VARS = { | |
'search_history': [], | |
'last_voice_input': "", | |
'transcript_history': [], | |
'should_rerun': False, | |
'search_columns': [], | |
'initial_search_done': False, | |
'tts_voice': "en-US-AriaNeural", | |
'arxiv_last_query': "", | |
'dataset_loaded': False, | |
'current_page': 0, | |
'data_cache': None, | |
'dataset_info': None, | |
'nps_submitted': False, | |
'nps_last_shown': None, | |
'old_val': None, | |
'voice_text': None, | |
'user_name': random.choice(USER_NAMES), | |
'max_items': 100, | |
'global_voice': "en-US-AriaNeural" # Default global voice | |
} | |
for var, default in SESSION_VARS.items(): | |
if var not in st.session_state: | |
st.session_state[var] = default | |
def get_model(): | |
return SentenceTransformer('all-MiniLM-L6-v2') | |
def create_voice_component(): | |
mycomponent = components.declare_component( | |
"mycomponent", | |
path="mycomponent" | |
) | |
return mycomponent | |
def clean_for_speech(text: str) -> str: | |
text = text.replace("\n", " ") | |
text = text.replace("</s>", " ") | |
text = text.replace("#", "") | |
text = re.sub(r"\(https?:\/\/[^\)]+\)", "", text) | |
text = re.sub(r"\s+", " ", text).strip() | |
return text | |
async def edge_tts_generate_audio(text, voice="en-US-AriaNeural", rate=0, pitch=0): | |
text = clean_for_speech(text) | |
if not text.strip(): | |
return None | |
rate_str = f"{rate:+d}%" | |
pitch_str = f"{pitch:+d}Hz" | |
communicate = edge_tts.Communicate(text, voice, rate=rate_str, pitch=pitch_str) | |
out_fn = f"speech_{datetime.now().strftime('%Y%m%d_%H%M%S')}.mp3" | |
await communicate.save(out_fn) | |
return out_fn | |
def speak_with_edge_tts(text, voice="en-US-AriaNeural"): | |
return asyncio.run(edge_tts_generate_audio(text, voice, 0, 0)) | |
def play_and_download_audio(file_path): | |
if file_path and os.path.exists(file_path): | |
st.audio(file_path) | |
dl_link = f'<a href="data:audio/mpeg;base64,{base64.b64encode(open(file_path,"rb").read()).decode()}" download="{os.path.basename(file_path)}">Download {os.path.basename(file_path)}</a>' | |
st.markdown(dl_link, unsafe_allow_html=True) | |
def generate_filename(prefix, text): | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
safe_text = re.sub(r'[^\w\s-]', '', text[:50]).strip().lower() | |
safe_text = re.sub(r'[-\s]+', '-', safe_text) | |
return f"{prefix}_{timestamp}_{safe_text}.md" | |
def save_input_as_md(user_name, text, prefix="input"): | |
if not text.strip(): | |
return | |
fn = generate_filename(prefix, text) | |
full_path = os.path.join(SAVED_INPUTS_DIR, fn) | |
with open(full_path, 'w', encoding='utf-8') as f: | |
f.write(f"# User: {user_name}\n") | |
f.write(f"**Timestamp:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n") | |
f.write(text) | |
return full_path | |
def save_response_as_md(user_name, text, prefix="response"): | |
if not text.strip(): | |
return | |
fn = generate_filename(prefix, text) | |
full_path = os.path.join(SAVED_INPUTS_DIR, fn) | |
with open(full_path, 'w', encoding='utf-8') as f: | |
f.write(f"# User: {user_name}\n") | |
f.write(f"**Timestamp:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n") | |
f.write(text) | |
return full_path | |
def list_saved_inputs(): | |
files = sorted(glob.glob(os.path.join(SAVED_INPUTS_DIR, "*.md"))) | |
return files | |
def parse_md_file(fpath): | |
# Extract user and text from md | |
user_line = "" | |
ts_line = "" | |
content_lines = [] | |
with open(fpath, 'r', encoding='utf-8') as f: | |
lines = f.readlines() | |
for line in lines: | |
if line.startswith("# User:"): | |
user_line = line.replace("# User:", "").strip() | |
elif line.startswith("**Timestamp:**"): | |
ts_line = line.replace("**Timestamp:**", "").strip() | |
else: | |
content_lines.append(line.strip()) | |
content = "\n".join(content_lines).strip() | |
return user_line, ts_line, content | |
def fetch_dataset_info(dataset_id, token): | |
info_url = f"https://huggingface.co/api/datasets/{dataset_id}" | |
try: | |
response = requests.get(info_url, timeout=30) | |
if response.status_code == 200: | |
return response.json() | |
except Exception: | |
pass | |
return None | |
def get_dataset_info(dataset_id, token): | |
try: | |
dataset = load_dataset(dataset_id, token=token, streaming=True) | |
return dataset['train'].info | |
except: | |
return None | |
def load_dataset_page(dataset_id, token, page, rows_per_page): | |
try: | |
start_idx = page * rows_per_page | |
end_idx = start_idx + rows_per_page | |
dataset = load_dataset( | |
dataset_id, | |
token=token, | |
streaming=False, | |
split=f'train[{start_idx}:{end_idx}]' | |
) | |
return pd.DataFrame(dataset) | |
except: | |
return pd.DataFrame() | |
class FastDatasetSearcher: | |
def __init__(self, dataset_id="tomg-group-umd/cinepile"): | |
self.dataset_id = dataset_id | |
self.text_model = get_model() | |
self.token = os.environ.get('DATASET_KEY') | |
def load_page(self, page=0): | |
return load_dataset_page(self.dataset_id, self.token, page, ROWS_PER_PAGE) | |
def quick_search(self, query, df): | |
if df.empty or not query.strip(): | |
return df | |
try: | |
searchable_cols = [] | |
if len(df) > 0: | |
for col in df.columns: | |
sample_val = df[col].iloc[0] | |
if not isinstance(sample_val, (np.ndarray, bytes)): | |
searchable_cols.append(col) | |
query_lower = query.lower() | |
query_terms = set(query_lower.split()) | |
query_embedding = self.text_model.encode([query], show_progress_bar=False)[0] | |
scores = [] | |
matched_any = [] | |
for _, row in df.iterrows(): | |
text_parts = [] | |
row_matched = False | |
exact_match = False | |
priority_fields = ['description', 'matched_text'] | |
other_fields = [col for col in searchable_cols if col not in priority_fields] | |
for col in priority_fields: | |
if col in row: | |
val = row[col] | |
if val is not None: | |
val_str = str(val).lower() | |
if query_lower in val_str.split(): | |
exact_match = True | |
if any(term in val_str.split() for term in query_terms): | |
row_matched = True | |
text_parts.append(str(val)) | |
for col in other_fields: | |
val = row[col] | |
if val is not None: | |
val_str = str(val).lower() | |
if query_lower in val_str.split(): | |
exact_match = True | |
if any(term in val_str.split() for term in query_terms): | |
row_matched = True | |
text_parts.append(str(val)) | |
text = ' '.join(text_parts) | |
if text.strip(): | |
text_tokens = set(text.lower().split()) | |
matching_terms = query_terms.intersection(text_tokens) | |
keyword_score = len(matching_terms) / len(query_terms) if len(query_terms) > 0 else 0.0 | |
text_embedding = self.text_model.encode([text], show_progress_bar=False)[0] | |
semantic_score = float(cosine_similarity([query_embedding], [text_embedding])[0][0]) | |
combined_score = 0.7 * keyword_score + 0.3 * semantic_score | |
if exact_match: | |
combined_score *= EXACT_MATCH_BOOST | |
elif row_matched: | |
combined_score *= 1.2 | |
else: | |
combined_score = 0.0 | |
row_matched = False | |
scores.append(combined_score) | |
matched_any.append(row_matched) | |
results_df = df.copy() | |
results_df['score'] = scores | |
results_df['matched'] = matched_any | |
filtered_df = results_df[ | |
(results_df['matched']) | | |
(results_df['score'] > MIN_SEARCH_SCORE) | |
] | |
return filtered_df.sort_values('score', ascending=False) | |
except: | |
return df | |
def play_text(text): | |
voice = st.session_state.get('global_voice', "en-US-AriaNeural") | |
audio_file = speak_with_edge_tts(text, voice=voice) | |
if audio_file: | |
play_and_download_audio(audio_file) | |
def arxiv_search(query, max_results=3): | |
# Simple arXiv search using RSS (for demonstration) | |
# In production, use official arXiv API or a library. | |
base_url = "http://export.arxiv.org/api/query" | |
params = { | |
'search_query': query.replace(' ', '+'), | |
'start': 0, | |
'max_results': max_results | |
} | |
response = requests.get(base_url, params=params, timeout=30) | |
if response.status_code == 200: | |
root = ET.fromstring(response.text) | |
ns = {"a": "http://www.w3.org/2005/Atom"} | |
entries = root.findall('a:entry', ns) | |
results = [] | |
for entry in entries: | |
title = entry.find('a:title', ns).text.strip() | |
summary = entry.find('a:summary', ns).text.strip() | |
# Just truncating summary for demo | |
summary_short = summary[:300] + "..." | |
results.append((title, summary_short)) | |
return results | |
return [] | |
def summarize_arxiv_results(results): | |
# Just combine titles and short summaries | |
lines = [] | |
for i, (title, summary) in enumerate(results, 1): | |
lines.append(f"Result {i}: {title}\n{summary}\n") | |
return "\n\n".join(lines) | |
def main(): | |
st.title("ποΈ Voice Chat & Search") | |
# Sidebar | |
with st.sidebar: | |
# Editable user name | |
st.session_state['user_name'] = st.text_input("Current User:", value=st.session_state['user_name']) | |
# Global voice selection | |
st.session_state['global_voice'] = st.selectbox("Select Global Voice:", ENGLISH_VOICES, index=0) | |
st.session_state['max_items'] = st.number_input("Max Items per search iteration:", min_value=1, max_value=1000, value=st.session_state['max_items']) | |
st.subheader("π Saved Inputs & Responses") | |
saved_files = list_saved_inputs() | |
for fpath in saved_files: | |
user, ts, content = parse_md_file(fpath) | |
fname = os.path.basename(fpath) | |
st.write(f"- {fname} (User: {user})") | |
# Create voice component for input | |
voice_component = create_voice_component() | |
voice_val = voice_component(my_input_value="Start speaking...") | |
# Tabs: Voice Chat History, Arxiv Search, Dataset Search, Settings | |
tab1, tab2, tab3, tab4 = st.tabs(["π£οΈ Voice Chat History", "π ArXiv Search", "π Dataset Search", "βοΈ Settings"]) | |
# ------------------ Voice Chat History ------------------------- | |
with tab1: | |
st.subheader("Voice Chat History") | |
# List saved inputs and responses and allow playing them | |
files = list_saved_inputs() | |
for fpath in reversed(files): | |
user, ts, content = parse_md_file(fpath) | |
with st.expander(f"{ts} - {user}", expanded=False): | |
st.write(content) | |
if st.button("π Read Aloud", key=f"read_{fpath}"): | |
play_text(content) | |
# ------------------ ArXiv Search ------------------------- | |
with tab2: | |
st.subheader("ArXiv Search") | |
# If we have a voice_val and autorun with ArXiv chosen: | |
edited_input = st.text_area("Enter or Edit Search Query:", value=(voice_val.strip() if voice_val else ""), height=100) | |
autorun = st.checkbox("β‘ Auto-Run", value=True) | |
run_arxiv = st.button("π ArXiv Search") | |
input_changed = (edited_input != st.session_state.get('old_val')) | |
if autorun and input_changed and edited_input.strip(): | |
st.session_state['old_val'] = edited_input | |
# Save user input | |
save_input_as_md(st.session_state['user_name'], edited_input, prefix="input") | |
with st.spinner("Searching ArXiv..."): | |
results = arxiv_search(edited_input) | |
if results: | |
summary = summarize_arxiv_results(results) | |
# Save response | |
save_response_as_md(st.session_state['user_name'], summary, prefix="response") | |
st.write(summary) | |
# Autoplay TTS | |
play_text(summary) | |
else: | |
st.warning("No results found on ArXiv.") | |
if run_arxiv and edited_input.strip(): | |
# Manual trigger | |
save_input_as_md(st.session_state['user_name'], edited_input, prefix="input") | |
with st.spinner("Searching ArXiv..."): | |
results = arxiv_search(edited_input) | |
if results: | |
summary = summarize_arxiv_results(results) | |
save_response_as_md(st.session_state['user_name'], summary, prefix="response") | |
st.write(summary) | |
play_text(summary) | |
else: | |
st.warning("No results found on ArXiv.") | |
# ------------------ Dataset Search ------------------------- | |
with tab3: | |
st.subheader("Dataset Search") | |
search = FastDatasetSearcher() | |
query = st.text_input("Enter dataset search query:") | |
run_ds_search = st.button("Search Dataset") | |
num_results = st.slider("Max results:", 1, 100, 20) | |
if run_ds_search and query.strip(): | |
with st.spinner("Searching dataset..."): | |
df = search.load_page() | |
results = search.quick_search(query, df) | |
if len(results) > 0: | |
st.write(f"Found {len(results)} results:") | |
shown = 0 | |
for i, (_, result) in enumerate(results.iterrows(), 1): | |
if shown >= num_results: | |
break | |
with st.expander(f"Result {i}", expanded=(i==1)): | |
# Just print result keys/values here | |
for k, v in result.items(): | |
if k not in ['score', 'matched']: | |
st.write(f"**{k}:** {v}") | |
shown += 1 | |
else: | |
st.warning("No matching results found.") | |
# ------------------ Settings Tab ------------------------- | |
with tab4: | |
st.subheader("Settings") | |
st.write("Adjust voice and search parameters in the sidebar.") | |
if st.button("ποΈ Clear Search History"): | |
st.session_state['search_history'] = [] | |
# Optionally delete files: | |
# for fpath in list_saved_inputs(): | |
# os.remove(fpath) | |
st.success("Search history cleared!") | |
if __name__ == "__main__": | |
main() | |