ziyadsuper2017's picture
Update app.py
ccf48ce
raw
history blame
5.23 kB
import streamlit as st
from PIL import Image
import io
import base64
import uuid
# Assuming google.generativeai is correctly imported as genai and the API key is set
import google.generativeai as genai
api_key = "AIzaSyC70u1sN87IkoxOoIj4XCAPw97ae2LZwNM" # Replace with your actual API key
genai.configure(api_key=api_key)
generation_config = genai.GenerationConfig(
temperature=0.9,
max_output_tokens=3000
)
safety_settings = SAFETY_SETTINGS = [
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_NONE",
},
]
if 'chat_history' not in st.session_state:
st.session_state['chat_history'] = []
if 'file_uploader_key' not in st.session_state:
st.session_state['file_uploader_key'] = str(uuid.uuid4())
if 'use_vision_model' not in st.session_state:
st.session_state['use_vision_model'] = False
st.title("Gemini Chatbot")
def get_image_base64(image):
image = image.convert("RGB")
buffered = io.BytesIO()
image.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode()
return img_str
def clear_conversation():
st.session_state['chat_history'] = []
st.session_state['file_uploader_key'] = str(uuid.uuid4())
st.session_state['use_vision_model'] = False
def send_message():
user_input = st.session_state.user_input
uploaded_files = st.session_state.uploaded_files
if uploaded_files:
st.session_state['use_vision_model'] = True
prompts = []
for entry in st.session_state['chat_history']:
for part in entry['parts']:
if 'text' in part:
prompts.append(part['text'])
elif 'data' in part:
prompts.append("[Image]")
if user_input:
prompts.append(user_input)
st.session_state['chat_history'].append({"role": "user", "parts": [{"text": user_input}]})
if uploaded_files:
for uploaded_file in uploaded_files:
base64_image = get_image_base64(Image.open(uploaded_file))
prompts.append("[Image]")
st.session_state['chat_history'].append({
"role": "user",
"parts": [{"mime_type": uploaded_file.type, "data": base64_image}]
})
model_name = 'gemini-pro-vision' if st.session_state['use_vision_model'] else 'gemini-pro'
model = genai.GenerativeModel(
model_name=model_name,
generation_config=generation_config,
safety_settings=safety_settings
)
chat_history_str = "\n".join(prompts)
if st.session_state['use_vision_model']:
prompt_parts = [{"text": chat_history_str}] + [
{"data": part['data'], "mime_type": "image/jpeg"}
for entry in st.session_state['chat_history'] for part in entry['parts']
if 'data' in part
]
else:
prompt_parts = [{"text": chat_history_str}]
response = model.generate_content([{"role": "user", "parts": prompt_parts}])
response_text = response.text if hasattr(response, "text") else "No response text found."
if response_text:
st.session_state['chat_history'].append({"role": "model", "parts": [{"text": response_text}]})
st.session_state.user_input = ''
st.session_state.uploaded_files = []
st.session_state.file_uploader_key = str(uuid.uuid4())
display_chat_history()
def display_chat_history():
for entry in st.session_state['chat_history']:
role = entry["role"]
parts = entry["parts"][0]
if 'text' in parts:
st.markdown(f"{role.title()}: {parts['text']}")
elif 'data' in parts:
st.image(Image.open(io.BytesIO(base64.b64decode(parts['data']))), caption='Uploaded Image')
def get_chat_history_str():
chat_history_str = "\n".join(
f"{entry['role'].title()}: {part['text']}" if 'text' in part
else f"{entry['role'].title()}: (Image)"
for entry in st.session_state['chat_history']
for part in entry['parts']
)
return chat_history_str
user_input = st.text_area("Enter your message here:", key="user_input")
uploaded_files = st.file_uploader(
"Upload images:",
type=["png", "jpg", "jpeg"],
accept_multiple_files=True,
key=st.session_state.file_uploader_key
)
send_button = st.button("Send", on_click=send_message)
clear_button = st.button("Clear Conversation", on_click=clear_conversation)
# Function to download the chat history
def download_chat_history():
chat_history_str = get_chat_history_str()
return chat_history_str
# Add a button to download the chat history as a text file
download_button = st.download_button(
label="Download Chat",
data=download_chat_history(),
file_name="chat_history.txt",
mime="text/plain"
)
# Display the chat history
display_chat_history()
# Ensure the file_uploader widget state is tied to the randomly generated key
st.session_state.uploaded_files = uploaded_files