Spaces:
Runtime error
Runtime error
File size: 7,808 Bytes
221a628 b557897 eabca2c dfb90bd b557897 2c4cf73 ab73386 23ed2c1 dfb90bd 2c4cf73 6e074fc 221a628 dfb90bd dfdbfa8 ab73386 dfdbfa8 221a628 dfb90bd ccf48ce 837873a dfb90bd ce73371 eabca2c b557897 221a628 2c4cf73 ab73386 dfb90bd 00bfc2f 665ff61 00bfc2f dfdbfa8 40e8df5 ce73371 40e8df5 dfb90bd ab73386 dfb90bd ab73386 dfb90bd 94fea6c eabca2c 94fea6c 36e811b dfb90bd 94fea6c ab73386 94fea6c 36e811b 94fea6c ab73386 94fea6c ab73386 94fea6c ab73386 36e811b ab73386 2c4cf73 ab73386 ce73371 fb00ecf 2c4cf73 36e811b ce73371 94fea6c 326cdbe ab73386 e742fb1 ab73386 dfb90bd 5837eff 9c3f46e ce73371 326cdbe dfb90bd 665ff61 dfb90bd 94fea6c dfb90bd 389cdce dfb90bd 389cdce ab73386 ce73371 eabca2c 389cdce dfb90bd 94fea6c dfb90bd a76b0fb dfb90bd ce73371 dfb90bd 665ff61 dfb90bd 665ff61 ce73371 dfb90bd e742fb1 dfb90bd e742fb1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 |
import streamlit as st
from PIL import Image
import io
import base64
import uuid
from gtts import gTTS
import google.generativeai as genai
from io import BytesIO
import PyPDF2
# Set your API key
api_key = "AIzaSyAHD0FwX-Ds6Y3eI-i5Oz7IdbJqR6rN7pg" # Replace with your actual API key
genai.configure(api_key=api_key)
# Configure the generative AI model
generation_config = genai.GenerationConfig(
temperature=0.9,
max_output_tokens=4000
)
# Safety settings configuration
safety_settings = [
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_NONE",
},
]
# Initialize session state
if 'chat_history' not in st.session_state:
st.session_state['chat_history'] = []
if 'file_uploader_key' not in st.session_state:
st.session_state['file_uploader_key'] = str(uuid.uuid4())
st.title("Gemini Chatbot")
# Model Selection Dropdown
selected_model = st.selectbox("Select a Gemini 1.5 model:", ["gemini-1.5-flash-latest", "gemini-1.5-pro-latest"])
# TTS Option Checkbox
enable_tts = st.checkbox("Enable Text-to-Speech")
# Helper functions for image processing and chat history management
def get_image_base64(image):
image = image.convert("RGB")
buffered = io.BytesIO()
image.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode()
return img_str
def clear_conversation():
st.session_state['chat_history'] = []
st.session_state['file_uploader_key'] = str(uuid.uuid4())
def display_chat_history():
for entry in st.session_state['chat_history']:
role = entry["role"]
parts = entry["parts"][0]
if 'text' in parts:
st.markdown(f"{role.title()}: {parts['text']}")
elif 'data' in parts:
mime_type = parts.get('mime_type', '')
if mime_type.startswith('image'):
st.image(Image.open(io.BytesIO(base64.b64decode(parts['data']))), caption='Uploaded Image')
elif mime_type == 'application/pdf':
st.write("PDF Content:")
pdf_reader = PyPDF2.PdfReader(io.BytesIO(base64.b64decode(parts['data'])))
for page_num in range(len(pdf_reader.pages)):
page = pdf_reader.pages[page_num]
st.write(page.extract_text())
elif mime_type.startswith('video'):
st.video(io.BytesIO(base64.b64decode(parts['data'])))
def get_chat_history_str():
chat_history_str = "\n".join(
f"{entry['role'].title()}: {part['text']}" if 'text' in part
else f"{entry['role'].title()}: (File: {part.get('mime_type', '')})"
for entry in st.session_state['chat_history']
for part in entry['parts']
)
return chat_history_str
# Send message function with TTS integration
def send_message():
user_input = st.session_state.user_input
uploaded_files = st.session_state.uploaded_files
prompts = []
prompt_parts = []
# Populate the prompts list with the existing chat history
for entry in st.session_state['chat_history']:
for part in entry['parts']:
if 'text' in part:
prompts.append(part['text'])
elif 'data' in part:
prompts.append(f"(File: {part.get('mime_type', '')})")
prompt_parts.append(part) # Add the entire part
# Append the user input to the prompts list
if user_input:
prompts.append(user_input)
st.session_state['chat_history'].append({"role": "user", "parts": [{"text": user_input}]})
prompt_parts.append({"text": user_input})
# Handle uploaded files
if uploaded_files:
for uploaded_file in uploaded_files:
file_content = uploaded_file.read()
base64_data = base64.b64encode(file_content).decode()
prompts.append(f"(File: {uploaded_file.type})")
prompt_parts.append({
"mime_type": uploaded_file.type,
"data": base64_data
})
st.session_state['chat_history'].append({
"role": "user",
"parts": [{"mime_type": uploaded_file.type, "data": base64_data}]
})
# Determine if vision model should be used
use_vision_model = any(part.get('mime_type') == 'image/jpeg' for part in prompt_parts)
# Use the selected model
model_name = selected_model
if use_vision_model and "pro" not in model_name:
st.warning(f"The selected model ({model_name}) does not support image inputs. Choose a 'pro' model for image capabilities.")
return
model = genai.GenerativeModel(
model_name=model_name,
generation_config=generation_config,
safety_settings=safety_settings
)
chat_history_str = "\n".join(prompts)
if use_vision_model:
# Include text and images for vision model
generated_prompt = {"role": "user", "parts": prompt_parts}
else:
# Include text only for standard model
generated_prompt = {"role": "user", "parts": [{"text": chat_history_str}]}
response = model.generate_content([generated_prompt])
response_text = response.text if hasattr(response, "text") else "No response text found."
# After generating the response from the model, append it to the chat history
if response_text:
st.session_state['chat_history'].append({"role": "model", "parts": [{"text": response_text}]})
# Convert the response text to speech if enabled
if enable_tts:
tts = gTTS(text=response_text, lang='en')
tts_file = BytesIO()
tts.write_to_fp(tts_file)
tts_file.seek(0)
st.audio(tts_file, format='audio/mp3')
# Clear the input fields after sending the message
st.session_state.user_input = ''
st.session_state.uploaded_files = []
st.session_state.file_uploader_key = str(uuid.uuid4())
# Display the updated chat history
display_chat_history()
# User input text area
user_input = st.text_area(
"Enter your message here:",
value="",
key="user_input"
)
# File uploader for images
uploaded_files = st.file_uploader(
"Upload files:",
type=["png", "jpg", "jpeg", "mp4", "pdf"], # Added mp4 and pdf
accept_multiple_files=True,
key=st.session_state.file_uploader_key
)
# Send message button
send_button = st.button(
"Send",
on_click=send_message
)
# Clear conversation button
clear_button = st.button("Clear Conversation", on_click=clear_conversation)
# Function to download the chat history as a text file
def download_chat_history():
chat_history_str = get_chat_history_str()
return chat_history_str
# Download button for the chat history
download_button = st.download_button(
label="Download Chat",
data=download_chat_history(),
file_name="chat_history.txt",
mime="text/plain"
)
# Ensure the file_uploader widget state is tied to the randomly generated key
st.session_state.uploaded_files = uploaded_files
# JavaScript to capture the Ctrl+Enter event and trigger a button click
st.markdown(
"""
<script>
document.addEventListener('DOMContentLoaded', (event) => {
document.querySelector('.stTextArea textarea').addEventListener('keydown', function(e) {
if (e.key === 'Enter' && e.ctrlKey) {
document.querySelector('.stButton > button').click();
e.preventDefault();
}
});
});
</script>
""",
unsafe_allow_html=True
) |