File size: 5,226 Bytes
221a628
b557897
 
 
eabca2c
221a628
665ff61
b557897
221a628
e2815a3
6e074fc
221a628
dfdbfa8
 
 
 
221a628
ccf48ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
837873a
ce73371
 
eabca2c
 
326cdbe
 
b557897
 
221a628
00bfc2f
665ff61
00bfc2f
 
 
 
dfdbfa8
40e8df5
ce73371
40e8df5
326cdbe
40e8df5
ba4c612
eabca2c
 
326cdbe
 
fb00ecf
 
 
 
 
 
 
ce73371
fb00ecf
326cdbe
ce73371
 
326cdbe
fb00ecf
326cdbe
 
 
 
 
ce73371
 
 
 
 
fb00ecf
 
 
 
 
 
 
 
 
 
ce73371
326cdbe
 
5837eff
9c3f46e
ce73371
326cdbe
 
 
 
 
 
 
 
 
 
 
665ff61
 
 
 
 
 
 
 
 
8aae6cc
389cdce
 
 
 
ce73371
eabca2c
389cdce
 
 
a76b0fb
ce73371
 
665ff61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce73371
665ff61
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import streamlit as st
from PIL import Image
import io
import base64
import uuid

# Assuming google.generativeai is correctly imported as genai and the API key is set
import google.generativeai as genai

api_key = "AIzaSyC70u1sN87IkoxOoIj4XCAPw97ae2LZwNM"  # Replace with your actual API key
genai.configure(api_key=api_key)

generation_config = genai.GenerationConfig(
    temperature=0.9,
    max_output_tokens=3000
)

safety_settings = SAFETY_SETTINGS = [
    {
        "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
        "threshold": "BLOCK_NONE",
    },
    {
        "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
        "threshold": "BLOCK_NONE",
    },
    {
        "category": "HARM_CATEGORY_HATE_SPEECH",
        "threshold": "BLOCK_NONE",
    },
    {
        "category": "HARM_CATEGORY_HARASSMENT",
        "threshold": "BLOCK_NONE",
    },
]

if 'chat_history' not in st.session_state:
    st.session_state['chat_history'] = []
if 'file_uploader_key' not in st.session_state:
    st.session_state['file_uploader_key'] = str(uuid.uuid4())
if 'use_vision_model' not in st.session_state:
    st.session_state['use_vision_model'] = False

st.title("Gemini Chatbot")

def get_image_base64(image):
    image = image.convert("RGB")
    buffered = io.BytesIO()
    image.save(buffered, format="JPEG")
    img_str = base64.b64encode(buffered.getvalue()).decode()
    return img_str

def clear_conversation():
    st.session_state['chat_history'] = []
    st.session_state['file_uploader_key'] = str(uuid.uuid4())
    st.session_state['use_vision_model'] = False

def send_message():
    user_input = st.session_state.user_input
    uploaded_files = st.session_state.uploaded_files
    if uploaded_files:
        st.session_state['use_vision_model'] = True
    prompts = []
    for entry in st.session_state['chat_history']:
        for part in entry['parts']:
            if 'text' in part:
                prompts.append(part['text'])
            elif 'data' in part:
                prompts.append("[Image]")
    if user_input:
        prompts.append(user_input)
        st.session_state['chat_history'].append({"role": "user", "parts": [{"text": user_input}]})
    if uploaded_files:
        for uploaded_file in uploaded_files:
            base64_image = get_image_base64(Image.open(uploaded_file))
            prompts.append("[Image]")
            st.session_state['chat_history'].append({
                "role": "user",
                "parts": [{"mime_type": uploaded_file.type, "data": base64_image}]
            })
    model_name = 'gemini-pro-vision' if st.session_state['use_vision_model'] else 'gemini-pro'
    model = genai.GenerativeModel(
        model_name=model_name,
        generation_config=generation_config,
        safety_settings=safety_settings
    )
    chat_history_str = "\n".join(prompts)
    if st.session_state['use_vision_model']:
        prompt_parts = [{"text": chat_history_str}] + [
            {"data": part['data'], "mime_type": "image/jpeg"}
            for entry in st.session_state['chat_history'] for part in entry['parts']
            if 'data' in part
        ]
    else:
        prompt_parts = [{"text": chat_history_str}]
    response = model.generate_content([{"role": "user", "parts": prompt_parts}])
    response_text = response.text if hasattr(response, "text") else "No response text found."
    if response_text:
        st.session_state['chat_history'].append({"role": "model", "parts": [{"text": response_text}]})
    st.session_state.user_input = ''
    st.session_state.uploaded_files = []
    st.session_state.file_uploader_key = str(uuid.uuid4())
    display_chat_history()

def display_chat_history():
    for entry in st.session_state['chat_history']:
        role = entry["role"]
        parts = entry["parts"][0]
        if 'text' in parts:
            st.markdown(f"{role.title()}: {parts['text']}")
        elif 'data' in parts:
            st.image(Image.open(io.BytesIO(base64.b64decode(parts['data']))), caption='Uploaded Image')

def get_chat_history_str():
    chat_history_str = "\n".join(
        f"{entry['role'].title()}: {part['text']}" if 'text' in part 
        else f"{entry['role'].title()}: (Image)"
        for entry in st.session_state['chat_history'] 
        for part in entry['parts']
    )
    return chat_history_str

user_input = st.text_area("Enter your message here:", key="user_input")

uploaded_files = st.file_uploader(
    "Upload images:",
    type=["png", "jpg", "jpeg"],
    accept_multiple_files=True,
    key=st.session_state.file_uploader_key
)

send_button = st.button("Send", on_click=send_message)

clear_button = st.button("Clear Conversation", on_click=clear_conversation)

# Function to download the chat history
def download_chat_history():
    chat_history_str = get_chat_history_str()
    return chat_history_str

# Add a button to download the chat history as a text file
download_button = st.download_button(
    label="Download Chat",
    data=download_chat_history(),
    file_name="chat_history.txt",
    mime="text/plain"
)

# Display the chat history
display_chat_history()

# Ensure the file_uploader widget state is tied to the randomly generated key
st.session_state.uploaded_files = uploaded_files