ziyadsuper2017 commited on
Commit
ab73386
·
verified ·
1 Parent(s): 2c4cf73

Trying to make it multimodal

Browse files
Files changed (1) hide show
  1. app.py +40 -24
app.py CHANGED
@@ -6,6 +6,7 @@ import uuid
6
  from gtts import gTTS
7
  import google.generativeai as genai
8
  from io import BytesIO
 
9
 
10
  # Set your API key
11
  api_key = "AIzaSyAHD0FwX-Ds6Y3eI-i5Oz7IdbJqR6rN7pg" # Replace with your actual API key
@@ -14,7 +15,7 @@ genai.configure(api_key=api_key)
14
  # Configure the generative AI model
15
  generation_config = genai.GenerationConfig(
16
  temperature=0.9,
17
- max_output_tokens=3000
18
  )
19
 
20
  # Safety settings configuration
@@ -48,6 +49,9 @@ st.title("Gemini Chatbot")
48
  # Model Selection Dropdown
49
  selected_model = st.selectbox("Select a Gemini 1.5 model:", ["gemini-1.5-flash-latest", "gemini-1.5-pro-latest"])
50
 
 
 
 
51
  # Helper functions for image processing and chat history management
52
  def get_image_base64(image):
53
  image = image.convert("RGB")
@@ -67,12 +71,22 @@ def display_chat_history():
67
  if 'text' in parts:
68
  st.markdown(f"{role.title()}: {parts['text']}")
69
  elif 'data' in parts:
70
- st.image(Image.open(io.BytesIO(base64.b64decode(parts['data']))), caption='Uploaded Image')
 
 
 
 
 
 
 
 
 
 
71
 
72
  def get_chat_history_str():
73
  chat_history_str = "\n".join(
74
  f"{entry['role'].title()}: {part['text']}" if 'text' in part
75
- else f"{entry['role'].title()}: (Image)"
76
  for entry in st.session_state['chat_history']
77
  for part in entry['parts']
78
  )
@@ -91,43 +105,44 @@ def send_message():
91
  if 'text' in part:
92
  prompts.append(part['text'])
93
  elif 'data' in part:
94
- # Add the image in base64 format to prompt_parts for vision model
95
- prompt_parts.append({"data": part['data'], "mime_type": "image/jpeg"})
96
- prompts.append("[Image]")
97
 
98
  # Append the user input to the prompts list
99
  if user_input:
100
  prompts.append(user_input)
101
  st.session_state['chat_history'].append({"role": "user", "parts": [{"text": user_input}]})
102
- # Also add the user text input to prompt_parts
103
  prompt_parts.append({"text": user_input})
104
 
105
  # Handle uploaded files
106
  if uploaded_files:
107
  for uploaded_file in uploaded_files:
108
- base64_image = get_image_base64(Image.open(uploaded_file))
109
- prompts.append("[Image]")
110
- prompt_parts.append({"data": base64_image, "mime_type": "image/jpeg"})
 
 
 
 
111
  st.session_state['chat_history'].append({
112
  "role": "user",
113
- "parts": [{"mime_type": uploaded_file.type, "data": base64_image}]
114
  })
115
 
116
- # Determine if vision model should be used
117
  use_vision_model = any(part.get('mime_type') == 'image/jpeg' for part in prompt_parts)
118
 
119
- # Use the selected model
120
  model_name = selected_model
121
  if use_vision_model and "pro" not in model_name:
122
  st.warning(f"The selected model ({model_name}) does not support image inputs. Choose a 'pro' model for image capabilities.")
123
  return
124
-
125
  model = genai.GenerativeModel(
126
  model_name=model_name,
127
  generation_config=generation_config,
128
  safety_settings=safety_settings
129
  )
130
-
131
  chat_history_str = "\n".join(prompts)
132
 
133
  if use_vision_model:
@@ -142,14 +157,15 @@ def send_message():
142
 
143
  # After generating the response from the model, append it to the chat history
144
  if response_text:
145
- st.session_state['chat_history'].append({"role": "model", "parts":[{"text": response_text}]})
146
 
147
- # Convert the response text to speech
148
- tts = gTTS(text=response_text, lang='en')
149
- tts_file = BytesIO()
150
- tts.write_to_fp(tts_file)
151
- tts_file.seek(0)
152
- st.audio(tts_file, format='audio/mp3')
 
153
 
154
  # Clear the input fields after sending the message
155
  st.session_state.user_input = ''
@@ -168,8 +184,8 @@ user_input = st.text_area(
168
 
169
  # File uploader for images
170
  uploaded_files = st.file_uploader(
171
- "Upload images:",
172
- type=["png", "jpg", "jpeg"],
173
  accept_multiple_files=True,
174
  key=st.session_state.file_uploader_key
175
  )
 
6
  from gtts import gTTS
7
  import google.generativeai as genai
8
  from io import BytesIO
9
+ import PyPDF2
10
 
11
  # Set your API key
12
  api_key = "AIzaSyAHD0FwX-Ds6Y3eI-i5Oz7IdbJqR6rN7pg" # Replace with your actual API key
 
15
  # Configure the generative AI model
16
  generation_config = genai.GenerationConfig(
17
  temperature=0.9,
18
+ max_output_tokens=4000
19
  )
20
 
21
  # Safety settings configuration
 
49
  # Model Selection Dropdown
50
  selected_model = st.selectbox("Select a Gemini 1.5 model:", ["gemini-1.5-flash-latest", "gemini-1.5-pro-latest"])
51
 
52
+ # TTS Option Checkbox
53
+ enable_tts = st.checkbox("Enable Text-to-Speech")
54
+
55
  # Helper functions for image processing and chat history management
56
  def get_image_base64(image):
57
  image = image.convert("RGB")
 
71
  if 'text' in parts:
72
  st.markdown(f"{role.title()}: {parts['text']}")
73
  elif 'data' in parts:
74
+ mime_type = parts.get('mime_type', '')
75
+ if mime_type.startswith('image'):
76
+ st.image(Image.open(io.BytesIO(base64.b64decode(parts['data']))), caption='Uploaded Image')
77
+ elif mime_type == 'application/pdf':
78
+ st.write("PDF Content:")
79
+ pdf_reader = PyPDF2.PdfReader(io.BytesIO(base64.b64decode(parts['data'])))
80
+ for page_num in range(len(pdf_reader.pages)):
81
+ page = pdf_reader.pages[page_num]
82
+ st.write(page.extract_text())
83
+ elif mime_type.startswith('video'):
84
+ st.video(io.BytesIO(base64.b64decode(parts['data'])))
85
 
86
  def get_chat_history_str():
87
  chat_history_str = "\n".join(
88
  f"{entry['role'].title()}: {part['text']}" if 'text' in part
89
+ else f"{entry['role'].title()}: (File: {part.get('mime_type', '')})"
90
  for entry in st.session_state['chat_history']
91
  for part in entry['parts']
92
  )
 
105
  if 'text' in part:
106
  prompts.append(part['text'])
107
  elif 'data' in part:
108
+ prompts.append(f"(File: {part.get('mime_type', '')})")
109
+ prompt_parts.append(part) # Add the entire part
 
110
 
111
  # Append the user input to the prompts list
112
  if user_input:
113
  prompts.append(user_input)
114
  st.session_state['chat_history'].append({"role": "user", "parts": [{"text": user_input}]})
 
115
  prompt_parts.append({"text": user_input})
116
 
117
  # Handle uploaded files
118
  if uploaded_files:
119
  for uploaded_file in uploaded_files:
120
+ file_content = uploaded_file.read()
121
+ base64_data = base64.b64encode(file_content).decode()
122
+ prompts.append(f"(File: {uploaded_file.type})")
123
+ prompt_parts.append({
124
+ "mime_type": uploaded_file.type,
125
+ "data": base64_data
126
+ })
127
  st.session_state['chat_history'].append({
128
  "role": "user",
129
+ "parts": [{"mime_type": uploaded_file.type, "data": base64_data}]
130
  })
131
 
132
+ # Determine if vision model should be used
133
  use_vision_model = any(part.get('mime_type') == 'image/jpeg' for part in prompt_parts)
134
 
135
+ # Use the selected model
136
  model_name = selected_model
137
  if use_vision_model and "pro" not in model_name:
138
  st.warning(f"The selected model ({model_name}) does not support image inputs. Choose a 'pro' model for image capabilities.")
139
  return
140
+
141
  model = genai.GenerativeModel(
142
  model_name=model_name,
143
  generation_config=generation_config,
144
  safety_settings=safety_settings
145
  )
 
146
  chat_history_str = "\n".join(prompts)
147
 
148
  if use_vision_model:
 
157
 
158
  # After generating the response from the model, append it to the chat history
159
  if response_text:
160
+ st.session_state['chat_history'].append({"role": "model", "parts": [{"text": response_text}]})
161
 
162
+ # Convert the response text to speech if enabled
163
+ if enable_tts:
164
+ tts = gTTS(text=response_text, lang='en')
165
+ tts_file = BytesIO()
166
+ tts.write_to_fp(tts_file)
167
+ tts_file.seek(0)
168
+ st.audio(tts_file, format='audio/mp3')
169
 
170
  # Clear the input fields after sending the message
171
  st.session_state.user_input = ''
 
184
 
185
  # File uploader for images
186
  uploaded_files = st.file_uploader(
187
+ "Upload files:",
188
+ type=["png", "jpg", "jpeg", "mp4", "pdf"], # Added mp4 and pdf
189
  accept_multiple_files=True,
190
  key=st.session_state.file_uploader_key
191
  )