File size: 3,705 Bytes
221a628
b557897
 
 
eabca2c
221a628
389cdce
b557897
221a628
b557897
dfdbfa8
6e074fc
221a628
dfdbfa8
 
 
 
221a628
6e074fc
837873a
eabca2c
dfdbfa8
 
eabca2c
 
b557897
 
 
221a628
00bfc2f
 
389cdce
00bfc2f
 
 
 
dfdbfa8
fd4809b
ba4c612
eabca2c
 
389cdce
a76b0fb
 
 
fd4809b
 
 
 
c72bfe4
 
389cdce
c72bfe4
 
fd4809b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c72bfe4
eabca2c
 
 
ba4c612
389cdce
 
 
 
 
 
 
 
eabca2c
389cdce
 
 
 
a76b0fb
389cdce
fd4809b
 
 
 
 
 
eabca2c
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import streamlit as st
from PIL import Image
import io
import base64
import uuid

# Assuming google.generativeai as genai is the correct import based on your description
import google.generativeai as genai

# Configure the API key (should be set as an environment variable or secure storage in production)
api_key = "AIzaSyC70u1sN87IkoxOoIj4XCAPw97ae2LZwNM"  # Replace with your actual API key
genai.configure(api_key=api_key)

generation_config = genai.GenerationConfig(
    temperature=0.9,
    max_output_tokens=3000
)

safety_settings = []

# Initialize session state for chat history and file uploader key
if 'chat_history' not in st.session_state:
    st.session_state['chat_history'] = []
if 'file_uploader_key' not in st.session_state:
    st.session_state['file_uploader_key'] = str(uuid.uuid4())

# UI layout
st.title("Gemini Chatbot")

# Function to convert image to base64
def get_image_base64(image):
    image = image.convert("RGB")  # Convert to RGB to remove alpha channel if present
    buffered = io.BytesIO()
    image.save(buffered, format="JPEG")
    img_str = base64.b64encode(buffered.getvalue()).decode()
    return img_str

# Function to send message and clear input
def send_message():
    user_input = st.session_state.user_input
    uploaded_files = st.session_state.uploaded_files
    
    if user_input or uploaded_files:
        # Save user input to the chat history
        if user_input:
            st.session_state['chat_history'].append({"role": "user", "parts": [{"text": user_input}]})

        # Process uploaded images
        image_prompts = []
        if uploaded_files:
            for uploaded_file in uploaded_files:
                image = Image.open(uploaded_file).convert("RGB")  # Ensure image is in RGB
                image_base64 = get_image_base64(image)
                image_prompt = {"role": "user", "parts": [{"mime_type": uploaded_file.type, "data": image_base64}]}
                image_prompts.append(image_prompt)
            st.session_state['chat_history'].extend(image_prompts)

        # Choose the appropriate model based on the input type
        model_name = 'gemini-pro-vision' if uploaded_files else 'gemini-pro'
        model = genai.GenerativeModel(
            model_name=model_name,
            generation_config=generation_config,
            safety_settings=safety_settings
        )

        # Generate the response
        response = model.generate_content(st.session_state['chat_history'])
        response_text = response.text if hasattr(response, "text") else "No response text found."
        st.session_state['chat_history'].append({"role": "model", "parts": [{"text": response_text}]})

        # Clear the user input and generate a new key for the file uploader widget to reset it
        st.session_state.user_input = ''
        st.session_state.file_uploader_key = str(uuid.uuid4())

# Multiline text input for the user to send messages
user_input = st.text_area("Enter your message here:", key="user_input", value="")

# File uploader for images
uploaded_files = st.file_uploader(
    "Upload images:",
    type=["png", "jpg", "jpeg"],
    accept_multiple_files=True,
    key=st.session_state.file_uploader_key
)

# Button to send the message
send_button = st.button("Send", on_click=send_message)

# Display the chat history
for entry in st.session_state['chat_history']:
    role = entry["role"]
    parts = entry["parts"][0]
    if 'text' in parts:
        st.markdown(f"**{role.title()}**: {parts['text']}")
    elif 'data' in parts:
        st.markdown(f"**{role.title()}**: (Image)")

# Ensure the file_uploader widget state is tied to the randomly generated key
st.session_state.uploaded_files = uploaded_files