ziyadsuper2017 commited on
Commit
d5a9ec5
·
verified ·
1 Parent(s): 672cfd7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -96
app.py CHANGED
@@ -15,7 +15,7 @@ genai.configure(api_key=api_key)
15
  # Configure the generative AI model
16
  generation_config = genai.GenerationConfig(
17
  temperature=0.9,
18
- max_output_tokens=4000
19
  )
20
 
21
  # Safety settings configuration
@@ -44,15 +44,17 @@ if 'chat_history' not in st.session_state:
44
  if 'file_uploader_key' not in st.session_state:
45
  st.session_state['file_uploader_key'] = str(uuid.uuid4())
46
 
 
47
  st.title("Gemini Chatbot")
 
48
 
49
  # Model Selection Dropdown
50
- selected_model = st.selectbox("Select a Gemini 1.5 model:", ["gemini-1.5-flash-latest", "gemini-1.5-pro-latest"])
51
 
52
  # TTS Option Checkbox
53
  enable_tts = st.checkbox("Enable Text-to-Speech")
54
 
55
- # Helper functions for image processing and chat history management
56
  def get_image_base64(image):
57
  image = image.convert("RGB")
58
  buffered = io.BytesIO()
@@ -69,13 +71,13 @@ def display_chat_history():
69
  role = entry["role"]
70
  parts = entry["parts"][0]
71
  if 'text' in parts:
72
- st.markdown(f"{role.title()}: {parts['text']}")
73
  elif 'data' in parts:
74
  mime_type = parts.get('mime_type', '')
75
  if mime_type.startswith('image'):
76
- st.image(Image.open(io.BytesIO(base64.b64decode(parts['data']))), caption='Uploaded Image')
77
  elif mime_type == 'application/pdf':
78
- st.write("PDF Content:")
79
  pdf_reader = PyPDF2.PdfReader(io.BytesIO(base64.b64decode(parts['data'])))
80
  for page_num in range(len(pdf_reader.pages)):
81
  page = pdf_reader.pages[page_num]
@@ -83,83 +85,37 @@ def display_chat_history():
83
  elif mime_type.startswith('video'):
84
  st.video(io.BytesIO(base64.b64decode(parts['data'])))
85
 
86
- def get_chat_history_str():
87
- chat_history_str = "\n".join(
88
- f"{entry['role'].title()}: {part['text']}" if 'text' in part
89
- else f"{entry['role'].title()}: (File: {part.get('mime_type', '')})"
90
- for entry in st.session_state['chat_history']
91
- for part in entry['parts']
92
- )
93
- return chat_history_str
94
-
95
- # Send message function with TTS integration
96
  def send_message():
97
  user_input = st.session_state.user_input
98
  uploaded_files = st.session_state.uploaded_files
99
- prompts = []
100
  prompt_parts = []
101
 
102
- # Populate the prompts list with the existing chat history
103
- for entry in st.session_state['chat_history']:
104
- for part in entry['parts']:
105
- if 'text' in part:
106
- prompts.append(part['text'])
107
- elif 'data' in part:
108
- prompts.append(f"(File: {part.get('mime_type', '')})")
109
- prompt_parts.append(part) # Add the entire part
110
-
111
- # Append the user input to the prompts list
112
  if user_input:
113
- prompts.append(user_input)
114
- st.session_state['chat_history'].append({"role": "user", "parts": [{"text": user_input}]})
115
  prompt_parts.append({"text": user_input})
 
116
 
117
  # Handle uploaded files
118
  if uploaded_files:
119
  for uploaded_file in uploaded_files:
120
  file_content = uploaded_file.read()
121
  base64_data = base64.b64encode(file_content).decode()
122
- prompts.append(f"(File: {uploaded_file.type})")
123
- prompt_parts.append({
124
- "mime_type": uploaded_file.type,
125
- "data": base64_data
126
- })
127
- st.session_state['chat_history'].append({
128
- "role": "user",
129
- "parts": [{"mime_type": uploaded_file.type, "data": base64_data}]
130
- })
131
-
132
- # Determine if vision model should be used
133
- use_vision_model = any(part.get('mime_type') == 'image/jpeg' for part in prompt_parts)
134
-
135
- # Use the selected model
136
- model_name = selected_model
137
- if use_vision_model and "pro" not in model_name:
138
- st.warning(f"The selected model ({model_name}) does not support image inputs. Choose a 'pro' model for image capabilities.")
139
- return
140
 
 
141
  model = genai.GenerativeModel(
142
- model_name=model_name,
143
  generation_config=generation_config,
144
  safety_settings=safety_settings
145
  )
146
- chat_history_str = "\n".join(prompts)
147
 
148
- if use_vision_model:
149
- # Include text and images for vision model
150
- generated_prompt = {"role": "user", "parts": prompt_parts}
151
- else:
152
- # Include text only for standard model
153
- generated_prompt = {"role": "user", "parts": [{"text": chat_history_str}]}
154
-
155
- response = model.generate_content([generated_prompt])
156
  response_text = response.text if hasattr(response, "text") else "No response text found."
157
 
158
- # After generating the response from the model, append it to the chat history
159
  if response_text:
160
  st.session_state['chat_history'].append({"role": "model", "parts": [{"text": response_text}]})
161
-
162
- # Convert the response text to speech if enabled
163
  if enable_tts:
164
  tts = gTTS(text=response_text, lang='en')
165
  tts_file = BytesIO()
@@ -167,55 +123,42 @@ def send_message():
167
  tts_file.seek(0)
168
  st.audio(tts_file, format='audio/mp3')
169
 
170
- # Clear the input fields after sending the message
171
  st.session_state.user_input = ''
172
  st.session_state.uploaded_files = []
173
  st.session_state.file_uploader_key = str(uuid.uuid4())
174
-
175
- # Display the updated chat history
176
  display_chat_history()
177
 
178
- # User input text area
179
- user_input = st.text_area(
180
- "Enter your message here:",
181
- value="",
182
- key="user_input"
183
- )
184
 
185
- # File uploader for images
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  uploaded_files = st.file_uploader(
187
- "Upload files:",
188
- type=["png", "jpg", "jpeg", "mp4", "pdf"], # Added mp4 and pdf
189
  accept_multiple_files=True,
190
  key=st.session_state.file_uploader_key
191
  )
192
 
193
- # Send message button
194
- send_button = st.button(
195
- "Send",
196
- on_click=send_message
197
- )
198
-
199
- # Clear conversation button
200
- clear_button = st.button("Clear Conversation", on_click=clear_conversation)
201
-
202
- # Function to download the chat history as a text file
203
- def download_chat_history():
204
- chat_history_str = get_chat_history_str()
205
- return chat_history_str
206
 
207
- # Download button for the chat history
208
- download_button = st.download_button(
209
- label="Download Chat",
210
- data=download_chat_history(),
211
- file_name="chat_history.txt",
212
- mime="text/plain"
213
- )
214
-
215
- # Ensure the file_uploader widget state is tied to the randomly generated key
216
  st.session_state.uploaded_files = uploaded_files
217
 
218
- # JavaScript to capture the Ctrl+Enter event and trigger a button click
219
  st.markdown(
220
  """
221
  <script>
@@ -230,4 +173,7 @@ st.markdown(
230
  </script>
231
  """,
232
  unsafe_allow_html=True
233
- )
 
 
 
 
15
  # Configure the generative AI model
16
  generation_config = genai.GenerationConfig(
17
  temperature=0.9,
18
+ max_output_tokens=3000
19
  )
20
 
21
  # Safety settings configuration
 
44
  if 'file_uploader_key' not in st.session_state:
45
  st.session_state['file_uploader_key'] = str(uuid.uuid4())
46
 
47
+ # --- Streamlit UI ---
48
  st.title("Gemini Chatbot")
49
+ st.write("Interact with the powerful Gemini 1.5 models.")
50
 
51
  # Model Selection Dropdown
52
+ selected_model = st.selectbox("Choose a Gemini 1.5 Model:", ["gemini-1.5-flash-latest", "gemini-1.5-pro-latest"])
53
 
54
  # TTS Option Checkbox
55
  enable_tts = st.checkbox("Enable Text-to-Speech")
56
 
57
+ # --- Helper Functions ---
58
  def get_image_base64(image):
59
  image = image.convert("RGB")
60
  buffered = io.BytesIO()
 
71
  role = entry["role"]
72
  parts = entry["parts"][0]
73
  if 'text' in parts:
74
+ st.markdown(f"**{role.title()}:** {parts['text']}")
75
  elif 'data' in parts:
76
  mime_type = parts.get('mime_type', '')
77
  if mime_type.startswith('image'):
78
+ st.image(Image.open(io.BytesIO(base64.b64decode(parts['data']))), caption='Uploaded Image', use_column_width=True)
79
  elif mime_type == 'application/pdf':
80
+ st.write("**PDF Content:**")
81
  pdf_reader = PyPDF2.PdfReader(io.BytesIO(base64.b64decode(parts['data'])))
82
  for page_num in range(len(pdf_reader.pages)):
83
  page = pdf_reader.pages[page_num]
 
85
  elif mime_type.startswith('video'):
86
  st.video(io.BytesIO(base64.b64decode(parts['data'])))
87
 
88
+ # --- Send Message Function ---
 
 
 
 
 
 
 
 
 
89
  def send_message():
90
  user_input = st.session_state.user_input
91
  uploaded_files = st.session_state.uploaded_files
 
92
  prompt_parts = []
93
 
94
+ # Add user input to the prompt
 
 
 
 
 
 
 
 
 
95
  if user_input:
 
 
96
  prompt_parts.append({"text": user_input})
97
+ st.session_state['chat_history'].append({"role": "user", "parts": [{"text": user_input}]})
98
 
99
  # Handle uploaded files
100
  if uploaded_files:
101
  for uploaded_file in uploaded_files:
102
  file_content = uploaded_file.read()
103
  base64_data = base64.b64encode(file_content).decode()
104
+ prompt_parts.append({"mime_type": uploaded_file.type, "data": base64_data})
105
+ st.session_state['chat_history'].append({"role": "user", "parts": [{"mime_type": uploaded_file.type, "data": base64_data}]})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
+ # Generate response using the selected model
108
  model = genai.GenerativeModel(
109
+ model_name=selected_model,
110
  generation_config=generation_config,
111
  safety_settings=safety_settings
112
  )
 
113
 
114
+ response = model.generate_content([{"role": "user", "parts": prompt_parts}])
 
 
 
 
 
 
 
115
  response_text = response.text if hasattr(response, "text") else "No response text found."
116
 
 
117
  if response_text:
118
  st.session_state['chat_history'].append({"role": "model", "parts": [{"text": response_text}]})
 
 
119
  if enable_tts:
120
  tts = gTTS(text=response_text, lang='en')
121
  tts_file = BytesIO()
 
123
  tts_file.seek(0)
124
  st.audio(tts_file, format='audio/mp3')
125
 
 
126
  st.session_state.user_input = ''
127
  st.session_state.uploaded_files = []
128
  st.session_state.file_uploader_key = str(uuid.uuid4())
 
 
129
  display_chat_history()
130
 
131
+ # --- User Input Area ---
132
+ col1, col2 = st.columns([3, 1])
 
 
 
 
133
 
134
+ with col1:
135
+ user_input = st.text_area(
136
+ "Enter your message:",
137
+ value="",
138
+ key="user_input"
139
+ )
140
+ with col2:
141
+ send_button = st.button(
142
+ "Send",
143
+ on_click=send_message,
144
+ type="primary" # Makes the Send button prominent
145
+ )
146
+
147
+ # --- File Uploader ---
148
  uploaded_files = st.file_uploader(
149
+ "Upload Files (Images, Videos, PDFs):",
150
+ type=["png", "jpg", "jpeg", "mp4", "pdf"],
151
  accept_multiple_files=True,
152
  key=st.session_state.file_uploader_key
153
  )
154
 
155
+ # --- Other Buttons ---
156
+ st.button("Clear Conversation", on_click=clear_conversation)
 
 
 
 
 
 
 
 
 
 
 
157
 
158
+ # --- Ensure file_uploader state ---
 
 
 
 
 
 
 
 
159
  st.session_state.uploaded_files = uploaded_files
160
 
161
+ # --- JavaScript for Ctrl+Enter ---
162
  st.markdown(
163
  """
164
  <script>
 
173
  </script>
174
  """,
175
  unsafe_allow_html=True
176
+ )
177
+
178
+ # --- Display Chat History ---
179
+ display_chat_history()