ziyadsuper2017's picture
Update app.py
2c4cf73 verified
raw
history blame
7.06 kB
import streamlit as st
from PIL import Image
import io
import base64
import uuid
from gtts import gTTS
import google.generativeai as genai
from io import BytesIO
# Set your API key
api_key = "AIzaSyAHD0FwX-Ds6Y3eI-i5Oz7IdbJqR6rN7pg" # Replace with your actual API key
genai.configure(api_key=api_key)
# Configure the generative AI model
generation_config = genai.GenerationConfig(
temperature=0.9,
max_output_tokens=3000
)
# Safety settings configuration
safety_settings = [
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_NONE",
},
]
# Initialize session state
if 'chat_history' not in st.session_state:
st.session_state['chat_history'] = []
if 'file_uploader_key' not in st.session_state:
st.session_state['file_uploader_key'] = str(uuid.uuid4())
st.title("Gemini Chatbot")
# Model Selection Dropdown
selected_model = st.selectbox("Select a Gemini 1.5 model:", ["gemini-1.5-flash-latest", "gemini-1.5-pro-latest"])
# Helper functions for image processing and chat history management
def get_image_base64(image):
image = image.convert("RGB")
buffered = io.BytesIO()
image.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode()
return img_str
def clear_conversation():
st.session_state['chat_history'] = []
st.session_state['file_uploader_key'] = str(uuid.uuid4())
def display_chat_history():
for entry in st.session_state['chat_history']:
role = entry["role"]
parts = entry["parts"][0]
if 'text' in parts:
st.markdown(f"{role.title()}: {parts['text']}")
elif 'data' in parts:
st.image(Image.open(io.BytesIO(base64.b64decode(parts['data']))), caption='Uploaded Image')
def get_chat_history_str():
chat_history_str = "\n".join(
f"{entry['role'].title()}: {part['text']}" if 'text' in part
else f"{entry['role'].title()}: (Image)"
for entry in st.session_state['chat_history']
for part in entry['parts']
)
return chat_history_str
# Send message function with TTS integration
def send_message():
user_input = st.session_state.user_input
uploaded_files = st.session_state.uploaded_files
prompts = []
prompt_parts = []
# Populate the prompts list with the existing chat history
for entry in st.session_state['chat_history']:
for part in entry['parts']:
if 'text' in part:
prompts.append(part['text'])
elif 'data' in part:
# Add the image in base64 format to prompt_parts for vision model
prompt_parts.append({"data": part['data'], "mime_type": "image/jpeg"})
prompts.append("[Image]")
# Append the user input to the prompts list
if user_input:
prompts.append(user_input)
st.session_state['chat_history'].append({"role": "user", "parts": [{"text": user_input}]})
# Also add the user text input to prompt_parts
prompt_parts.append({"text": user_input})
# Handle uploaded files
if uploaded_files:
for uploaded_file in uploaded_files:
base64_image = get_image_base64(Image.open(uploaded_file))
prompts.append("[Image]")
prompt_parts.append({"data": base64_image, "mime_type": "image/jpeg"})
st.session_state['chat_history'].append({
"role": "user",
"parts": [{"mime_type": uploaded_file.type, "data": base64_image}]
})
# Determine if vision model should be used
use_vision_model = any(part.get('mime_type') == 'image/jpeg' for part in prompt_parts)
# Use the selected model
model_name = selected_model
if use_vision_model and "pro" not in model_name:
st.warning(f"The selected model ({model_name}) does not support image inputs. Choose a 'pro' model for image capabilities.")
return
model = genai.GenerativeModel(
model_name=model_name,
generation_config=generation_config,
safety_settings=safety_settings
)
chat_history_str = "\n".join(prompts)
if use_vision_model:
# Include text and images for vision model
generated_prompt = {"role": "user", "parts": prompt_parts}
else:
# Include text only for standard model
generated_prompt = {"role": "user", "parts": [{"text": chat_history_str}]}
response = model.generate_content([generated_prompt])
response_text = response.text if hasattr(response, "text") else "No response text found."
# After generating the response from the model, append it to the chat history
if response_text:
st.session_state['chat_history'].append({"role": "model", "parts":[{"text": response_text}]})
# Convert the response text to speech
tts = gTTS(text=response_text, lang='en')
tts_file = BytesIO()
tts.write_to_fp(tts_file)
tts_file.seek(0)
st.audio(tts_file, format='audio/mp3')
# Clear the input fields after sending the message
st.session_state.user_input = ''
st.session_state.uploaded_files = []
st.session_state.file_uploader_key = str(uuid.uuid4())
# Display the updated chat history
display_chat_history()
# User input text area
user_input = st.text_area(
"Enter your message here:",
value="",
key="user_input"
)
# File uploader for images
uploaded_files = st.file_uploader(
"Upload images:",
type=["png", "jpg", "jpeg"],
accept_multiple_files=True,
key=st.session_state.file_uploader_key
)
# Send message button
send_button = st.button(
"Send",
on_click=send_message
)
# Clear conversation button
clear_button = st.button("Clear Conversation", on_click=clear_conversation)
# Function to download the chat history as a text file
def download_chat_history():
chat_history_str = get_chat_history_str()
return chat_history_str
# Download button for the chat history
download_button = st.download_button(
label="Download Chat",
data=download_chat_history(),
file_name="chat_history.txt",
mime="text/plain"
)
# Ensure the file_uploader widget state is tied to the randomly generated key
st.session_state.uploaded_files = uploaded_files
# JavaScript to capture the Ctrl+Enter event and trigger a button click
st.markdown(
"""
<script>
document.addEventListener('DOMContentLoaded', (event) => {
document.querySelector('.stTextArea textarea').addEventListener('keydown', function(e) {
if (e.key === 'Enter' && e.ctrlKey) {
document.querySelector('.stButton > button').click();
e.preventDefault();
}
});
});
</script>
""",
unsafe_allow_html=True
)