Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -15,7 +15,7 @@ genai.configure(api_key=api_key)
|
|
15 |
# Configure the generative AI model
|
16 |
generation_config = genai.GenerationConfig(
|
17 |
temperature=0.9,
|
18 |
-
max_output_tokens=
|
19 |
)
|
20 |
|
21 |
# Safety settings configuration
|
@@ -44,15 +44,17 @@ if 'chat_history' not in st.session_state:
|
|
44 |
if 'file_uploader_key' not in st.session_state:
|
45 |
st.session_state['file_uploader_key'] = str(uuid.uuid4())
|
46 |
|
|
|
47 |
st.title("Gemini Chatbot")
|
|
|
48 |
|
49 |
# Model Selection Dropdown
|
50 |
-
selected_model = st.selectbox("
|
51 |
|
52 |
# TTS Option Checkbox
|
53 |
enable_tts = st.checkbox("Enable Text-to-Speech")
|
54 |
|
55 |
-
# Helper
|
56 |
def get_image_base64(image):
|
57 |
image = image.convert("RGB")
|
58 |
buffered = io.BytesIO()
|
@@ -69,13 +71,13 @@ def display_chat_history():
|
|
69 |
role = entry["role"]
|
70 |
parts = entry["parts"][0]
|
71 |
if 'text' in parts:
|
72 |
-
st.markdown(f"{role.title()}
|
73 |
elif 'data' in parts:
|
74 |
mime_type = parts.get('mime_type', '')
|
75 |
if mime_type.startswith('image'):
|
76 |
-
st.image(Image.open(io.BytesIO(base64.b64decode(parts['data']))), caption='Uploaded Image')
|
77 |
elif mime_type == 'application/pdf':
|
78 |
-
st.write("PDF Content
|
79 |
pdf_reader = PyPDF2.PdfReader(io.BytesIO(base64.b64decode(parts['data'])))
|
80 |
for page_num in range(len(pdf_reader.pages)):
|
81 |
page = pdf_reader.pages[page_num]
|
@@ -83,83 +85,37 @@ def display_chat_history():
|
|
83 |
elif mime_type.startswith('video'):
|
84 |
st.video(io.BytesIO(base64.b64decode(parts['data'])))
|
85 |
|
86 |
-
|
87 |
-
chat_history_str = "\n".join(
|
88 |
-
f"{entry['role'].title()}: {part['text']}" if 'text' in part
|
89 |
-
else f"{entry['role'].title()}: (File: {part.get('mime_type', '')})"
|
90 |
-
for entry in st.session_state['chat_history']
|
91 |
-
for part in entry['parts']
|
92 |
-
)
|
93 |
-
return chat_history_str
|
94 |
-
|
95 |
-
# Send message function with TTS integration
|
96 |
def send_message():
|
97 |
user_input = st.session_state.user_input
|
98 |
uploaded_files = st.session_state.uploaded_files
|
99 |
-
prompts = []
|
100 |
prompt_parts = []
|
101 |
|
102 |
-
#
|
103 |
-
for entry in st.session_state['chat_history']:
|
104 |
-
for part in entry['parts']:
|
105 |
-
if 'text' in part:
|
106 |
-
prompts.append(part['text'])
|
107 |
-
elif 'data' in part:
|
108 |
-
prompts.append(f"(File: {part.get('mime_type', '')})")
|
109 |
-
prompt_parts.append(part) # Add the entire part
|
110 |
-
|
111 |
-
# Append the user input to the prompts list
|
112 |
if user_input:
|
113 |
-
prompts.append(user_input)
|
114 |
-
st.session_state['chat_history'].append({"role": "user", "parts": [{"text": user_input}]})
|
115 |
prompt_parts.append({"text": user_input})
|
|
|
116 |
|
117 |
# Handle uploaded files
|
118 |
if uploaded_files:
|
119 |
for uploaded_file in uploaded_files:
|
120 |
file_content = uploaded_file.read()
|
121 |
base64_data = base64.b64encode(file_content).decode()
|
122 |
-
|
123 |
-
|
124 |
-
"mime_type": uploaded_file.type,
|
125 |
-
"data": base64_data
|
126 |
-
})
|
127 |
-
st.session_state['chat_history'].append({
|
128 |
-
"role": "user",
|
129 |
-
"parts": [{"mime_type": uploaded_file.type, "data": base64_data}]
|
130 |
-
})
|
131 |
-
|
132 |
-
# Determine if vision model should be used
|
133 |
-
use_vision_model = any(part.get('mime_type') == 'image/jpeg' for part in prompt_parts)
|
134 |
-
|
135 |
-
# Use the selected model
|
136 |
-
model_name = selected_model
|
137 |
-
if use_vision_model and "pro" not in model_name:
|
138 |
-
st.warning(f"The selected model ({model_name}) does not support image inputs. Choose a 'pro' model for image capabilities.")
|
139 |
-
return
|
140 |
|
|
|
141 |
model = genai.GenerativeModel(
|
142 |
-
model_name=
|
143 |
generation_config=generation_config,
|
144 |
safety_settings=safety_settings
|
145 |
)
|
146 |
-
chat_history_str = "\n".join(prompts)
|
147 |
|
148 |
-
|
149 |
-
# Include text and images for vision model
|
150 |
-
generated_prompt = {"role": "user", "parts": prompt_parts}
|
151 |
-
else:
|
152 |
-
# Include text only for standard model
|
153 |
-
generated_prompt = {"role": "user", "parts": [{"text": chat_history_str}]}
|
154 |
-
|
155 |
-
response = model.generate_content([generated_prompt])
|
156 |
response_text = response.text if hasattr(response, "text") else "No response text found."
|
157 |
|
158 |
-
# After generating the response from the model, append it to the chat history
|
159 |
if response_text:
|
160 |
st.session_state['chat_history'].append({"role": "model", "parts": [{"text": response_text}]})
|
161 |
-
|
162 |
-
# Convert the response text to speech if enabled
|
163 |
if enable_tts:
|
164 |
tts = gTTS(text=response_text, lang='en')
|
165 |
tts_file = BytesIO()
|
@@ -167,55 +123,42 @@ def send_message():
|
|
167 |
tts_file.seek(0)
|
168 |
st.audio(tts_file, format='audio/mp3')
|
169 |
|
170 |
-
# Clear the input fields after sending the message
|
171 |
st.session_state.user_input = ''
|
172 |
st.session_state.uploaded_files = []
|
173 |
st.session_state.file_uploader_key = str(uuid.uuid4())
|
174 |
-
|
175 |
-
# Display the updated chat history
|
176 |
display_chat_history()
|
177 |
|
178 |
-
# User
|
179 |
-
|
180 |
-
"Enter your message here:",
|
181 |
-
value="",
|
182 |
-
key="user_input"
|
183 |
-
)
|
184 |
|
185 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
uploaded_files = st.file_uploader(
|
187 |
-
"Upload
|
188 |
-
type=["png", "jpg", "jpeg", "mp4", "pdf"],
|
189 |
accept_multiple_files=True,
|
190 |
key=st.session_state.file_uploader_key
|
191 |
)
|
192 |
|
193 |
-
#
|
194 |
-
|
195 |
-
"Send",
|
196 |
-
on_click=send_message
|
197 |
-
)
|
198 |
-
|
199 |
-
# Clear conversation button
|
200 |
-
clear_button = st.button("Clear Conversation", on_click=clear_conversation)
|
201 |
-
|
202 |
-
# Function to download the chat history as a text file
|
203 |
-
def download_chat_history():
|
204 |
-
chat_history_str = get_chat_history_str()
|
205 |
-
return chat_history_str
|
206 |
|
207 |
-
#
|
208 |
-
download_button = st.download_button(
|
209 |
-
label="Download Chat",
|
210 |
-
data=download_chat_history(),
|
211 |
-
file_name="chat_history.txt",
|
212 |
-
mime="text/plain"
|
213 |
-
)
|
214 |
-
|
215 |
-
# Ensure the file_uploader widget state is tied to the randomly generated key
|
216 |
st.session_state.uploaded_files = uploaded_files
|
217 |
|
218 |
-
# JavaScript
|
219 |
st.markdown(
|
220 |
"""
|
221 |
<script>
|
@@ -230,4 +173,7 @@ st.markdown(
|
|
230 |
</script>
|
231 |
""",
|
232 |
unsafe_allow_html=True
|
233 |
-
)
|
|
|
|
|
|
|
|
15 |
# Configure the generative AI model
|
16 |
generation_config = genai.GenerationConfig(
|
17 |
temperature=0.9,
|
18 |
+
max_output_tokens=3000
|
19 |
)
|
20 |
|
21 |
# Safety settings configuration
|
|
|
44 |
if 'file_uploader_key' not in st.session_state:
|
45 |
st.session_state['file_uploader_key'] = str(uuid.uuid4())
|
46 |
|
47 |
+
# --- Streamlit UI ---
|
48 |
st.title("Gemini Chatbot")
|
49 |
+
st.write("Interact with the powerful Gemini 1.5 models.")
|
50 |
|
51 |
# Model Selection Dropdown
|
52 |
+
selected_model = st.selectbox("Choose a Gemini 1.5 Model:", ["gemini-1.5-flash-latest", "gemini-1.5-pro-latest"])
|
53 |
|
54 |
# TTS Option Checkbox
|
55 |
enable_tts = st.checkbox("Enable Text-to-Speech")
|
56 |
|
57 |
+
# --- Helper Functions ---
|
58 |
def get_image_base64(image):
|
59 |
image = image.convert("RGB")
|
60 |
buffered = io.BytesIO()
|
|
|
71 |
role = entry["role"]
|
72 |
parts = entry["parts"][0]
|
73 |
if 'text' in parts:
|
74 |
+
st.markdown(f"**{role.title()}:** {parts['text']}")
|
75 |
elif 'data' in parts:
|
76 |
mime_type = parts.get('mime_type', '')
|
77 |
if mime_type.startswith('image'):
|
78 |
+
st.image(Image.open(io.BytesIO(base64.b64decode(parts['data']))), caption='Uploaded Image', use_column_width=True)
|
79 |
elif mime_type == 'application/pdf':
|
80 |
+
st.write("**PDF Content:**")
|
81 |
pdf_reader = PyPDF2.PdfReader(io.BytesIO(base64.b64decode(parts['data'])))
|
82 |
for page_num in range(len(pdf_reader.pages)):
|
83 |
page = pdf_reader.pages[page_num]
|
|
|
85 |
elif mime_type.startswith('video'):
|
86 |
st.video(io.BytesIO(base64.b64decode(parts['data'])))
|
87 |
|
88 |
+
# --- Send Message Function ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
def send_message():
|
90 |
user_input = st.session_state.user_input
|
91 |
uploaded_files = st.session_state.uploaded_files
|
|
|
92 |
prompt_parts = []
|
93 |
|
94 |
+
# Add user input to the prompt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
if user_input:
|
|
|
|
|
96 |
prompt_parts.append({"text": user_input})
|
97 |
+
st.session_state['chat_history'].append({"role": "user", "parts": [{"text": user_input}]})
|
98 |
|
99 |
# Handle uploaded files
|
100 |
if uploaded_files:
|
101 |
for uploaded_file in uploaded_files:
|
102 |
file_content = uploaded_file.read()
|
103 |
base64_data = base64.b64encode(file_content).decode()
|
104 |
+
prompt_parts.append({"mime_type": uploaded_file.type, "data": base64_data})
|
105 |
+
st.session_state['chat_history'].append({"role": "user", "parts": [{"mime_type": uploaded_file.type, "data": base64_data}]})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
|
107 |
+
# Generate response using the selected model
|
108 |
model = genai.GenerativeModel(
|
109 |
+
model_name=selected_model,
|
110 |
generation_config=generation_config,
|
111 |
safety_settings=safety_settings
|
112 |
)
|
|
|
113 |
|
114 |
+
response = model.generate_content([{"role": "user", "parts": prompt_parts}])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
response_text = response.text if hasattr(response, "text") else "No response text found."
|
116 |
|
|
|
117 |
if response_text:
|
118 |
st.session_state['chat_history'].append({"role": "model", "parts": [{"text": response_text}]})
|
|
|
|
|
119 |
if enable_tts:
|
120 |
tts = gTTS(text=response_text, lang='en')
|
121 |
tts_file = BytesIO()
|
|
|
123 |
tts_file.seek(0)
|
124 |
st.audio(tts_file, format='audio/mp3')
|
125 |
|
|
|
126 |
st.session_state.user_input = ''
|
127 |
st.session_state.uploaded_files = []
|
128 |
st.session_state.file_uploader_key = str(uuid.uuid4())
|
|
|
|
|
129 |
display_chat_history()
|
130 |
|
131 |
+
# --- User Input Area ---
|
132 |
+
col1, col2 = st.columns([3, 1])
|
|
|
|
|
|
|
|
|
133 |
|
134 |
+
with col1:
|
135 |
+
user_input = st.text_area(
|
136 |
+
"Enter your message:",
|
137 |
+
value="",
|
138 |
+
key="user_input"
|
139 |
+
)
|
140 |
+
with col2:
|
141 |
+
send_button = st.button(
|
142 |
+
"Send",
|
143 |
+
on_click=send_message,
|
144 |
+
type="primary" # Makes the Send button prominent
|
145 |
+
)
|
146 |
+
|
147 |
+
# --- File Uploader ---
|
148 |
uploaded_files = st.file_uploader(
|
149 |
+
"Upload Files (Images, Videos, PDFs):",
|
150 |
+
type=["png", "jpg", "jpeg", "mp4", "pdf"],
|
151 |
accept_multiple_files=True,
|
152 |
key=st.session_state.file_uploader_key
|
153 |
)
|
154 |
|
155 |
+
# --- Other Buttons ---
|
156 |
+
st.button("Clear Conversation", on_click=clear_conversation)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
157 |
|
158 |
+
# --- Ensure file_uploader state ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
159 |
st.session_state.uploaded_files = uploaded_files
|
160 |
|
161 |
+
# --- JavaScript for Ctrl+Enter ---
|
162 |
st.markdown(
|
163 |
"""
|
164 |
<script>
|
|
|
173 |
</script>
|
174 |
""",
|
175 |
unsafe_allow_html=True
|
176 |
+
)
|
177 |
+
|
178 |
+
# --- Display Chat History ---
|
179 |
+
display_chat_history()
|