ziyadsuper2017's picture
trying to add audio recording feature
7fd1c6d verified
raw
history blame
7.74 kB
import streamlit as st
from PIL import Image
import io
import base64
import uuid
from gtts import gTTS
import google.generativeai as genai
from io import BytesIO
import PyPDF2
import soundfile as sf
import librosa
import numpy as np
# Set your API key
api_key = "AIzaSyAHD0FwX-Ds6Y3eI-i5Oz7IdbJqR6rN7pg" # Replace with your actual API key
genai.configure(api_key=api_key)
# Configure the generative AI model
generation_config = genai.GenerationConfig(
temperature=0.9,
max_output_tokens=3000
)
# Safety settings configuration
safety_settings = [
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_NONE",
},
]
# Initialize session state
if 'chat_history' not in st.session_state:
st.session_state['chat_history'] = []
if 'file_uploader_key' not in st.session_state:
st.session_state['file_uploader_key'] = str(uuid.uuid4())
if 'recording_enabled' not in st.session_state:
st.session_state['recording_enabled'] = False
if 'recorded_audio' not in st.session_state:
st.session_state['recorded_audio'] = None
# --- Streamlit UI ---
st.title("Gemini Chatbot")
st.write("Interact with the powerful Gemini 1.5 models.")
# Model Selection Dropdown
selected_model = st.selectbox("Choose a Gemini 1.5 Model:", ["gemini-1.5-flash-latest", "gemini-1.5-pro-latest"])
# TTS Option Checkbox
enable_tts = st.checkbox("Enable Text-to-Speech")
# --- Helper Functions ---
def get_file_base64(file_content, mime_type):
base64_data = base64.b64encode(file_content).decode()
return {"mime_type": mime_type, "data": base64_data}
def clear_conversation():
st.session_state['chat_history'] = []
st.session_state['file_uploader_key'] = str(uuid.uuid4())
st.session_state['recorded_audio'] = None
def display_chat_history():
chat_container = st.empty()
with chat_container.container():
for entry in st.session_state['chat_history']:
role = entry["role"]
parts = entry["parts"][0]
if 'text' in parts:
st.markdown(f"**{role.title()}:** {parts['text']}")
elif 'data' in parts:
mime_type = parts.get('mime_type', '')
if mime_type.startswith('image'):
st.image(Image.open(io.BytesIO(base64.b64decode(parts['data']))),
caption='Uploaded Image', use_column_width=True)
elif mime_type == 'application/pdf':
st.write("**PDF Content:**")
pdf_reader = PyPDF2.PdfReader(io.BytesIO(base64.b64decode(parts['data'])))
for page_num in range(len(pdf_reader.pages)):
page = pdf_reader.pages[page_num]
st.write(page.extract_text())
elif mime_type.startswith('audio'):
st.audio(io.BytesIO(base64.b64decode(parts['data'])), format=mime_type)
elif mime_type.startswith('video'):
st.video(io.BytesIO(base64.b64decode(parts['data'])))
# --- Audio Recording Functions ---
def start_recording():
st.session_state['recording_enabled'] = True
st.warning("Recording started. Click 'Stop Recording' to finish.")
def stop_recording():
st.session_state['recording_enabled'] = False
st.success("Recording stopped.")
def process_audio(audio_data):
# Convert to WAV format for compatibility
wav_data, samplerate = librosa.load(audio_data, sr=None)
sf.write("temp.wav", wav_data, samplerate, format="wav")
with open("temp.wav", "rb") as f:
wav_content = f.read()
return wav_content, "audio/wav"
# --- Send Message Function ---
def send_message():
user_input = st.session_state.user_input
uploaded_files = st.session_state.uploaded_files
recorded_audio = st.session_state.recorded_audio
prompt_parts = []
# Add user input to the prompt
if user_input:
prompt_parts.append({"text": user_input})
st.session_state['chat_history'].append({"role": "user", "parts": [{"text": user_input}]})
# Handle uploaded files
if uploaded_files:
for uploaded_file in uploaded_files:
file_content = uploaded_file.read()
prompt_parts.append(get_file_base64(file_content, uploaded_file.type))
st.session_state['chat_history'].append(
{"role": "user", "parts": [get_file_base64(file_content, uploaded_file.type)]}
)
# Handle recorded audio
if recorded_audio:
audio_content, audio_type = process_audio(recorded_audio)
prompt_parts.append(get_file_base64(audio_content, audio_type))
st.session_state['chat_history'].append(
{"role": "user", "parts": [get_file_base64(audio_content, audio_type)]}
)
st.session_state['recorded_audio'] = None # Reset recorded audio
# Generate response using the selected model
try:
model = genai.GenerativeModel(
model_name=selected_model,
generation_config=generation_config,
safety_settings=safety_settings
)
response = model.generate_content([{"role": "user", "parts": prompt_parts}])
response_text = response.text if hasattr(response, "text") else "No response text found."
if response_text:
st.session_state['chat_history'].append({"role": "model", "parts": [{"text": response_text}]})
if enable_tts:
tts = gTTS(text=response_text, lang='en')
tts_file = BytesIO()
tts.write_to_fp(tts_file)
tts_file.seek(0)
st.audio(tts_file, format='audio/mp3')
except Exception as e:
st.error(f"An error occurred: {e}")
st.session_state.user_input = ''
st.session_state.uploaded_files = []
st.session_state.file_uploader_key = str(uuid.uuid4())
# Update the chat history display
display_chat_history()
# --- User Input Area ---
col1, col2 = st.columns([3, 1])
with col1:
user_input = st.text_area(
"Enter your message:",
value="",
key="user_input"
)
with col2:
send_button = st.button(
"Send",
on_click=send_message,
type="primary"
)
# --- File Uploader ---
uploaded_files = st.file_uploader(
"Upload Files (Images, Videos, PDFs, MP3):",
type=["png", "jpg", "jpeg", "mp4", "pdf", "mp3"],
accept_multiple_files=True,
key=st.session_state.file_uploader_key
)
# --- Audio Recording ---
st.audio_recorder("Record audio:", key="recorded_audio")
col3, col4 = st.columns([1, 1])
with col3:
if st.button("Start Recording"):
start_recording()
with col4:
if st.button("Stop Recording"):
stop_recording()
# --- Other Buttons ---
st.button("Clear Conversation", on_click=clear_conversation)
# --- Ensure file_uploader state ---
st.session_state.uploaded_files = uploaded_files
# --- JavaScript for Ctrl+Enter ---
st.markdown(
"""
<script>
document.addEventListener('DOMContentLoaded', (event) => {
document.querySelector('.stTextArea textarea').addEventListener('keydown', function(e) {
if (e.key === 'Enter' && e.ctrlKey) {
document.querySelector('.stButton > button').click();
e.preventDefault();
}
});
});
</script>
""",
unsafe_allow_html=True
)
# --- Display Chat History ---
display_chat_history()