ziyadsuper2017 commited on
Commit
221a628
·
1 Parent(s): 6335d32

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +257 -37
app.py CHANGED
@@ -1,50 +1,270 @@
1
- import streamlit as st
 
 
 
 
2
  import google.generativeai as genai
3
- from streamlit import file_uploader
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- # Generative AI setup
6
- api_key = "AIzaSyC70u1sN87IkoxOoIj4XCAPw97ae2LZwNM"
7
- genai.configure(api_key=api_key)
 
 
 
 
8
 
9
- generation_config = {
10
- "temperature": 0.9,
11
- "max_output_tokens": 3000
12
- }
13
 
14
- safety_settings = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  # Streamlit UI
17
- st.title("Chatbot")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- # Use text_input for text input by typing
20
- user_text = st.text_input("Type your text here:")
 
 
 
 
 
21
 
22
- # Use file_uploader for image input
23
- user_image = st.file_uploader("Upload an image file here", type=["png", "jpg", "jpeg"])
 
 
 
 
 
24
 
25
- # Check if the user has entered text or uploaded an image
26
- if user_text or user_image:
27
- # Create the prompt parts accordingly
 
 
 
 
 
 
 
 
 
28
  if user_text:
29
- prompt_parts = [user_text]
30
- model_name = "gemini-pro" # Use the text-only model
31
- else:
32
- prompt_parts = [{
33
- "mime_type": user_image.type,
34
- "data": user_image.read()
35
- }]
36
- model_name = "gemini-pro-vision" # Use the multimodal model
37
-
38
- # Model code
39
- model = genai.GenerativeModel(
40
- model_name=model_name,
41
- generation_config=generation_config,
42
- safety_settings=safety_settings
43
- )
44
-
45
- response = model.generate_content(prompt_parts)
46
- response_text = response.text
47
 
 
 
 
 
 
48
  # Display the user input and the model response
49
- st.markdown(f"**User:** {prompt_parts[0]['text']}")
 
 
 
50
  st.markdown(f"**Model:** {response_text}")
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import uuid
4
+ from typing import List, Tuple, Optional, Dict, Union
5
+
6
  import google.generativeai as genai
7
+ import streamlit as st
8
+ from PIL import Image
9
+
10
+ print("google-generativeai:", genai.__version__)
11
+
12
+ GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY")
13
+
14
+ TITLE = """<h1 align="center">Gemini Playground 💬</h1>"""
15
+ SUBTITLE = """<h2 align="center">Play with Gemini Pro and Gemini Pro Vision API</h2>"""
16
+ DUPLICATE = """
17
+ <div style="text-align: center; display: flex; justify-content: center; align-items: center;">
18
+ <a href="https://huggingface.co/spaces/SkalskiP/ChatGemini?duplicate=true">
19
+ <img src="https://bit.ly/3gLdBN6" alt="Duplicate Space" style="margin-right: 10px;">
20
+ </a>
21
+ <span>Duplicate the Space and run securely with your
22
+ <a href="https://makersuite.google.com/app/apikey">GOOGLE API KEY</a>.
23
+ </span>
24
+ </div>
25
+ """
26
+
27
+ AVATAR_IMAGES = (
28
+ None,
29
+ "https://media.roboflow.com/spaces/gemini-icon.png"
30
+ )
31
+
32
+ IMAGE_CACHE_DIRECTORY = "/tmp"
33
+ IMAGE_WIDTH = 512
34
+ CHAT_HISTORY = List[Tuple[Optional[Union[Tuple[str], str]], Optional[str]]]
35
+
36
+
37
+ def preprocess_stop_sequences(stop_sequences: str) -> Optional[List[str]]:
38
+ if not stop_sequences:
39
+ return None
40
+ return [sequence.strip() for sequence in stop_sequences.split(",")]
41
+
42
+
43
+ def preprocess_image(image: Image.Image) -> Optional[Image.Image]:
44
+ image_height = int(image.height * IMAGE_WIDTH / image.width)
45
+ return image.resize((IMAGE_WIDTH, image_height))
46
+
47
+
48
+ def cache_pil_image(image: Image.Image) -> str:
49
+ image_filename = f"{uuid.uuid4()}.jpeg"
50
+ os.makedirs(IMAGE_CACHE_DIRECTORY, exist_ok=True)
51
+ image_path = os.path.join(IMAGE_CACHE_DIRECTORY, image_filename)
52
+ image.save(image_path, "JPEG")
53
+ return image_path
54
+
55
+
56
+ def preprocess_chat_history(
57
+ history: CHAT_HISTORY
58
+ ) -> List[Dict[str, Union[str, List[str]]]]:
59
+ messages = []
60
+ for user_message, model_message in history:
61
+ if isinstance(user_message, tuple):
62
+ pass
63
+ elif user_message is not None:
64
+ messages.append({'role': 'user', 'parts': [user_message]})
65
+ if model_message is not None:
66
+ messages.append({'role': 'model', 'parts': [model_message]})
67
+ return messages
68
+
69
 
70
+ def upload(files: Optional[List[str]], chatbot: CHAT_HISTORY) -> CHAT_HISTORY:
71
+ for file in files:
72
+ image = Image.open(file).convert('RGB')
73
+ image = preprocess_image(image)
74
+ image_path = cache_pil_image(image)
75
+ chatbot.append(((image_path,), None))
76
+ return chatbot
77
 
 
 
 
 
78
 
79
+ def user(text_prompt: str, chatbot: CHAT_HISTORY):
80
+ if text_prompt:
81
+ chatbot.append((text_prompt, None))
82
+ return "", chatbot
83
+
84
+
85
+ def bot(
86
+ google_key: str,
87
+ files: Optional[List[str]],
88
+ temperature: float,
89
+ max_output_tokens: int,
90
+ stop_sequences: str,
91
+ top_k: int,
92
+ top_p: float,
93
+ chatbot: CHAT_HISTORY
94
+ ):
95
+ if len(chatbot) == 0:
96
+ return chatbot
97
+
98
+ google_key = google_key if google_key else GOOGLE_API_KEY
99
+ if not google_key:
100
+ raise ValueError(
101
+ "GOOGLE_API_KEY is not set. "
102
+ "Please follow the instructions in the README to set it up.")
103
+
104
+ genai.configure(api_key=google_key)
105
+ generation_config = genai.types.GenerationConfig(
106
+ temperature=temperature,
107
+ max_output_tokens=max_output_tokens,
108
+ stop_sequences=preprocess_stop_sequences(stop_sequences=stop_sequences),
109
+ top_k=top_k,
110
+ top_p=top_p)
111
+
112
+ if files:
113
+ text_prompt = [chatbot[-1][0]] \
114
+ if chatbot[-1][0] and isinstance(chatbot[-1][0], str) \
115
+ else []
116
+ image_prompt = [Image.open(file).convert('RGB') for file in files]
117
+ model = genai.GenerativeModel('gemini-pro-vision')
118
+ response = model.generate_content(
119
+ text_prompt + image_prompt,
120
+ stream=True,
121
+ generation_config=generation_config)
122
+ else:
123
+ messages = preprocess_chat_history(chatbot)
124
+ model = genai.GenerativeModel('gemini-pro')
125
+ response = model.generate_content(
126
+ messages,
127
+ stream=True,
128
+ generation_config=generation_config)
129
+
130
+ # streaming effect
131
+ chatbot[-1][1] = ""
132
+ for chunk in response:
133
+ for i in range(0, len(chunk.text), 10):
134
+ section = chunk.text[i:i + 10]
135
+ chatbot[-1][1] += section
136
+ time.sleep(0.01)
137
+ yield chatbot
138
+
139
 
140
  # Streamlit UI
141
+ st.markdown(TITLE, unsafe_allow_html=True)
142
+ st.markdown(SUBTITLE, unsafe_allow_html=True)
143
+ st.markdown(DUPLICATE, unsafe_allow_html=True)
144
+
145
+ # Sidebar for parameters
146
+ st.sidebar.header("Parameters")
147
+ google_key = st.sidebar.text_input(
148
+ label="GOOGLE API KEY",
149
+ value="",
150
+ type="password",
151
+ help="You have to provide your own GOOGLE_API_KEY for this app to function properly",
152
+ key="google_key"
153
+ )
154
+ temperature = st.sidebar.slider(
155
+ label="Temperature",
156
+ min_value=0.0,
157
+ max_value=1.0,
158
+ value=0.4,
159
+ step=0.05,
160
+ help=(
161
+ "Temperature controls the degree of randomness in token selection. Lower "
162
+ "temperatures are good for prompts that expect a true or correct response, "
163
+ "while higher temperatures can lead to more diverse or unexpected results. "
164
+ ),
165
+ key="temperature"
166
+ )
167
+ max_output_tokens = st.sidebar.slider(
168
+ label="Token limit",
169
+ min_value=1,
170
+ max_value=2048,
171
+ value=1024,
172
+ step=1,
173
+ help=(
174
+ "Token limit determines the maximum amount of text output from one prompt. A "
175
+ "token is approximately four characters. The default value is 2048."
176
+ ),
177
+ key="max_output_tokens"
178
+ )
179
+ stop_sequences = st.sidebar.text_input(
180
+ label="Add stop sequence",
181
+ value="",
182
+ help=(
183
+ "A stop sequence is a series of characters (including spaces) that stops "
184
+ "response generation if the model encounters it. The sequence is not included "
185
+ "as part of the response. You can add up to five stop sequences."
186
+ ),
187
+ key="stop_sequences"
188
+ )
189
+ top_k = st.sidebar.slider(
190
+ label="Top-K",
191
+ min_value=1,
192
+ max_value=40,
193
+ value=32,
194
+ step=1,
195
+ help=(
196
+ "Top-k changes how the model selects tokens for output. A top-k of 1 means the "
197
+ "selected token is the most probable among all tokens in the model’s "
198
+ "vocabulary (also called greedy decoding), while a top-k of 3 means that the "
199
+ "next token is selected from among the 3 most probable tokens (using "
200
+ "temperature)."
201
+ ),
202
+ key="top_k"
203
+ )
204
+ top_p = st.sidebar.slider(
205
+ label="Top-P",
206
+ min_value=0.0,
207
+ max_value=1.0,
208
+ value=1.0,
209
+ step=0.01,
210
+ help=(
211
+ "Top-p changes how the model selects tokens for output. Tokens are selected "
212
+ "from most probable to least until the sum of their probabilities equals the "
213
+ "top-p value. For example, if tokens A, B, and C have a probability of .3, .2, "
214
+ "and .1 and the top-p value is .5, then the model will select either A or B as "
215
+ "the next token (using temperature). "
216
+ ),
217
+ key="top_p"
218
+ )
219
 
220
+ # Main area for chatbot
221
+ st.header("Chatbot")
222
+ chatbot = st.session_state.get("chatbot", [])
223
+ if len(chatbot) % 2 == 0:
224
+ role = "user"
225
+ else:
226
+ role = "model"
227
 
228
+ for user_message, model_message in chatbot:
229
+ if isinstance(user_message, tuple):
230
+ st.image(user_message[0], use_column_width=True)
231
+ elif user_message is not None:
232
+ st.markdown(f"**User:** {user_message}")
233
+ if model_message is not None:
234
+ st.markdown(f"**Model:** {model_message}")
235
 
236
+ # Text input for user message
237
+ user_text = st.text_input("Type your text here:", key="user_text")
238
+
239
+ # File uploader for user image
240
+ user_image = st.file_uploader("Upload an image file here", type=["png", "jpg", "jpeg"], key="user_image")
241
+
242
+ # Button for running the bot
243
+ run_button = st.button("Run", key="run_button")
244
+
245
+ # Logic for handling user input and bot response
246
+ if run_button or user_text or user_image:
247
+ # Append user input to chatbot history
248
  if user_text:
249
+ chatbot.append((user_text, None))
250
+ elif user_image:
251
+ image = Image.open(user_image).convert('RGB')
252
+ image = preprocess_image(image)
253
+ image_path = cache_pil_image(image)
254
+ chatbot.append(((image_path,), None))
 
 
 
 
 
 
 
 
 
 
 
 
255
 
256
+ # Call the bot function with parameters and chatbot history
257
+ bot_response = bot(
258
+ google_key=google_key,
259
+ files=None,
260
+ temperature
261
  # Display the user input and the model response
262
+ if user_text:
263
+ st.markdown(f"**User:** {user_text}")
264
+ elif user_image:
265
+ st.image(user_image, use_column_width=True)
266
  st.markdown(f"**Model:** {response_text}")
267
+
268
+ # Update the chatbot history with the model response
269
+ chatbot[-1][1] = response_text
270
+ st.session_state["chatbot"] = chatbot