File size: 4,510 Bytes
221a628
b557897
 
 
eabca2c
9c3f46e
221a628
389cdce
b557897
221a628
5e7a6b1
9ebd672
6e074fc
221a628
dfdbfa8
 
 
 
221a628
6e074fc
837873a
9c3f46e
dfdbfa8
 
eabca2c
 
b557897
 
 
221a628
00bfc2f
 
389cdce
00bfc2f
 
 
 
dfdbfa8
fd4809b
ba4c612
eabca2c
 
6ae7b4c
9c3f46e
 
 
 
5837eff
9c3f46e
 
5837eff
9c3f46e
 
 
 
 
 
 
 
 
fd4809b
5837eff
fd4809b
 
 
9c3f46e
fd4809b
 
6ae7b4c
9c3f46e
 
5837eff
 
 
 
 
 
 
9c3f46e
 
5837eff
c72bfe4
5837eff
 
9c3f46e
5837eff
ba4c612
9c3f46e
6ae7b4c
9c3f46e
 
6ae7b4c
389cdce
9c3f46e
389cdce
 
 
 
 
 
eabca2c
389cdce
 
 
 
a76b0fb
6ae7b4c
 
 
389cdce
fd4809b
 
 
 
9c3f46e
fd4809b
9c3f46e
eabca2c
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import streamlit as st
from PIL import Image
import io
import base64
import uuid
import os

# Assuming google.generativeai as genai is the correct import based on your description
import google.generativeai as genai

#Configure the API key (should be set as an environment variable or secure storage in production)
api_key = "AIzaSyC70u1sN87IkoxOoIj4XCAPw97ae2LZwNM" # Replace with your actual API key
genai.configure(api_key=api_key)

generation_config = genai.GenerationConfig(
    temperature=0.9,
    max_output_tokens=3000
)

safety_settings = []

# Initialize session state for chat history and file uploader key
if 'chat_history' not in st.session_state:
    st.session_state['chat_history'] = []
if 'file_uploader_key' not in st.session_state:
    st.session_state['file_uploader_key'] = str(uuid.uuid4())

# UI layout
st.title("Gemini Chatbot")

# Function to convert image to base64
def get_image_base64(image):
    image = image.convert("RGB")  # Convert to RGB to remove alpha channel if present
    buffered = io.BytesIO()
    image.save(buffered, format="JPEG")
    img_str = base64.b64encode(buffered.getvalue()).decode()
    return img_str

# Function to send message and clear input
def send_message():
    user_input = st.session_state.user_input
    uploaded_files = st.session_state.uploaded_files

    text_prompts = []
    image_prompts = []

    # Process text input for multi-turn conversation
    if user_input:
        text_prompts.append({"role": "user", "parts": [{"text": user_input}]})
        st.session_state['chat_history'].append({"role": "user", "parts": [{"text": user_input}]})

    # Process uploaded images for single-turn conversation
    if uploaded_files:
        for uploaded_file in uploaded_files:
            image = Image.open(uploaded_file).convert("RGB")  # Ensure image is in RGB
            image_base64 = get_image_base64(image)
            image_prompts.append({"role": "user", "parts": [{"mime_type": uploaded_file.type, "data": image_base64}]})

    # Generate text response if text input is provided
    if text_prompts:
        model = genai.GenerativeModel(
            model_name='gemini-pro',
            generation_config=generation_config,
            safety_settings=safety_settings
        )
        response = model.generate_content(st.session_state['chat_history'])
        response_text = response.text if hasattr(response, "text") else "No response text found."
        st.session_state['chat_history'].append({"role": "model", "parts": [{"text": response_text}]})

    # Generate image response if images are uploaded
    if image_prompts:
        model = genai.GenerativeModel(
            model_name='gemini-pro-vision',
            generation_config=generation_config,
            safety_settings=safety_settings
        )
        response = model.generate_content(image_prompts)
        response_text = response.text if hasattr(response, "text") else "No response text found."
        for prompt in image_prompts:
            st.session_state['chat_history'].append(prompt)  # Append images to history
        st.session_state['chat_history'].append({"role": "model", "parts": [{"text": response_text}]})

    # Clear the user input and generate a new key for the file uploader widget to reset it
    st.session_state.user_input = ''
    st.session_state.uploaded_files = []
    st.session_state.file_uploader_key = str(uuid.uuid4())

# Function to clear conversation
def clear_conversation():
    st.session_state['chat_history'] = []
    st.session_state['file_uploader_key'] = str(uuid.uuid4())

# Multiline text input for the user to send messages
user_input = st.text_area("Enter your message here:", key="user_input", value="")

# File uploader for images
uploaded_files = st.file_uploader(
    "Upload images:",
    type=["png", "jpg", "jpeg"],
    accept_multiple_files=True,
    key=st.session_state.file_uploader_key
)

# Button to send the message
send_button = st.button("Send", on_click=send_message)

# Button to clear the conversation
clear_button = st.button("Clear Conversation", on_click=clear_conversation)

# Display the chat history
for entry in st.session_state['chat_history']:
    role = entry["role"]
    parts = entry["parts"][0]
    if 'text' in parts:
        st.markdown(f"{role.title()}: {parts['text']}")
    elif 'data' in parts:
        st.markdown(f"{role.title()}: (Image)")

# Ensure the file_uploader widget state is tied to the randomly generated key
st.session_state.uploaded_files = uploaded_files