Spaces:
Runtime error
Runtime error
import streamlit as st | |
from PIL import Image | |
import io | |
import base64 | |
import uuid | |
# Assuming google.generativeai is the correct import based on your description | |
import google.generativeai as genai | |
# Configure the API key (should be set as an environment variable or secure storage in production) | |
api_key = "AIzaSyC70u1sN87IkoxOoIj4XCAPw97ae2LZwNM" # Replace with your actual API key | |
genai.configure(api_key=api_key) | |
generation_config = genai.GenerationConfig( | |
temperature=0.9, | |
max_output_tokens=3000 | |
) | |
safety_settings = [] | |
# Initialize session state for chat history and file uploader key | |
if 'chat_history' not in st.session_state: | |
st.session_state['chat_history'] = [] | |
if 'file_uploader_key' not in st.session_state: | |
st.session_state['file_uploader_key'] = str(uuid.uuid4()) | |
# UI layout | |
st.title("Gemini Chatbot") | |
# Function to convert image to base64 | |
def get_image_base64(image): | |
image = image.convert("RGB") # Convert to RGB to remove alpha channel if present | |
buffered = io.BytesIO() | |
image.save(buffered, format="JPEG") | |
img_str = base64.b64encode(buffered.getvalue()).decode() | |
return img_str | |
# Function to clear conversation | |
def clear_conversation(): | |
st.session_state['chat_history'] = [] | |
st.session_state['file_uploader_key'] = str(uuid.uuid4()) | |
# Function to send message and clear input | |
def send_message(): | |
user_input = st.session_state.user_input | |
uploaded_files = st.session_state.uploaded_files | |
# Process text input for multi-turn conversation | |
if user_input: | |
st.session_state['chat_history'].append({"role": "user", "parts": [{"text": user_input}]}) | |
# Check if images are uploaded | |
if uploaded_files: | |
# Prepare image prompts for single-turn conversation | |
image_prompts = [{ | |
"role": "user", | |
"parts": [{"mime_type": uploaded_file.type, "data": get_image_base64(Image.open(uploaded_file))}] | |
} for uploaded_file in uploaded_files] | |
# Use Gemini Pro Vision model for image-based interaction | |
vision_model = genai.GenerativeModel( | |
model_name='gemini-pro-vision', | |
generation_config=generation_config, | |
safety_settings=safety_settings | |
) | |
response = vision_model.generate_content(image_prompts) | |
response_text = response.text if hasattr(response, "text") else "No response text found." | |
# Append images and response to chat history | |
for prompt in image_prompts: | |
st.session_state['chat_history'].append(prompt) | |
st.session_state['chat_history'].append({"role": "model", "parts": [{"text": response_text}]}) | |
# If no images are uploaded, use Gemini Pro model for text-based interaction | |
elif user_input: | |
text_model = genai.GenerativeModel( | |
model_name='gemini-pro', | |
generation_config=generation_config, | |
safety_settings=safety_settings | |
) | |
response = text_model.generate_content(st.session_state['chat_history']) | |
response_text = response.text if hasattr(response, "text") else "No response text found." | |
st.session_state['chat_history'].append({"role": "model", "parts": [{"text": response_text}]}) | |
# Clear the user input and generate a new key for the file uploader widget to reset it | |
st.session_state.user_input = '' | |
st.session_state.uploaded_files = [] | |
st.session_state.file_uploader_key = str(uuid.uuid4()) | |
# Multiline text input for the user to send messages | |
user_input = st.text_area("Enter your message here:", key="user_input") | |
# File uploader for images | |
uploaded_files = st.file_uploader( | |
"Upload images:", | |
type=["png", "jpg", "jpeg"], | |
accept_multiple_files=True, | |
key=st.session_state.file_uploader_key | |
) | |
# Button to send the message | |
send_button = st.button("Send", on_click=send_message) | |
# Button to clear the conversation | |
clear_button = st.button("Clear Conversation", on_click=clear_conversation) | |
# Display the chat history | |
for entry in st.session_state['chat_history']: | |
role = entry["role"] | |
parts = entry["parts"][0] | |
if 'text' in parts: | |
st.markdown(f"{role.title()}: {parts['text']}") | |
elif 'data' in parts: | |
st.markdown(f"{role.title()}: (Image)") | |
# Ensure the file_uploader widget state is tied to the randomly generated key | |
st.session_state.uploaded_files = uploaded_files |