Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| from PIL import Image | |
| import io | |
| import base64 | |
| import uuid | |
| from gtts import gTTS | |
| import google.generativeai as genai | |
| from io import BytesIO | |
| # Set your API key | |
| api_key = "AIzaSyAHD0FwX-Ds6Y3eI-i5Oz7IdbJqR6rN7pg" # Replace with your actual API key | |
| genai.configure(api_key=api_key) | |
| # Configure the generative AI model | |
| generation_config = genai.GenerationConfig( | |
| temperature=0.9, | |
| max_output_tokens=3000 | |
| ) | |
| # Safety settings configuration | |
| safety_settings = [ | |
| { | |
| "category": "HARM_CATEGORY_DANGEROUS_CONTENT", | |
| "threshold": "BLOCK_NONE", | |
| }, | |
| { | |
| "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", | |
| "threshold": "BLOCK_NONE", | |
| }, | |
| { | |
| "category": "HARM_CATEGORY_HATE_SPEECH", | |
| "threshold": "BLOCK_NONE", | |
| }, | |
| { | |
| "category": "HARM_CATEGORY_HARASSMENT", | |
| "threshold": "BLOCK_NONE", | |
| }, | |
| ] | |
| # Initialize session state | |
| if 'chat_history' not in st.session_state: | |
| st.session_state['chat_history'] = [] | |
| if 'file_uploader_key' not in st.session_state: | |
| st.session_state['file_uploader_key'] = str(uuid.uuid4()) | |
| st.title("Gemini Chatbot") | |
| # Model Selection Dropdown | |
| selected_model = st.selectbox("Select a Gemini 1.5 model:", ["gemini-1.5-flash-latest", "gemini-1.5-pro-latest"]) | |
| # Helper functions for image processing and chat history management | |
| def get_image_base64(image): | |
| image = image.convert("RGB") | |
| buffered = io.BytesIO() | |
| image.save(buffered, format="JPEG") | |
| img_str = base64.b64encode(buffered.getvalue()).decode() | |
| return img_str | |
| def clear_conversation(): | |
| st.session_state['chat_history'] = [] | |
| st.session_state['file_uploader_key'] = str(uuid.uuid4()) | |
| def display_chat_history(): | |
| for entry in st.session_state['chat_history']: | |
| role = entry["role"] | |
| parts = entry["parts"][0] | |
| if 'text' in parts: | |
| st.markdown(f"{role.title()}: {parts['text']}") | |
| elif 'data' in parts: | |
| st.image(Image.open(io.BytesIO(base64.b64decode(parts['data']))), caption='Uploaded Image') | |
| def get_chat_history_str(): | |
| chat_history_str = "\n".join( | |
| f"{entry['role'].title()}: {part['text']}" if 'text' in part | |
| else f"{entry['role'].title()}: (Image)" | |
| for entry in st.session_state['chat_history'] | |
| for part in entry['parts'] | |
| ) | |
| return chat_history_str | |
| # Send message function with TTS integration | |
| def send_message(): | |
| user_input = st.session_state.user_input | |
| uploaded_files = st.session_state.uploaded_files | |
| prompts = [] | |
| prompt_parts = [] | |
| # Populate the prompts list with the existing chat history | |
| for entry in st.session_state['chat_history']: | |
| for part in entry['parts']: | |
| if 'text' in part: | |
| prompts.append(part['text']) | |
| elif 'data' in part: | |
| # Add the image in base64 format to prompt_parts for vision model | |
| prompt_parts.append({"data": part['data'], "mime_type": "image/jpeg"}) | |
| prompts.append("[Image]") | |
| # Append the user input to the prompts list | |
| if user_input: | |
| prompts.append(user_input) | |
| st.session_state['chat_history'].append({"role": "user", "parts": [{"text": user_input}]}) | |
| # Also add the user text input to prompt_parts | |
| prompt_parts.append({"text": user_input}) | |
| # Handle uploaded files | |
| if uploaded_files: | |
| for uploaded_file in uploaded_files: | |
| base64_image = get_image_base64(Image.open(uploaded_file)) | |
| prompts.append("[Image]") | |
| prompt_parts.append({"data": base64_image, "mime_type": "image/jpeg"}) | |
| st.session_state['chat_history'].append({ | |
| "role": "user", | |
| "parts": [{"mime_type": uploaded_file.type, "data": base64_image}] | |
| }) | |
| # Determine if vision model should be used | |
| use_vision_model = any(part.get('mime_type') == 'image/jpeg' for part in prompt_parts) | |
| # Use the selected model | |
| model_name = selected_model | |
| if use_vision_model and "pro" not in model_name: | |
| st.warning(f"The selected model ({model_name}) does not support image inputs. Choose a 'pro' model for image capabilities.") | |
| return | |
| model = genai.GenerativeModel( | |
| model_name=model_name, | |
| generation_config=generation_config, | |
| safety_settings=safety_settings | |
| ) | |
| chat_history_str = "\n".join(prompts) | |
| if use_vision_model: | |
| # Include text and images for vision model | |
| generated_prompt = {"role": "user", "parts": prompt_parts} | |
| else: | |
| # Include text only for standard model | |
| generated_prompt = {"role": "user", "parts": [{"text": chat_history_str}]} | |
| response = model.generate_content([generated_prompt]) | |
| response_text = response.text if hasattr(response, "text") else "No response text found." | |
| # After generating the response from the model, append it to the chat history | |
| if response_text: | |
| st.session_state['chat_history'].append({"role": "model", "parts":[{"text": response_text}]}) | |
| # Convert the response text to speech | |
| tts = gTTS(text=response_text, lang='en') | |
| tts_file = BytesIO() | |
| tts.write_to_fp(tts_file) | |
| tts_file.seek(0) | |
| st.audio(tts_file, format='audio/mp3') | |
| # Clear the input fields after sending the message | |
| st.session_state.user_input = '' | |
| st.session_state.uploaded_files = [] | |
| st.session_state.file_uploader_key = str(uuid.uuid4()) | |
| # Display the updated chat history | |
| display_chat_history() | |
| # User input text area | |
| user_input = st.text_area( | |
| "Enter your message here:", | |
| value="", | |
| key="user_input" | |
| ) | |
| # File uploader for images | |
| uploaded_files = st.file_uploader( | |
| "Upload images:", | |
| type=["png", "jpg", "jpeg"], | |
| accept_multiple_files=True, | |
| key=st.session_state.file_uploader_key | |
| ) | |
| # Send message button | |
| send_button = st.button( | |
| "Send", | |
| on_click=send_message | |
| ) | |
| # Clear conversation button | |
| clear_button = st.button("Clear Conversation", on_click=clear_conversation) | |
| # Function to download the chat history as a text file | |
| def download_chat_history(): | |
| chat_history_str = get_chat_history_str() | |
| return chat_history_str | |
| # Download button for the chat history | |
| download_button = st.download_button( | |
| label="Download Chat", | |
| data=download_chat_history(), | |
| file_name="chat_history.txt", | |
| mime="text/plain" | |
| ) | |
| # Ensure the file_uploader widget state is tied to the randomly generated key | |
| st.session_state.uploaded_files = uploaded_files | |
| # JavaScript to capture the Ctrl+Enter event and trigger a button click | |
| st.markdown( | |
| """ | |
| <script> | |
| document.addEventListener('DOMContentLoaded', (event) => { | |
| document.querySelector('.stTextArea textarea').addEventListener('keydown', function(e) { | |
| if (e.key === 'Enter' && e.ctrlKey) { | |
| document.querySelector('.stButton > button').click(); | |
| e.preventDefault(); | |
| } | |
| }); | |
| }); | |
| </script> | |
| """, | |
| unsafe_allow_html=True | |
| ) |