ziyadsuper2017 commited on
Commit
665ff61
·
1 Parent(s): fb00ecf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -39
app.py CHANGED
@@ -4,10 +4,9 @@ import io
4
  import base64
5
  import uuid
6
 
7
- # Assuming google.generativeai is the correct import based on your description
8
  import google.generativeai as genai
9
 
10
- # Configure the API key (should be set as an environment variable or secure storage in production)
11
  api_key = "AIzaSyC70u1sN87IkoxOoIj4XCAPw97ae2LZwNM" # Replace with your actual API key
12
  genai.configure(api_key=api_key)
13
 
@@ -18,7 +17,6 @@ generation_config = genai.GenerationConfig(
18
 
19
  safety_settings = []
20
 
21
- # Initialize session state for chat history and file uploader key
22
  if 'chat_history' not in st.session_state:
23
  st.session_state['chat_history'] = []
24
  if 'file_uploader_key' not in st.session_state:
@@ -26,33 +24,25 @@ if 'file_uploader_key' not in st.session_state:
26
  if 'use_vision_model' not in st.session_state:
27
  st.session_state['use_vision_model'] = False
28
 
29
- # UI layout
30
  st.title("Gemini Chatbot")
31
 
32
- # Function to convert image to base64
33
  def get_image_base64(image):
34
- image = image.convert("RGB") # Convert to RGB to remove alpha channel if present
35
  buffered = io.BytesIO()
36
  image.save(buffered, format="JPEG")
37
  img_str = base64.b64encode(buffered.getvalue()).decode()
38
  return img_str
39
 
40
- # Function to clear conversation
41
  def clear_conversation():
42
  st.session_state['chat_history'] = []
43
  st.session_state['file_uploader_key'] = str(uuid.uuid4())
44
  st.session_state['use_vision_model'] = False
45
 
46
- # Function to send message and clear input
47
  def send_message():
48
  user_input = st.session_state.user_input
49
  uploaded_files = st.session_state.uploaded_files
50
-
51
- # If images are uploaded, switch to using the vision model
52
  if uploaded_files:
53
  st.session_state['use_vision_model'] = True
54
-
55
- # Prepare the prompt for the vision model
56
  prompts = []
57
  for entry in st.session_state['chat_history']:
58
  for part in entry['parts']:
@@ -60,13 +50,9 @@ def send_message():
60
  prompts.append(part['text'])
61
  elif 'data' in part:
62
  prompts.append("[Image]")
63
-
64
- # If there is text input, add it to the list of prompts
65
  if user_input:
66
  prompts.append(user_input)
67
  st.session_state['chat_history'].append({"role": "user", "parts": [{"text": user_input}]})
68
-
69
- # Add uploaded images to the list of prompts and session state
70
  if uploaded_files:
71
  for uploaded_file in uploaded_files:
72
  base64_image = get_image_base64(Image.open(uploaded_file))
@@ -75,23 +61,14 @@ def send_message():
75
  "role": "user",
76
  "parts": [{"mime_type": uploaded_file.type, "data": base64_image}]
77
  })
78
-
79
- # Determine which model to use based on whether an image has been uploaded
80
  model_name = 'gemini-pro-vision' if st.session_state['use_vision_model'] else 'gemini-pro'
81
-
82
- # Use the appropriate model for interaction
83
  model = genai.GenerativeModel(
84
  model_name=model_name,
85
  generation_config=generation_config,
86
  safety_settings=safety_settings
87
  )
88
-
89
- # Create a single prompt string from the list of prompts
90
  chat_history_str = "\n".join(prompts)
91
-
92
- # Generate content from the chat history or the latest prompt
93
  if st.session_state['use_vision_model']:
94
- # When using vision model, include images as base64 strings
95
  prompt_parts = [{"text": chat_history_str}] + [
96
  {"data": part['data'], "mime_type": "image/jpeg"}
97
  for entry in st.session_state['chat_history'] for part in entry['parts']
@@ -99,23 +76,15 @@ def send_message():
99
  ]
100
  else:
101
  prompt_parts = [{"text": chat_history_str}]
102
-
103
  response = model.generate_content([{"role": "user", "parts": prompt_parts}])
104
  response_text = response.text if hasattr(response, "text") else "No response text found."
105
-
106
- # Display the model response
107
  if response_text:
108
  st.session_state['chat_history'].append({"role": "model", "parts": [{"text": response_text}]})
109
-
110
- # Clear the user input and reset the file uploader widget
111
  st.session_state.user_input = ''
112
  st.session_state.uploaded_files = []
113
  st.session_state.file_uploader_key = str(uuid.uuid4())
114
-
115
- # Display chat history
116
  display_chat_history()
117
 
118
- # Function to display the chat history
119
  def display_chat_history():
120
  for entry in st.session_state['chat_history']:
121
  role = entry["role"]
@@ -123,13 +92,19 @@ def display_chat_history():
123
  if 'text' in parts:
124
  st.markdown(f"{role.title()}: {parts['text']}")
125
  elif 'data' in parts:
126
- # Display the image
127
  st.image(Image.open(io.BytesIO(base64.b64decode(parts['data']))), caption='Uploaded Image')
128
 
129
- # Multiline text input for the user to send messages
 
 
 
 
 
 
 
 
130
  user_input = st.text_area("Enter your message here:", key="user_input")
131
 
132
- # File uploader for images
133
  uploaded_files = st.file_uploader(
134
  "Upload images:",
135
  type=["png", "jpg", "jpeg"],
@@ -137,11 +112,25 @@ uploaded_files = st.file_uploader(
137
  key=st.session_state.file_uploader_key
138
  )
139
 
140
- # Button to send the message
141
  send_button = st.button("Send", on_click=send_message)
142
 
143
- # Button to clear the conversation
144
  clear_button = st.button("Clear Conversation", on_click=clear_conversation)
145
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  # Ensure the file_uploader widget state is tied to the randomly generated key
147
- st.session_state.uploaded_files = uploaded_files
 
4
  import base64
5
  import uuid
6
 
7
+ # Assuming google.generativeai is correctly imported as genai and the API key is set
8
  import google.generativeai as genai
9
 
 
10
  api_key = "AIzaSyC70u1sN87IkoxOoIj4XCAPw97ae2LZwNM" # Replace with your actual API key
11
  genai.configure(api_key=api_key)
12
 
 
17
 
18
  safety_settings = []
19
 
 
20
  if 'chat_history' not in st.session_state:
21
  st.session_state['chat_history'] = []
22
  if 'file_uploader_key' not in st.session_state:
 
24
  if 'use_vision_model' not in st.session_state:
25
  st.session_state['use_vision_model'] = False
26
 
 
27
  st.title("Gemini Chatbot")
28
 
 
29
  def get_image_base64(image):
30
+ image = image.convert("RGB")
31
  buffered = io.BytesIO()
32
  image.save(buffered, format="JPEG")
33
  img_str = base64.b64encode(buffered.getvalue()).decode()
34
  return img_str
35
 
 
36
  def clear_conversation():
37
  st.session_state['chat_history'] = []
38
  st.session_state['file_uploader_key'] = str(uuid.uuid4())
39
  st.session_state['use_vision_model'] = False
40
 
 
41
  def send_message():
42
  user_input = st.session_state.user_input
43
  uploaded_files = st.session_state.uploaded_files
 
 
44
  if uploaded_files:
45
  st.session_state['use_vision_model'] = True
 
 
46
  prompts = []
47
  for entry in st.session_state['chat_history']:
48
  for part in entry['parts']:
 
50
  prompts.append(part['text'])
51
  elif 'data' in part:
52
  prompts.append("[Image]")
 
 
53
  if user_input:
54
  prompts.append(user_input)
55
  st.session_state['chat_history'].append({"role": "user", "parts": [{"text": user_input}]})
 
 
56
  if uploaded_files:
57
  for uploaded_file in uploaded_files:
58
  base64_image = get_image_base64(Image.open(uploaded_file))
 
61
  "role": "user",
62
  "parts": [{"mime_type": uploaded_file.type, "data": base64_image}]
63
  })
 
 
64
  model_name = 'gemini-pro-vision' if st.session_state['use_vision_model'] else 'gemini-pro'
 
 
65
  model = genai.GenerativeModel(
66
  model_name=model_name,
67
  generation_config=generation_config,
68
  safety_settings=safety_settings
69
  )
 
 
70
  chat_history_str = "\n".join(prompts)
 
 
71
  if st.session_state['use_vision_model']:
 
72
  prompt_parts = [{"text": chat_history_str}] + [
73
  {"data": part['data'], "mime_type": "image/jpeg"}
74
  for entry in st.session_state['chat_history'] for part in entry['parts']
 
76
  ]
77
  else:
78
  prompt_parts = [{"text": chat_history_str}]
 
79
  response = model.generate_content([{"role": "user", "parts": prompt_parts}])
80
  response_text = response.text if hasattr(response, "text") else "No response text found."
 
 
81
  if response_text:
82
  st.session_state['chat_history'].append({"role": "model", "parts": [{"text": response_text}]})
 
 
83
  st.session_state.user_input = ''
84
  st.session_state.uploaded_files = []
85
  st.session_state.file_uploader_key = str(uuid.uuid4())
 
 
86
  display_chat_history()
87
 
 
88
  def display_chat_history():
89
  for entry in st.session_state['chat_history']:
90
  role = entry["role"]
 
92
  if 'text' in parts:
93
  st.markdown(f"{role.title()}: {parts['text']}")
94
  elif 'data' in parts:
 
95
  st.image(Image.open(io.BytesIO(base64.b64decode(parts['data']))), caption='Uploaded Image')
96
 
97
+ def get_chat_history_str():
98
+ chat_history_str = "\n".join(
99
+ f"{entry['role'].title()}: {part['text']}" if 'text' in part
100
+ else f"{entry['role'].title()}: (Image)"
101
+ for entry in st.session_state['chat_history']
102
+ for part in entry['parts']
103
+ )
104
+ return chat_history_str
105
+
106
  user_input = st.text_area("Enter your message here:", key="user_input")
107
 
 
108
  uploaded_files = st.file_uploader(
109
  "Upload images:",
110
  type=["png", "jpg", "jpeg"],
 
112
  key=st.session_state.file_uploader_key
113
  )
114
 
 
115
  send_button = st.button("Send", on_click=send_message)
116
 
 
117
  clear_button = st.button("Clear Conversation", on_click=clear_conversation)
118
 
119
+ # Function to download the chat history
120
+ def download_chat_history():
121
+ chat_history_str = get_chat_history_str()
122
+ return chat_history_str
123
+
124
+ # Add a button to download the chat history as a text file
125
+ download_button = st.download_button(
126
+ label="Download Chat",
127
+ data=download_chat_history(),
128
+ file_name="chat_history.txt",
129
+ mime="text/plain"
130
+ )
131
+
132
+ # Display the chat history
133
+ display_chat_history()
134
+
135
  # Ensure the file_uploader widget state is tied to the randomly generated key
136
+ st.session_state.uploaded_files = uploaded_files