File size: 7,744 Bytes
221a628
b557897
 
 
eabca2c
dfb90bd
b557897
2c4cf73
ab73386
7fd1c6d
 
 
23ed2c1
dfb90bd
2c4cf73
6e074fc
221a628
dfb90bd
dfdbfa8
 
d5a9ec5
dfdbfa8
221a628
dfb90bd
 
ccf48ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
837873a
dfb90bd
ce73371
 
eabca2c
 
7fd1c6d
 
 
 
b557897
d5a9ec5
b557897
d5a9ec5
221a628
2c4cf73
d5a9ec5
2c4cf73
ab73386
 
 
d5a9ec5
db9bdff
 
 
dfdbfa8
40e8df5
ce73371
40e8df5
7fd1c6d
40e8df5
dfb90bd
db9bdff
 
50b3dd7
 
db9bdff
50b3dd7
 
 
 
 
7fd1c6d
50b3dd7
 
 
 
 
 
 
db9bdff
 
50b3dd7
 
dfb90bd
7fd1c6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d5a9ec5
94fea6c
eabca2c
 
7fd1c6d
36e811b
dfb90bd
d5a9ec5
94fea6c
36e811b
d5a9ec5
94fea6c
 
 
 
ab73386
db9bdff
 
 
 
ab73386
7fd1c6d
 
 
 
 
 
 
 
 
d5a9ec5
db9bdff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7fd1c6d
dfb90bd
5837eff
9c3f46e
ce73371
50b3dd7
 
dfb90bd
665ff61
d5a9ec5
 
389cdce
d5a9ec5
 
 
 
 
 
 
 
 
 
50b3dd7
d5a9ec5
 
 
389cdce
db9bdff
 
ce73371
eabca2c
389cdce
 
7fd1c6d
 
 
 
 
 
 
 
 
 
d5a9ec5
 
665ff61
d5a9ec5
dfb90bd
 
d5a9ec5
e742fb1
 
 
 
 
 
 
dfb90bd
 
 
e742fb1
 
 
 
d5a9ec5
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
import streamlit as st
from PIL import Image
import io
import base64
import uuid
from gtts import gTTS
import google.generativeai as genai
from io import BytesIO
import PyPDF2
import soundfile as sf
import librosa 
import numpy as np

# Set your API key
api_key = "AIzaSyAHD0FwX-Ds6Y3eI-i5Oz7IdbJqR6rN7pg"  # Replace with your actual API key
genai.configure(api_key=api_key)

# Configure the generative AI model
generation_config = genai.GenerationConfig(
    temperature=0.9,
    max_output_tokens=3000
)

# Safety settings configuration
safety_settings = [
    {
        "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
        "threshold": "BLOCK_NONE",
    },
    {
        "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
        "threshold": "BLOCK_NONE",
    },
    {
        "category": "HARM_CATEGORY_HATE_SPEECH",
        "threshold": "BLOCK_NONE",
    },
    {
        "category": "HARM_CATEGORY_HARASSMENT",
        "threshold": "BLOCK_NONE",
    },
]

# Initialize session state
if 'chat_history' not in st.session_state:
    st.session_state['chat_history'] = []
if 'file_uploader_key' not in st.session_state:
    st.session_state['file_uploader_key'] = str(uuid.uuid4())
if 'recording_enabled' not in st.session_state:
    st.session_state['recording_enabled'] = False
if 'recorded_audio' not in st.session_state:
    st.session_state['recorded_audio'] = None

# --- Streamlit UI ---
st.title("Gemini Chatbot")
st.write("Interact with the powerful Gemini 1.5 models.")

# Model Selection Dropdown
selected_model = st.selectbox("Choose a Gemini 1.5 Model:", ["gemini-1.5-flash-latest", "gemini-1.5-pro-latest"])

# TTS Option Checkbox
enable_tts = st.checkbox("Enable Text-to-Speech")

# --- Helper Functions ---
def get_file_base64(file_content, mime_type):
    base64_data = base64.b64encode(file_content).decode()
    return {"mime_type": mime_type, "data": base64_data}

def clear_conversation():
    st.session_state['chat_history'] = []
    st.session_state['file_uploader_key'] = str(uuid.uuid4())
    st.session_state['recorded_audio'] = None

def display_chat_history():
    chat_container = st.empty()
    with chat_container.container():
        for entry in st.session_state['chat_history']:
            role = entry["role"]
            parts = entry["parts"][0] 
            if 'text' in parts:
                st.markdown(f"**{role.title()}:** {parts['text']}")
            elif 'data' in parts:
                mime_type = parts.get('mime_type', '')
                if mime_type.startswith('image'):
                    st.image(Image.open(io.BytesIO(base64.b64decode(parts['data']))),
                             caption='Uploaded Image', use_column_width=True)
                elif mime_type == 'application/pdf':
                    st.write("**PDF Content:**") 
                    pdf_reader = PyPDF2.PdfReader(io.BytesIO(base64.b64decode(parts['data'])))
                    for page_num in range(len(pdf_reader.pages)):
                        page = pdf_reader.pages[page_num]
                        st.write(page.extract_text())
                elif mime_type.startswith('audio'):
                    st.audio(io.BytesIO(base64.b64decode(parts['data'])), format=mime_type)
                elif mime_type.startswith('video'):
                    st.video(io.BytesIO(base64.b64decode(parts['data'])))

# --- Audio Recording Functions ---
def start_recording():
    st.session_state['recording_enabled'] = True
    st.warning("Recording started. Click 'Stop Recording' to finish.")

def stop_recording():
    st.session_state['recording_enabled'] = False
    st.success("Recording stopped.")

def process_audio(audio_data):
    # Convert to WAV format for compatibility
    wav_data, samplerate = librosa.load(audio_data, sr=None)
    sf.write("temp.wav", wav_data, samplerate, format="wav")
    with open("temp.wav", "rb") as f:
        wav_content = f.read()
    return wav_content, "audio/wav"

# --- Send Message Function ---
def send_message():
    user_input = st.session_state.user_input
    uploaded_files = st.session_state.uploaded_files
    recorded_audio = st.session_state.recorded_audio
    prompt_parts = []

    # Add user input to the prompt
    if user_input:
        prompt_parts.append({"text": user_input})
        st.session_state['chat_history'].append({"role": "user", "parts": [{"text": user_input}]})

    # Handle uploaded files
    if uploaded_files:
        for uploaded_file in uploaded_files:
            file_content = uploaded_file.read()
            prompt_parts.append(get_file_base64(file_content, uploaded_file.type))
            st.session_state['chat_history'].append(
                {"role": "user", "parts": [get_file_base64(file_content, uploaded_file.type)]}
            )

    # Handle recorded audio
    if recorded_audio:
        audio_content, audio_type = process_audio(recorded_audio)
        prompt_parts.append(get_file_base64(audio_content, audio_type))
        st.session_state['chat_history'].append(
            {"role": "user", "parts": [get_file_base64(audio_content, audio_type)]}
        )
        st.session_state['recorded_audio'] = None  # Reset recorded audio

    # Generate response using the selected model
    try:
        model = genai.GenerativeModel(
            model_name=selected_model,
            generation_config=generation_config,
            safety_settings=safety_settings
        )

        response = model.generate_content([{"role": "user", "parts": prompt_parts}])
        response_text = response.text if hasattr(response, "text") else "No response text found."

        if response_text:
            st.session_state['chat_history'].append({"role": "model", "parts": [{"text": response_text}]})
            if enable_tts:
                tts = gTTS(text=response_text, lang='en')
                tts_file = BytesIO()
                tts.write_to_fp(tts_file)
                tts_file.seek(0)
                st.audio(tts_file, format='audio/mp3')

    except Exception as e:
        st.error(f"An error occurred: {e}")

    st.session_state.user_input = ''
    st.session_state.uploaded_files = []
    st.session_state.file_uploader_key = str(uuid.uuid4())

    # Update the chat history display
    display_chat_history()

# --- User Input Area ---
col1, col2 = st.columns([3, 1])

with col1:
    user_input = st.text_area(
        "Enter your message:",
        value="",
        key="user_input"
    )
with col2:
    send_button = st.button(
        "Send",
        on_click=send_message,
        type="primary"
    )

# --- File Uploader ---
uploaded_files = st.file_uploader(
    "Upload Files (Images, Videos, PDFs, MP3):",
    type=["png", "jpg", "jpeg", "mp4", "pdf", "mp3"],
    accept_multiple_files=True,
    key=st.session_state.file_uploader_key
)

# --- Audio Recording ---
st.audio_recorder("Record audio:", key="recorded_audio")
col3, col4 = st.columns([1, 1])
with col3:
    if st.button("Start Recording"):
        start_recording()
with col4:
    if st.button("Stop Recording"):
        stop_recording()

# --- Other Buttons ---
st.button("Clear Conversation", on_click=clear_conversation)

# --- Ensure file_uploader state ---
st.session_state.uploaded_files = uploaded_files

# --- JavaScript for Ctrl+Enter ---
st.markdown(
    """
    <script>
    document.addEventListener('DOMContentLoaded', (event) => {
        document.querySelector('.stTextArea textarea').addEventListener('keydown', function(e) {
            if (e.key === 'Enter' && e.ctrlKey) {
                document.querySelector('.stButton > button').click();
                e.preventDefault();
            }
        });
    });
    </script>
    """,
    unsafe_allow_html=True
)

# --- Display Chat History ---
display_chat_history()