File size: 4,177 Bytes
221a628
b557897
 
 
eabca2c
221a628
40e8df5
b557897
221a628
40e8df5
e2815a3
6e074fc
221a628
dfdbfa8
 
 
 
221a628
6e074fc
837873a
9c3f46e
ce73371
 
eabca2c
 
ce73371
 
b557897
 
 
221a628
00bfc2f
 
389cdce
00bfc2f
 
 
 
dfdbfa8
40e8df5
 
ce73371
40e8df5
ce73371
40e8df5
fd4809b
ba4c612
eabca2c
 
6ae7b4c
ce73371
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5837eff
9c3f46e
ce73371
ba4c612
40e8df5
8aae6cc
389cdce
 
 
 
 
ce73371
eabca2c
389cdce
 
 
 
a76b0fb
6ae7b4c
ce73371
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import streamlit as st
from PIL import Image
import io
import base64
import uuid

# Assuming google.generativeai is the correct import based on your description
import google.generativeai as genai

# Configure the API key (should be set as an environment variable or secure storage in production)
api_key = "AIzaSyC70u1sN87IkoxOoIj4XCAPw97ae2LZwNM"  # Replace with your actual API key
genai.configure(api_key=api_key)

generation_config = genai.GenerationConfig(
    temperature=0.9,
    max_output_tokens=3000
)

safety_settings = []

# Initialize session state for chat history and file uploader key
if 'chat_history' not in st.session_state:
    st.session_state['chat_history'] = []
if 'file_uploader_key' not in st.session_state:
    st.session_state['file_uploader_key'] = str(uuid.uuid4())
if 'last_model_used' not in st.session_state:
    st.session_state['last_model_used'] = 'text'

# UI layout
st.title("Gemini Chatbot")

# Function to convert image to base64
def get_image_base64(image):
    image = image.convert("RGB")  # Convert to RGB to remove alpha channel if present
    buffered = io.BytesIO()
    image.save(buffered, format="JPEG")
    img_str = base64.b64encode(buffered.getvalue()).decode()
    return img_str

# Function to clear conversation
def clear_conversation():
    st.session_state['chat_history'] = []
    st.session_state['file_uploader_key'] = str(uuid.uuid4())
    st.session_state['last_model_used'] = 'text'

# Function to send message and clear input
def send_message():
    user_input = st.session_state.user_input
    uploaded_files = st.session_state.uploaded_files

    # Determine which model to use based on input type
    model_name = 'gemini-pro-vision' if uploaded_files else 'gemini-pro'
    st.session_state['last_model_used'] = 'vision' if uploaded_files else 'text'

    # Prepare prompts for conversation
    prompts = []
    if user_input:
        prompts.append({"role": "user", "parts": [{"text": user_input}]})
    
    # Append images to prompts if uploaded
    if uploaded_files:
        for uploaded_file in uploaded_files:
            prompts.append({"role": "user", "parts": [{"mime_type": uploaded_file.type, "data": get_image_base64(Image.open(uploaded_file))}]})
    
    # Create a new list combining chat history with current prompts
    combined_prompts = st.session_state['chat_history'] + prompts

    # Use the appropriate model for interaction
    model = genai.GenerativeModel(
        model_name=model_name,
        generation_config=generation_config,
        safety_settings=safety_settings
    )
    response = model.generate_content(combined_prompts)
    response_text = response.text if hasattr(response, "text") else "No response text found."
    
    # Update chat history and display the model response
    for prompt in prompts:
        st.session_state['chat_history'].append(prompt)
    st.session_state['chat_history'].append({"role": "model", "parts": [{"text": response_text}]})

    # Clear the user input and reset the file uploader widget
    st.session_state.user_input = ''
    st.session_state.uploaded_files = []
    st.session_state.file_uploader_key = str(uuid.uuid4())

# Multiline text input for the user to send messages
user_input = st.text_area("Enter your message here:", key="user_input")

# File uploader for images
uploaded_files = st.file_uploader(
    "Upload images:",
    type=["png", "jpg", "jpeg"],
    accept_multiple_files=True,
    key=st.session_state.file_uploader_key
)

# Button to send the message
send_button = st.button("Send", on_click=send_message)

# Button to clear the conversation
clear_button = st.button("Clear Conversation", on_click=clear_conversation)

# Display the chat history
for entry in st.session_state['chat_history']:
    role = entry["role"]
    parts = entry["parts"][0]
    if 'text' in parts:
        st.markdown(f"{role.title()}: {parts['text']}")
    elif 'data' in parts:
        # Display the image
        st.image(Image.open(io.BytesIO(base64.b64decode(parts['data']))), caption='Uploaded Image')

# Ensure the file_uploader widget state is tied to the randomly generated key
st.session_state.uploaded_files = uploaded_files