Spaces:
Runtime error
Runtime error
Commit
·
665ff61
1
Parent(s):
fb00ecf
Update app.py
Browse files
app.py
CHANGED
@@ -4,10 +4,9 @@ import io
|
|
4 |
import base64
|
5 |
import uuid
|
6 |
|
7 |
-
# Assuming google.generativeai is
|
8 |
import google.generativeai as genai
|
9 |
|
10 |
-
# Configure the API key (should be set as an environment variable or secure storage in production)
|
11 |
api_key = "AIzaSyC70u1sN87IkoxOoIj4XCAPw97ae2LZwNM" # Replace with your actual API key
|
12 |
genai.configure(api_key=api_key)
|
13 |
|
@@ -18,7 +17,6 @@ generation_config = genai.GenerationConfig(
|
|
18 |
|
19 |
safety_settings = []
|
20 |
|
21 |
-
# Initialize session state for chat history and file uploader key
|
22 |
if 'chat_history' not in st.session_state:
|
23 |
st.session_state['chat_history'] = []
|
24 |
if 'file_uploader_key' not in st.session_state:
|
@@ -26,33 +24,25 @@ if 'file_uploader_key' not in st.session_state:
|
|
26 |
if 'use_vision_model' not in st.session_state:
|
27 |
st.session_state['use_vision_model'] = False
|
28 |
|
29 |
-
# UI layout
|
30 |
st.title("Gemini Chatbot")
|
31 |
|
32 |
-
# Function to convert image to base64
|
33 |
def get_image_base64(image):
|
34 |
-
image = image.convert("RGB")
|
35 |
buffered = io.BytesIO()
|
36 |
image.save(buffered, format="JPEG")
|
37 |
img_str = base64.b64encode(buffered.getvalue()).decode()
|
38 |
return img_str
|
39 |
|
40 |
-
# Function to clear conversation
|
41 |
def clear_conversation():
|
42 |
st.session_state['chat_history'] = []
|
43 |
st.session_state['file_uploader_key'] = str(uuid.uuid4())
|
44 |
st.session_state['use_vision_model'] = False
|
45 |
|
46 |
-
# Function to send message and clear input
|
47 |
def send_message():
|
48 |
user_input = st.session_state.user_input
|
49 |
uploaded_files = st.session_state.uploaded_files
|
50 |
-
|
51 |
-
# If images are uploaded, switch to using the vision model
|
52 |
if uploaded_files:
|
53 |
st.session_state['use_vision_model'] = True
|
54 |
-
|
55 |
-
# Prepare the prompt for the vision model
|
56 |
prompts = []
|
57 |
for entry in st.session_state['chat_history']:
|
58 |
for part in entry['parts']:
|
@@ -60,13 +50,9 @@ def send_message():
|
|
60 |
prompts.append(part['text'])
|
61 |
elif 'data' in part:
|
62 |
prompts.append("[Image]")
|
63 |
-
|
64 |
-
# If there is text input, add it to the list of prompts
|
65 |
if user_input:
|
66 |
prompts.append(user_input)
|
67 |
st.session_state['chat_history'].append({"role": "user", "parts": [{"text": user_input}]})
|
68 |
-
|
69 |
-
# Add uploaded images to the list of prompts and session state
|
70 |
if uploaded_files:
|
71 |
for uploaded_file in uploaded_files:
|
72 |
base64_image = get_image_base64(Image.open(uploaded_file))
|
@@ -75,23 +61,14 @@ def send_message():
|
|
75 |
"role": "user",
|
76 |
"parts": [{"mime_type": uploaded_file.type, "data": base64_image}]
|
77 |
})
|
78 |
-
|
79 |
-
# Determine which model to use based on whether an image has been uploaded
|
80 |
model_name = 'gemini-pro-vision' if st.session_state['use_vision_model'] else 'gemini-pro'
|
81 |
-
|
82 |
-
# Use the appropriate model for interaction
|
83 |
model = genai.GenerativeModel(
|
84 |
model_name=model_name,
|
85 |
generation_config=generation_config,
|
86 |
safety_settings=safety_settings
|
87 |
)
|
88 |
-
|
89 |
-
# Create a single prompt string from the list of prompts
|
90 |
chat_history_str = "\n".join(prompts)
|
91 |
-
|
92 |
-
# Generate content from the chat history or the latest prompt
|
93 |
if st.session_state['use_vision_model']:
|
94 |
-
# When using vision model, include images as base64 strings
|
95 |
prompt_parts = [{"text": chat_history_str}] + [
|
96 |
{"data": part['data'], "mime_type": "image/jpeg"}
|
97 |
for entry in st.session_state['chat_history'] for part in entry['parts']
|
@@ -99,23 +76,15 @@ def send_message():
|
|
99 |
]
|
100 |
else:
|
101 |
prompt_parts = [{"text": chat_history_str}]
|
102 |
-
|
103 |
response = model.generate_content([{"role": "user", "parts": prompt_parts}])
|
104 |
response_text = response.text if hasattr(response, "text") else "No response text found."
|
105 |
-
|
106 |
-
# Display the model response
|
107 |
if response_text:
|
108 |
st.session_state['chat_history'].append({"role": "model", "parts": [{"text": response_text}]})
|
109 |
-
|
110 |
-
# Clear the user input and reset the file uploader widget
|
111 |
st.session_state.user_input = ''
|
112 |
st.session_state.uploaded_files = []
|
113 |
st.session_state.file_uploader_key = str(uuid.uuid4())
|
114 |
-
|
115 |
-
# Display chat history
|
116 |
display_chat_history()
|
117 |
|
118 |
-
# Function to display the chat history
|
119 |
def display_chat_history():
|
120 |
for entry in st.session_state['chat_history']:
|
121 |
role = entry["role"]
|
@@ -123,13 +92,19 @@ def display_chat_history():
|
|
123 |
if 'text' in parts:
|
124 |
st.markdown(f"{role.title()}: {parts['text']}")
|
125 |
elif 'data' in parts:
|
126 |
-
# Display the image
|
127 |
st.image(Image.open(io.BytesIO(base64.b64decode(parts['data']))), caption='Uploaded Image')
|
128 |
|
129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
user_input = st.text_area("Enter your message here:", key="user_input")
|
131 |
|
132 |
-
# File uploader for images
|
133 |
uploaded_files = st.file_uploader(
|
134 |
"Upload images:",
|
135 |
type=["png", "jpg", "jpeg"],
|
@@ -137,11 +112,25 @@ uploaded_files = st.file_uploader(
|
|
137 |
key=st.session_state.file_uploader_key
|
138 |
)
|
139 |
|
140 |
-
# Button to send the message
|
141 |
send_button = st.button("Send", on_click=send_message)
|
142 |
|
143 |
-
# Button to clear the conversation
|
144 |
clear_button = st.button("Clear Conversation", on_click=clear_conversation)
|
145 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
# Ensure the file_uploader widget state is tied to the randomly generated key
|
147 |
-
st.session_state.uploaded_files = uploaded_files
|
|
|
4 |
import base64
|
5 |
import uuid
|
6 |
|
7 |
+
# Assuming google.generativeai is correctly imported as genai and the API key is set
|
8 |
import google.generativeai as genai
|
9 |
|
|
|
10 |
api_key = "AIzaSyC70u1sN87IkoxOoIj4XCAPw97ae2LZwNM" # Replace with your actual API key
|
11 |
genai.configure(api_key=api_key)
|
12 |
|
|
|
17 |
|
18 |
safety_settings = []
|
19 |
|
|
|
20 |
if 'chat_history' not in st.session_state:
|
21 |
st.session_state['chat_history'] = []
|
22 |
if 'file_uploader_key' not in st.session_state:
|
|
|
24 |
if 'use_vision_model' not in st.session_state:
|
25 |
st.session_state['use_vision_model'] = False
|
26 |
|
|
|
27 |
st.title("Gemini Chatbot")
|
28 |
|
|
|
29 |
def get_image_base64(image):
|
30 |
+
image = image.convert("RGB")
|
31 |
buffered = io.BytesIO()
|
32 |
image.save(buffered, format="JPEG")
|
33 |
img_str = base64.b64encode(buffered.getvalue()).decode()
|
34 |
return img_str
|
35 |
|
|
|
36 |
def clear_conversation():
|
37 |
st.session_state['chat_history'] = []
|
38 |
st.session_state['file_uploader_key'] = str(uuid.uuid4())
|
39 |
st.session_state['use_vision_model'] = False
|
40 |
|
|
|
41 |
def send_message():
|
42 |
user_input = st.session_state.user_input
|
43 |
uploaded_files = st.session_state.uploaded_files
|
|
|
|
|
44 |
if uploaded_files:
|
45 |
st.session_state['use_vision_model'] = True
|
|
|
|
|
46 |
prompts = []
|
47 |
for entry in st.session_state['chat_history']:
|
48 |
for part in entry['parts']:
|
|
|
50 |
prompts.append(part['text'])
|
51 |
elif 'data' in part:
|
52 |
prompts.append("[Image]")
|
|
|
|
|
53 |
if user_input:
|
54 |
prompts.append(user_input)
|
55 |
st.session_state['chat_history'].append({"role": "user", "parts": [{"text": user_input}]})
|
|
|
|
|
56 |
if uploaded_files:
|
57 |
for uploaded_file in uploaded_files:
|
58 |
base64_image = get_image_base64(Image.open(uploaded_file))
|
|
|
61 |
"role": "user",
|
62 |
"parts": [{"mime_type": uploaded_file.type, "data": base64_image}]
|
63 |
})
|
|
|
|
|
64 |
model_name = 'gemini-pro-vision' if st.session_state['use_vision_model'] else 'gemini-pro'
|
|
|
|
|
65 |
model = genai.GenerativeModel(
|
66 |
model_name=model_name,
|
67 |
generation_config=generation_config,
|
68 |
safety_settings=safety_settings
|
69 |
)
|
|
|
|
|
70 |
chat_history_str = "\n".join(prompts)
|
|
|
|
|
71 |
if st.session_state['use_vision_model']:
|
|
|
72 |
prompt_parts = [{"text": chat_history_str}] + [
|
73 |
{"data": part['data'], "mime_type": "image/jpeg"}
|
74 |
for entry in st.session_state['chat_history'] for part in entry['parts']
|
|
|
76 |
]
|
77 |
else:
|
78 |
prompt_parts = [{"text": chat_history_str}]
|
|
|
79 |
response = model.generate_content([{"role": "user", "parts": prompt_parts}])
|
80 |
response_text = response.text if hasattr(response, "text") else "No response text found."
|
|
|
|
|
81 |
if response_text:
|
82 |
st.session_state['chat_history'].append({"role": "model", "parts": [{"text": response_text}]})
|
|
|
|
|
83 |
st.session_state.user_input = ''
|
84 |
st.session_state.uploaded_files = []
|
85 |
st.session_state.file_uploader_key = str(uuid.uuid4())
|
|
|
|
|
86 |
display_chat_history()
|
87 |
|
|
|
88 |
def display_chat_history():
|
89 |
for entry in st.session_state['chat_history']:
|
90 |
role = entry["role"]
|
|
|
92 |
if 'text' in parts:
|
93 |
st.markdown(f"{role.title()}: {parts['text']}")
|
94 |
elif 'data' in parts:
|
|
|
95 |
st.image(Image.open(io.BytesIO(base64.b64decode(parts['data']))), caption='Uploaded Image')
|
96 |
|
97 |
+
def get_chat_history_str():
|
98 |
+
chat_history_str = "\n".join(
|
99 |
+
f"{entry['role'].title()}: {part['text']}" if 'text' in part
|
100 |
+
else f"{entry['role'].title()}: (Image)"
|
101 |
+
for entry in st.session_state['chat_history']
|
102 |
+
for part in entry['parts']
|
103 |
+
)
|
104 |
+
return chat_history_str
|
105 |
+
|
106 |
user_input = st.text_area("Enter your message here:", key="user_input")
|
107 |
|
|
|
108 |
uploaded_files = st.file_uploader(
|
109 |
"Upload images:",
|
110 |
type=["png", "jpg", "jpeg"],
|
|
|
112 |
key=st.session_state.file_uploader_key
|
113 |
)
|
114 |
|
|
|
115 |
send_button = st.button("Send", on_click=send_message)
|
116 |
|
|
|
117 |
clear_button = st.button("Clear Conversation", on_click=clear_conversation)
|
118 |
|
119 |
+
# Function to download the chat history
|
120 |
+
def download_chat_history():
|
121 |
+
chat_history_str = get_chat_history_str()
|
122 |
+
return chat_history_str
|
123 |
+
|
124 |
+
# Add a button to download the chat history as a text file
|
125 |
+
download_button = st.download_button(
|
126 |
+
label="Download Chat",
|
127 |
+
data=download_chat_history(),
|
128 |
+
file_name="chat_history.txt",
|
129 |
+
mime="text/plain"
|
130 |
+
)
|
131 |
+
|
132 |
+
# Display the chat history
|
133 |
+
display_chat_history()
|
134 |
+
|
135 |
# Ensure the file_uploader widget state is tied to the randomly generated key
|
136 |
+
st.session_state.uploaded_files = uploaded_files
|