Spaces:
Runtime error
Runtime error
File size: 5,226 Bytes
221a628 b557897 eabca2c 221a628 665ff61 b557897 221a628 e2815a3 6e074fc 221a628 dfdbfa8 221a628 ccf48ce 837873a ce73371 eabca2c 326cdbe b557897 221a628 00bfc2f 665ff61 00bfc2f dfdbfa8 40e8df5 ce73371 40e8df5 326cdbe 40e8df5 ba4c612 eabca2c 326cdbe fb00ecf ce73371 fb00ecf 326cdbe ce73371 326cdbe fb00ecf 326cdbe ce73371 fb00ecf ce73371 326cdbe 5837eff 9c3f46e ce73371 326cdbe 665ff61 8aae6cc 389cdce ce73371 eabca2c 389cdce a76b0fb ce73371 665ff61 ce73371 665ff61 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
import streamlit as st
from PIL import Image
import io
import base64
import uuid
# Assuming google.generativeai is correctly imported as genai and the API key is set
import google.generativeai as genai
api_key = "AIzaSyC70u1sN87IkoxOoIj4XCAPw97ae2LZwNM" # Replace with your actual API key
genai.configure(api_key=api_key)
generation_config = genai.GenerationConfig(
temperature=0.9,
max_output_tokens=3000
)
safety_settings = SAFETY_SETTINGS = [
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_NONE",
},
]
if 'chat_history' not in st.session_state:
st.session_state['chat_history'] = []
if 'file_uploader_key' not in st.session_state:
st.session_state['file_uploader_key'] = str(uuid.uuid4())
if 'use_vision_model' not in st.session_state:
st.session_state['use_vision_model'] = False
st.title("Gemini Chatbot")
def get_image_base64(image):
image = image.convert("RGB")
buffered = io.BytesIO()
image.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode()
return img_str
def clear_conversation():
st.session_state['chat_history'] = []
st.session_state['file_uploader_key'] = str(uuid.uuid4())
st.session_state['use_vision_model'] = False
def send_message():
user_input = st.session_state.user_input
uploaded_files = st.session_state.uploaded_files
if uploaded_files:
st.session_state['use_vision_model'] = True
prompts = []
for entry in st.session_state['chat_history']:
for part in entry['parts']:
if 'text' in part:
prompts.append(part['text'])
elif 'data' in part:
prompts.append("[Image]")
if user_input:
prompts.append(user_input)
st.session_state['chat_history'].append({"role": "user", "parts": [{"text": user_input}]})
if uploaded_files:
for uploaded_file in uploaded_files:
base64_image = get_image_base64(Image.open(uploaded_file))
prompts.append("[Image]")
st.session_state['chat_history'].append({
"role": "user",
"parts": [{"mime_type": uploaded_file.type, "data": base64_image}]
})
model_name = 'gemini-pro-vision' if st.session_state['use_vision_model'] else 'gemini-pro'
model = genai.GenerativeModel(
model_name=model_name,
generation_config=generation_config,
safety_settings=safety_settings
)
chat_history_str = "\n".join(prompts)
if st.session_state['use_vision_model']:
prompt_parts = [{"text": chat_history_str}] + [
{"data": part['data'], "mime_type": "image/jpeg"}
for entry in st.session_state['chat_history'] for part in entry['parts']
if 'data' in part
]
else:
prompt_parts = [{"text": chat_history_str}]
response = model.generate_content([{"role": "user", "parts": prompt_parts}])
response_text = response.text if hasattr(response, "text") else "No response text found."
if response_text:
st.session_state['chat_history'].append({"role": "model", "parts": [{"text": response_text}]})
st.session_state.user_input = ''
st.session_state.uploaded_files = []
st.session_state.file_uploader_key = str(uuid.uuid4())
display_chat_history()
def display_chat_history():
for entry in st.session_state['chat_history']:
role = entry["role"]
parts = entry["parts"][0]
if 'text' in parts:
st.markdown(f"{role.title()}: {parts['text']}")
elif 'data' in parts:
st.image(Image.open(io.BytesIO(base64.b64decode(parts['data']))), caption='Uploaded Image')
def get_chat_history_str():
chat_history_str = "\n".join(
f"{entry['role'].title()}: {part['text']}" if 'text' in part
else f"{entry['role'].title()}: (Image)"
for entry in st.session_state['chat_history']
for part in entry['parts']
)
return chat_history_str
user_input = st.text_area("Enter your message here:", key="user_input")
uploaded_files = st.file_uploader(
"Upload images:",
type=["png", "jpg", "jpeg"],
accept_multiple_files=True,
key=st.session_state.file_uploader_key
)
send_button = st.button("Send", on_click=send_message)
clear_button = st.button("Clear Conversation", on_click=clear_conversation)
# Function to download the chat history
def download_chat_history():
chat_history_str = get_chat_history_str()
return chat_history_str
# Add a button to download the chat history as a text file
download_button = st.download_button(
label="Download Chat",
data=download_chat_history(),
file_name="chat_history.txt",
mime="text/plain"
)
# Display the chat history
display_chat_history()
# Ensure the file_uploader widget state is tied to the randomly generated key
st.session_state.uploaded_files = uploaded_files |