Spaces:
Runtime error
Runtime error
import streamlit as st | |
from PIL import Image | |
import io | |
import base64 | |
import uuid | |
from gtts import gTTS | |
import google.generativeai as genai | |
# Set your API key | |
api_key = "AIzaSyC70u1sN87IkoxOoIj4XCAPw97ae2LZwNM" # Replace with your actual API key | |
genai.configure(api_key=api_key) | |
# Configure the generative AI model | |
generation_config = genai.GenerationConfig( | |
temperature=0.9, | |
max_output_tokens=3000 | |
) | |
# Safety settings configuration | |
safety_settings = [ | |
{ | |
"category": "HARM_CATEGORY_DANGEROUS_CONTENT", | |
"threshold": "BLOCK_NONE", | |
}, | |
{ | |
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", | |
"threshold": "BLOCK_NONE", | |
}, | |
{ | |
"category": "HARM_CATEGORY_HATE_SPEECH", | |
"threshold": "BLOCK_NONE", | |
}, | |
{ | |
"category": "HARM_CATEGORY_HARASSMENT", | |
"threshold": "BLOCK_NONE", | |
}, | |
] | |
# Initialize session state | |
if 'chat_history' not in st.session_state: | |
st.session_state['chat_history'] = [] | |
if 'file_uploader_key' not in st.session_state: | |
st.session_state['file_uploader_key'] = str(uuid.uuid4()) | |
st.title("Gemini Chatbot") | |
# Helper functions for image processing and chat history management | |
def get_image_base64(image): | |
image = image.convert("RGB") | |
buffered = io.BytesIO() | |
image.save(buffered, format="JPEG") | |
img_str = base64.b64encode(buffered.getvalue()).decode() | |
return img_str | |
def clear_conversation(): | |
st.session_state['chat_history'] = [] | |
st.session_state['file_uploader_key'] = str(uuid.uuid4()) | |
def display_chat_history(): | |
for entry in st.session_state['chat_history']: | |
role = entry["role"] | |
parts = entry["parts"][0] | |
if 'text' in parts: | |
st.markdown(f"{role.title()}: {parts['text']}") | |
elif 'data' in parts: | |
st.image(Image.open(io.BytesIO(base64.b64decode(parts['data']))), caption='Uploaded Image') | |
def get_chat_history_str(): | |
chat_history_str = "\n".join( | |
f"{entry['role'].title()}: {part['text']}" if 'text' in part | |
else f"{entry['role'].title()}: (Image)" | |
for entry in st.session_state['chat_history'] | |
for part in entry['parts'] | |
) | |
return chat_history_str | |
# Send message function with TTS integration | |
def send_message(): | |
user_input = st.session_state.user_input | |
uploaded_files = st.session_state.uploaded_files | |
prompts = [] | |
# Populate the prompts list with the existing chat history | |
for entry in st.session_state['chat_history']: | |
for part in entry['parts']: | |
if 'text' in part: | |
prompts.append(part['text']) | |
elif 'data' in part: | |
prompts.append("[Image]") | |
# Append the user input to the prompts list | |
if user_input: | |
prompts.append(user_input) | |
st.session_state['chat_history'].append({"role": "user", "parts": [{"text": user_input}]}) | |
# Handle uploaded files | |
if uploaded_files: | |
for uploaded_file in uploaded_files: | |
base64_image = get_image_base64(Image.open(uploaded_file)) | |
prompts.append("[Image]") | |
st.session_state['chat_history'].append({ | |
"role": "user", | |
"parts": [{"mime_type": uploaded_file.type, "data": base64_image}] | |
}) | |
# Set up the model and generate a response | |
model_name = 'gemini-pro-vision' if st.session_state.get('use_vision_model', False) else 'gemini-pro' | |
model = genai.GenerativeModel( | |
model_name=model_name, | |
generation_config=generation_config, | |
safety_settings=safety_settings | |
) | |
chat_history_str = "\n".join(prompts) | |
prompt_parts = [{"text": chat_history_str}] | |
response = model.generate_content([{"role": "user", "parts": prompt_parts}]) | |
response_text = response.text if hasattr(response, "text") else "No response text found." | |
# After generating the response from the model, append it to the chat history | |
if response_text: | |
st.session_state['chat_history'].append({"role": "model", "parts": [{"text": response_text}]}) | |
# Convert the response text to speech | |
tts = gTTS(text=response_text, lang='en') | |
tts_file = BytesIO() | |
tts.write_to_fp(tts_file) | |
tts_file.seek(0) | |
st.audio(tts_file, format='audio/mp3') | |
# Clear the input fields after sending the message | |
st.session_state.user_input = '' | |
st.session_state.uploaded_files = [] | |
st.session_state.file_uploader_key = str(uuid.uuid4()) | |
# Display the updated chat history | |
display_chat_history() | |
# User input text area | |
user_input = st.text_area( | |
"Enter your message here:", | |
value="", | |
key="user_input" | |
) | |
# File uploader for images | |
uploaded_files = st.file_uploader( | |
"Upload images:", | |
type=["png", "jpg", "jpeg"], | |
accept_multiple_files=True, | |
key=st.session_state.file_uploader_key | |
) | |
# Send message button | |
send_button = st.button( | |
"Send", | |
on_click=send_message | |
) | |
# Clear conversation button | |
clear_button = st.button("Clear Conversation", on_click=clear_conversation) | |
# Function to download the chat history as a text file | |
def download_chat_history(): | |
chat_history_str = get_chat_history_str() | |
return chat_history_str | |
# Download button for the chat history | |
download_button = st.download_button( | |
label="Download Chat", | |
data=download_chat_history(), | |
file_name="chat_history.txt", | |
mime="text/plain" | |
) | |
# Ensure the file_uploader widget state is tied to the randomly generated key | |
st.session_state.uploaded_files = uploaded_files | |
# JavaScript to capture the Ctrl+Enter event and trigger a button click | |
st.markdown( | |
""" | |
<script> | |
document.addEventListener('DOMContentLoaded', (event) => { | |
document.querySelector('.stTextArea textarea').addEventListener('keydown', function(e) { | |
if (e.key === 'Enter' && e.ctrlKey) { | |
document.querySelector('.stButton > button').click(); | |
e.preventDefault(); | |
} | |
}); | |
}); | |
</script> | |
""", | |
unsafe_allow_html=True | |
) |