File size: 7,808 Bytes
221a628
b557897
 
 
eabca2c
dfb90bd
b557897
2c4cf73
ab73386
23ed2c1
dfb90bd
2c4cf73
6e074fc
221a628
dfb90bd
dfdbfa8
 
ab73386
dfdbfa8
221a628
dfb90bd
 
ccf48ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
837873a
dfb90bd
ce73371
 
eabca2c
 
b557897
 
221a628
2c4cf73
 
 
ab73386
 
 
dfb90bd
00bfc2f
665ff61
00bfc2f
 
 
 
dfdbfa8
40e8df5
ce73371
40e8df5
 
dfb90bd
 
 
 
 
 
 
ab73386
 
 
 
 
 
 
 
 
 
 
dfb90bd
 
 
 
ab73386
dfb90bd
 
 
 
 
 
94fea6c
eabca2c
 
94fea6c
36e811b
dfb90bd
94fea6c
 
 
 
 
 
ab73386
 
94fea6c
 
 
 
 
36e811b
94fea6c
 
 
 
ab73386
 
 
 
 
 
 
94fea6c
 
ab73386
94fea6c
 
ab73386
36e811b
 
ab73386
2c4cf73
 
 
 
ab73386
ce73371
 
 
 
 
fb00ecf
2c4cf73
36e811b
 
 
 
 
 
 
 
ce73371
94fea6c
 
326cdbe
ab73386
e742fb1
ab73386
 
 
 
 
 
 
dfb90bd
 
5837eff
9c3f46e
ce73371
326cdbe
dfb90bd
 
665ff61
dfb90bd
 
 
94fea6c
dfb90bd
 
389cdce
dfb90bd
389cdce
ab73386
 
ce73371
eabca2c
389cdce
 
dfb90bd
 
 
94fea6c
dfb90bd
a76b0fb
dfb90bd
ce73371
 
dfb90bd
665ff61
 
 
 
dfb90bd
665ff61
 
 
 
 
 
 
ce73371
dfb90bd
 
 
e742fb1
 
 
 
 
 
 
dfb90bd
 
 
e742fb1
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
import streamlit as st
from PIL import Image
import io
import base64
import uuid
from gtts import gTTS
import google.generativeai as genai
from io import BytesIO
import PyPDF2

# Set your API key
api_key = "AIzaSyAHD0FwX-Ds6Y3eI-i5Oz7IdbJqR6rN7pg"  # Replace with your actual API key
genai.configure(api_key=api_key)

# Configure the generative AI model
generation_config = genai.GenerationConfig(
    temperature=0.9,
    max_output_tokens=4000
)

# Safety settings configuration
safety_settings = [
    {
        "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
        "threshold": "BLOCK_NONE",
    },
    {
        "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
        "threshold": "BLOCK_NONE",
    },
    {
        "category": "HARM_CATEGORY_HATE_SPEECH",
        "threshold": "BLOCK_NONE",
    },
    {
        "category": "HARM_CATEGORY_HARASSMENT",
        "threshold": "BLOCK_NONE",
    },
]

# Initialize session state
if 'chat_history' not in st.session_state:
    st.session_state['chat_history'] = []
if 'file_uploader_key' not in st.session_state:
    st.session_state['file_uploader_key'] = str(uuid.uuid4())

st.title("Gemini Chatbot")

# Model Selection Dropdown
selected_model = st.selectbox("Select a Gemini 1.5 model:", ["gemini-1.5-flash-latest", "gemini-1.5-pro-latest"])

# TTS Option Checkbox
enable_tts = st.checkbox("Enable Text-to-Speech")

# Helper functions for image processing and chat history management
def get_image_base64(image):
    image = image.convert("RGB")
    buffered = io.BytesIO()
    image.save(buffered, format="JPEG")
    img_str = base64.b64encode(buffered.getvalue()).decode()
    return img_str

def clear_conversation():
    st.session_state['chat_history'] = []
    st.session_state['file_uploader_key'] = str(uuid.uuid4())

def display_chat_history():
    for entry in st.session_state['chat_history']:
        role = entry["role"]
        parts = entry["parts"][0]
        if 'text' in parts:
            st.markdown(f"{role.title()}: {parts['text']}")
        elif 'data' in parts:
            mime_type = parts.get('mime_type', '')
            if mime_type.startswith('image'):
                st.image(Image.open(io.BytesIO(base64.b64decode(parts['data']))), caption='Uploaded Image')
            elif mime_type == 'application/pdf':
                st.write("PDF Content:") 
                pdf_reader = PyPDF2.PdfReader(io.BytesIO(base64.b64decode(parts['data'])))
                for page_num in range(len(pdf_reader.pages)):
                    page = pdf_reader.pages[page_num]
                    st.write(page.extract_text())
            elif mime_type.startswith('video'):
                st.video(io.BytesIO(base64.b64decode(parts['data'])))

def get_chat_history_str():
    chat_history_str = "\n".join(
        f"{entry['role'].title()}: {part['text']}" if 'text' in part
        else f"{entry['role'].title()}: (File: {part.get('mime_type', '')})"
        for entry in st.session_state['chat_history']
        for part in entry['parts']
    )
    return chat_history_str

# Send message function with TTS integration
def send_message():
    user_input = st.session_state.user_input
    uploaded_files = st.session_state.uploaded_files
    prompts = []
    prompt_parts = []

    # Populate the prompts list with the existing chat history
    for entry in st.session_state['chat_history']:
        for part in entry['parts']:
            if 'text' in part:
                prompts.append(part['text'])
            elif 'data' in part:
                prompts.append(f"(File: {part.get('mime_type', '')})")
                prompt_parts.append(part)  # Add the entire part

    # Append the user input to the prompts list
    if user_input:
        prompts.append(user_input)
        st.session_state['chat_history'].append({"role": "user", "parts": [{"text": user_input}]})
        prompt_parts.append({"text": user_input})

    # Handle uploaded files
    if uploaded_files:
        for uploaded_file in uploaded_files:
            file_content = uploaded_file.read()
            base64_data = base64.b64encode(file_content).decode()
            prompts.append(f"(File: {uploaded_file.type})")
            prompt_parts.append({
                "mime_type": uploaded_file.type,
                "data": base64_data
            })
            st.session_state['chat_history'].append({
                "role": "user",
                "parts": [{"mime_type": uploaded_file.type, "data": base64_data}]
            })

    # Determine if vision model should be used
    use_vision_model = any(part.get('mime_type') == 'image/jpeg' for part in prompt_parts)

    # Use the selected model
    model_name = selected_model
    if use_vision_model and "pro" not in model_name:
        st.warning(f"The selected model ({model_name}) does not support image inputs. Choose a 'pro' model for image capabilities.")
        return

    model = genai.GenerativeModel(
        model_name=model_name,
        generation_config=generation_config,
        safety_settings=safety_settings
    )
    chat_history_str = "\n".join(prompts)

    if use_vision_model:
        # Include text and images for vision model
        generated_prompt = {"role": "user", "parts": prompt_parts}
    else:
        # Include text only for standard model
        generated_prompt = {"role": "user", "parts": [{"text": chat_history_str}]}

    response = model.generate_content([generated_prompt])
    response_text = response.text if hasattr(response, "text") else "No response text found."

    # After generating the response from the model, append it to the chat history
    if response_text:
        st.session_state['chat_history'].append({"role": "model", "parts": [{"text": response_text}]})

        # Convert the response text to speech if enabled
        if enable_tts:
            tts = gTTS(text=response_text, lang='en')
            tts_file = BytesIO()
            tts.write_to_fp(tts_file)
            tts_file.seek(0)
            st.audio(tts_file, format='audio/mp3')

    # Clear the input fields after sending the message
    st.session_state.user_input = ''
    st.session_state.uploaded_files = []
    st.session_state.file_uploader_key = str(uuid.uuid4())

    # Display the updated chat history
    display_chat_history()

# User input text area
user_input = st.text_area(
    "Enter your message here:",
    value="",
    key="user_input"
)

# File uploader for images
uploaded_files = st.file_uploader(
    "Upload files:",
    type=["png", "jpg", "jpeg", "mp4", "pdf"],  # Added mp4 and pdf
    accept_multiple_files=True,
    key=st.session_state.file_uploader_key
)

# Send message button
send_button = st.button(
    "Send",
    on_click=send_message
)

# Clear conversation button
clear_button = st.button("Clear Conversation", on_click=clear_conversation)

# Function to download the chat history as a text file
def download_chat_history():
    chat_history_str = get_chat_history_str()
    return chat_history_str

# Download button for the chat history
download_button = st.download_button(
    label="Download Chat",
    data=download_chat_history(),
    file_name="chat_history.txt",
    mime="text/plain"
)

# Ensure the file_uploader widget state is tied to the randomly generated key
st.session_state.uploaded_files = uploaded_files

# JavaScript to capture the Ctrl+Enter event and trigger a button click
st.markdown(
    """
    <script>
    document.addEventListener('DOMContentLoaded', (event) => {
        document.querySelector('.stTextArea textarea').addEventListener('keydown', function(e) {
            if (e.key === 'Enter' && e.ctrlKey) {
                document.querySelector('.stButton > button').click();
                e.preventDefault();
            }
        });
    });
    </script>
    """,
    unsafe_allow_html=True
)