ziyadsuper2017 commited on
Commit
94fea6c
·
1 Parent(s): e742fb1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -24
app.py CHANGED
@@ -41,8 +41,6 @@ if 'chat_history' not in st.session_state:
41
  st.session_state['chat_history'] = []
42
  if 'file_uploader_key' not in st.session_state:
43
  st.session_state['file_uploader_key'] = str(uuid.uuid4())
44
- if 'use_vision_model' not in st.session_state:
45
- st.session_state['use_vision_model'] = False
46
 
47
  st.title("Gemini Chatbot")
48
 
@@ -57,7 +55,6 @@ def get_image_base64(image):
57
  def clear_conversation():
58
  st.session_state['chat_history'] = []
59
  st.session_state['file_uploader_key'] = str(uuid.uuid4())
60
- st.session_state['use_vision_model'] = False
61
 
62
  def display_chat_history():
63
  for entry in st.session_state['chat_history']:
@@ -78,38 +75,56 @@ def get_chat_history_str():
78
  return chat_history_str
79
 
80
  # Send message function with TTS integration
81
- def send_message(tts=False):
82
  user_input = st.session_state.user_input
83
  uploaded_files = st.session_state.uploaded_files
84
- # Your existing code that processes user input and uploaded files...
85
 
86
- model_name = 'gemini-pro-vision' if st.session_state['use_vision_model'] else 'gemini-pro'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  model = genai.GenerativeModel(
88
  model_name=model_name,
89
  generation_config=generation_config,
90
  safety_settings=safety_settings
91
  )
92
  chat_history_str = "\n".join(prompts)
93
- if st.session_state['use_vision_model']:
94
- prompt_parts = [{"text": chat_history_str}] + [
95
- {"data": part['data'], "mime_type": "image/jpeg"}
96
- for entry in st.session_state['chat_history'] for part in entry['parts']
97
- if 'data' in part
98
- ]
99
- else:
100
- prompt_parts = [{"text": chat_history_str}]
101
  response = model.generate_content([{"role": "user", "parts": prompt_parts}])
102
  response_text = response.text if hasattr(response, "text") else "No response text found."
 
 
103
  if response_text:
104
  st.session_state['chat_history'].append({"role": "model", "parts": [{"text": response_text}]})
105
 
106
- # If TTS is enabled, convert the response text to speech
107
- if tts:
108
- tts = gTTS(text=response_text, lang='en')
109
- tts_file = BytesIO()
110
- tts.write_to_fp(tts_file)
111
- tts_file.seek(0)
112
- st.audio(tts_file, format='audio/mp3')
113
 
114
  # Clear the input fields after sending the message
115
  st.session_state.user_input = ''
@@ -122,6 +137,7 @@ def send_message(tts=False):
122
  # User input text area
123
  user_input = st.text_area(
124
  "Enter your message here:",
 
125
  key="user_input"
126
  )
127
 
@@ -136,8 +152,7 @@ uploaded_files = st.file_uploader(
136
  # Send message button
137
  send_button = st.button(
138
  "Send",
139
- on_click=send_message,
140
- args=(False,) # TTS disabled by default when clicking the Send button
141
  )
142
 
143
  # Clear conversation button
@@ -163,7 +178,6 @@ st.session_state.uploaded_files = uploaded_files
163
  st.markdown(
164
  """
165
  <script>
166
- // Use jQuery to capture the Ctrl+Enter event and click the 'Send' button
167
  document.addEventListener('DOMContentLoaded', (event) => {
168
  document.querySelector('.stTextArea textarea').addEventListener('keydown', function(e) {
169
  if (e.key === 'Enter' && e.ctrlKey) {
 
41
  st.session_state['chat_history'] = []
42
  if 'file_uploader_key' not in st.session_state:
43
  st.session_state['file_uploader_key'] = str(uuid.uuid4())
 
 
44
 
45
  st.title("Gemini Chatbot")
46
 
 
55
  def clear_conversation():
56
  st.session_state['chat_history'] = []
57
  st.session_state['file_uploader_key'] = str(uuid.uuid4())
 
58
 
59
  def display_chat_history():
60
  for entry in st.session_state['chat_history']:
 
75
  return chat_history_str
76
 
77
  # Send message function with TTS integration
78
+ def send_message():
79
  user_input = st.session_state.user_input
80
  uploaded_files = st.session_state.uploaded_files
81
+ prompts = []
82
 
83
+ # Populate the prompts list with the existing chat history
84
+ for entry in st.session_state['chat_history']:
85
+ for part in entry['parts']:
86
+ if 'text' in part:
87
+ prompts.append(part['text'])
88
+ elif 'data' in part:
89
+ prompts.append("[Image]")
90
+
91
+ # Append the user input to the prompts list
92
+ if user_input:
93
+ prompts.append(user_input)
94
+ st.session_state['chat_history'].append({"role": "user", "parts": [{"text": user_input}]})
95
+
96
+ # Handle uploaded files
97
+ if uploaded_files:
98
+ for uploaded_file in uploaded_files:
99
+ base64_image = get_image_base64(Image.open(uploaded_file))
100
+ prompts.append("[Image]")
101
+ st.session_state['chat_history'].append({
102
+ "role": "user",
103
+ "parts": [{"mime_type": uploaded_file.type, "data": base64_image}]
104
+ })
105
+
106
+ # Set up the model and generate a response
107
+ model_name = 'gemini-pro-vision' if st.session_state.get('use_vision_model', False) else 'gemini-pro'
108
  model = genai.GenerativeModel(
109
  model_name=model_name,
110
  generation_config=generation_config,
111
  safety_settings=safety_settings
112
  )
113
  chat_history_str = "\n".join(prompts)
114
+ prompt_parts = [{"text": chat_history_str}]
 
 
 
 
 
 
 
115
  response = model.generate_content([{"role": "user", "parts": prompt_parts}])
116
  response_text = response.text if hasattr(response, "text") else "No response text found."
117
+
118
+ # After generating the response from the model, append it to the chat history
119
  if response_text:
120
  st.session_state['chat_history'].append({"role": "model", "parts": [{"text": response_text}]})
121
 
122
+ # Convert the response text to speech
123
+ tts = gTTS(text=response_text, lang='en')
124
+ tts_file = BytesIO()
125
+ tts.write_to_fp(tts_file)
126
+ tts_file.seek(0)
127
+ st.audio(tts_file, format='audio/mp3')
 
128
 
129
  # Clear the input fields after sending the message
130
  st.session_state.user_input = ''
 
137
  # User input text area
138
  user_input = st.text_area(
139
  "Enter your message here:",
140
+ value="",
141
  key="user_input"
142
  )
143
 
 
152
  # Send message button
153
  send_button = st.button(
154
  "Send",
155
+ on_click=send_message
 
156
  )
157
 
158
  # Clear conversation button
 
178
  st.markdown(
179
  """
180
  <script>
 
181
  document.addEventListener('DOMContentLoaded', (event) => {
182
  document.querySelector('.stTextArea textarea').addEventListener('keydown', function(e) {
183
  if (e.key === 'Enter' && e.ctrlKey) {