Spaces:
Runtime error
Runtime error
Javi
commited on
Commit
·
af5047d
1
Parent(s):
081801e
Introduced file uploader hack
Browse files- streamlit_app.py +83 -22
streamlit_app.py
CHANGED
@@ -1,9 +1,40 @@
|
|
1 |
import random
|
2 |
from typing import Optional, List
|
|
|
3 |
|
4 |
-
import booste
|
5 |
import streamlit as st
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
|
|
|
|
|
|
|
|
7 |
from session_state import SessionState, get_state
|
8 |
|
9 |
# Unfortunately Streamlit sharing does not allow to hide enviroment variables yet.
|
@@ -28,7 +59,7 @@ IMAGES_LINKS = ["https://cdn.pixabay.com/photo/2014/10/13/21/34/clipper-487503_9
|
|
28 |
"https://cdn.pixabay.com/photo/2016/11/29/04/52/architecture-1867411_960_720.jpg",
|
29 |
]
|
30 |
|
31 |
-
@st.cache
|
32 |
def select_random_dataset():
|
33 |
return random.sample(IMAGES_LINKS, 10)
|
34 |
|
@@ -46,9 +77,17 @@ class Sections:
|
|
46 |
st.markdown(" ")
|
47 |
|
48 |
@staticmethod
|
49 |
-
def image_uploader(accept_multiple_files: bool)
|
50 |
-
|
51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
|
53 |
@staticmethod
|
54 |
def image_picker(state: SessionState):
|
@@ -117,9 +156,12 @@ class Sections:
|
|
117 |
col1.image(state.images[idx], use_column_width=True)
|
118 |
else:
|
119 |
col2.image(state.images[idx], use_column_width=True)
|
|
|
|
|
120 |
else:
|
121 |
col1.warning("Select an image")
|
122 |
|
|
|
123 |
with col3:
|
124 |
st.markdown("Query prompt")
|
125 |
if state.prompts is not None:
|
@@ -133,10 +175,19 @@ class Sections:
|
|
133 |
# Possible way of customize this https://discuss.streamlit.io/t/st-button-in-a-custom-layout/2187/2
|
134 |
if st.button("Predict"):
|
135 |
with st.spinner("Predicting..."):
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
st.markdown("### Results")
|
141 |
# st.write(clip_response)
|
142 |
if len(state.images) == 1:
|
@@ -152,8 +203,13 @@ class Sections:
|
|
152 |
else:
|
153 |
st.markdown(f"### {state.prompts[0]}")
|
154 |
assert len(state.prompts) == 1
|
155 |
-
|
156 |
-
|
|
|
|
|
|
|
|
|
|
|
157 |
simplified_clip_results = sorted(simplified_clip_results, key=lambda x: x[1], reverse=True)
|
158 |
for image, probability in simplified_clip_results[:5]:
|
159 |
col1, col2 = st.beta_columns([1, 3])
|
@@ -162,35 +218,40 @@ class Sections:
|
|
162 |
col2.markdown(f"### ")
|
163 |
|
164 |
|
165 |
-
|
166 |
task_name: str = st.sidebar.radio("Task", options=["Image classification", "Image ranking", "Prompt ranking"])
|
167 |
-
session_state = get_state()
|
168 |
if task_name == "Image classification":
|
|
|
169 |
Sections.header()
|
170 |
-
Sections.image_uploader(accept_multiple_files=False)
|
171 |
-
|
172 |
-
|
|
|
173 |
input_label = "Enter the classes to chose from separated by a semi-colon. (f.x. `banana; boat; honesty; apple`)"
|
174 |
Sections.prompts_input(session_state, input_label, prompt_prefix='A picture of a ')
|
175 |
Sections.single_image_input_preview(session_state)
|
176 |
Sections.classification_output(session_state)
|
177 |
elif task_name == "Prompt ranking":
|
|
|
178 |
Sections.header()
|
179 |
-
Sections.image_uploader(accept_multiple_files=False)
|
180 |
-
|
181 |
-
|
|
|
182 |
input_label = "Enter the prompts to choose from separated by a semi-colon. " \
|
183 |
"(f.x. `An image that inspires; A feeling of loneliness; joyful and young; apple`)"
|
184 |
Sections.prompts_input(session_state, input_label)
|
185 |
Sections.single_image_input_preview(session_state)
|
186 |
Sections.classification_output(session_state)
|
187 |
elif task_name == "Image ranking":
|
|
|
188 |
Sections.header()
|
189 |
-
Sections.image_uploader(accept_multiple_files=True)
|
190 |
-
|
191 |
-
|
|
|
192 |
Sections.prompts_input(session_state, "Enter the prompt to query the images by")
|
193 |
Sections.multiple_images_input_preview(session_state)
|
194 |
Sections.classification_output(session_state)
|
|
|
195 |
|
196 |
session_state.sync()
|
|
|
1 |
import random
|
2 |
from typing import Optional, List
|
3 |
+
import uuid
|
4 |
|
|
|
5 |
import streamlit as st
|
6 |
+
from mock import patch
|
7 |
+
|
8 |
+
class ImagesMocker:
|
9 |
+
"""HACK ALERT: I needed a way to call the booste API without storing the images first
|
10 |
+
(as that is not allowed in streamlit sharing). If you have a better idea on hwo to this let me know!"""
|
11 |
+
|
12 |
+
def __init__(self):
|
13 |
+
self.pil_patch = patch('PIL.Image.open', lambda x: self.image_id2image(x))
|
14 |
+
self.path_patch = patch('os.path.exists', lambda x: True)
|
15 |
+
self.image_id2image_lookup = {}
|
16 |
+
|
17 |
+
def start_mocking(self):
|
18 |
+
self.pil_patch.start()
|
19 |
+
self.path_patch.start()
|
20 |
+
|
21 |
+
def stop_mocking(self):
|
22 |
+
self.pil_patch.stop()
|
23 |
+
self.path_patch.stop()
|
24 |
+
|
25 |
+
def image_id2image(self, image_id: str):
|
26 |
+
return self.image_id2image_lookup[image_id]
|
27 |
+
|
28 |
+
def calculate_image_id2image_lookup(self, images: List):
|
29 |
+
self.image_id2image_lookup = {str(uuid.uuid4()) + ".png": image for image in images}
|
30 |
+
@property
|
31 |
+
def image_ids(self):
|
32 |
+
return list(self.image_id2image_lookup.keys())
|
33 |
|
34 |
+
images_mocker = ImagesMocker()
|
35 |
+
import booste
|
36 |
+
|
37 |
+
from PIL import Image
|
38 |
from session_state import SessionState, get_state
|
39 |
|
40 |
# Unfortunately Streamlit sharing does not allow to hide enviroment variables yet.
|
|
|
59 |
"https://cdn.pixabay.com/photo/2016/11/29/04/52/architecture-1867411_960_720.jpg",
|
60 |
]
|
61 |
|
62 |
+
@st.cache # Cache this so that it doesn't change every time something changes in the page
|
63 |
def select_random_dataset():
|
64 |
return random.sample(IMAGES_LINKS, 10)
|
65 |
|
|
|
77 |
st.markdown(" ")
|
78 |
|
79 |
@staticmethod
|
80 |
+
def image_uploader(state: SessionState, accept_multiple_files: bool):
|
81 |
+
uploaded_images = st.file_uploader("Upload image", type=[".png", ".jpg", ".jpeg"],
|
82 |
+
accept_multiple_files=accept_multiple_files)
|
83 |
+
if uploaded_images is not None or (accept_multiple_files and len(uploaded_images) > 1):
|
84 |
+
images = []
|
85 |
+
if not accept_multiple_files:
|
86 |
+
uploaded_images = [uploaded_images]
|
87 |
+
for uploaded_image in uploaded_images:
|
88 |
+
images.append(Image.open(uploaded_image))
|
89 |
+
state.images = images
|
90 |
+
|
91 |
|
92 |
@staticmethod
|
93 |
def image_picker(state: SessionState):
|
|
|
156 |
col1.image(state.images[idx], use_column_width=True)
|
157 |
else:
|
158 |
col2.image(state.images[idx], use_column_width=True)
|
159 |
+
if len(state.images) < 2:
|
160 |
+
col2.warning("At least 2 images required")
|
161 |
else:
|
162 |
col1.warning("Select an image")
|
163 |
|
164 |
+
|
165 |
with col3:
|
166 |
st.markdown("Query prompt")
|
167 |
if state.prompts is not None:
|
|
|
175 |
# Possible way of customize this https://discuss.streamlit.io/t/st-button-in-a-custom-layout/2187/2
|
176 |
if st.button("Predict"):
|
177 |
with st.spinner("Predicting..."):
|
178 |
+
if isinstance(state.images[0], str):
|
179 |
+
print("Regular call!")
|
180 |
+
clip_response = booste.clip(BOOSTE_API_KEY,
|
181 |
+
prompts=state.prompts,
|
182 |
+
images=state.images)
|
183 |
+
else:
|
184 |
+
print("Hacky call!")
|
185 |
+
images_mocker.calculate_image_id2image_lookup(state.images)
|
186 |
+
images_mocker.start_mocking()
|
187 |
+
clip_response = booste.clip(BOOSTE_API_KEY,
|
188 |
+
prompts=state.prompts,
|
189 |
+
images=images_mocker.image_ids)
|
190 |
+
images_mocker.stop_mocking()
|
191 |
st.markdown("### Results")
|
192 |
# st.write(clip_response)
|
193 |
if len(state.images) == 1:
|
|
|
203 |
else:
|
204 |
st.markdown(f"### {state.prompts[0]}")
|
205 |
assert len(state.prompts) == 1
|
206 |
+
if isinstance(state.images[0], str):
|
207 |
+
simplified_clip_results = [(image, results["probabilityRelativeToImages"]) for image, results
|
208 |
+
in list(clip_response.values())[0].items()]
|
209 |
+
else:
|
210 |
+
simplified_clip_results = [(images_mocker.image_id2image(image),
|
211 |
+
results["probabilityRelativeToImages"]) for image, results
|
212 |
+
in list(clip_response.values())[0].items()]
|
213 |
simplified_clip_results = sorted(simplified_clip_results, key=lambda x: x[1], reverse=True)
|
214 |
for image, probability in simplified_clip_results[:5]:
|
215 |
col1, col2 = st.beta_columns([1, 3])
|
|
|
218 |
col2.markdown(f"### ")
|
219 |
|
220 |
|
|
|
221 |
task_name: str = st.sidebar.radio("Task", options=["Image classification", "Image ranking", "Prompt ranking"])
|
|
|
222 |
if task_name == "Image classification":
|
223 |
+
session_state = get_state()
|
224 |
Sections.header()
|
225 |
+
Sections.image_uploader(session_state, accept_multiple_files=False)
|
226 |
+
if session_state.images is None:
|
227 |
+
st.markdown("or choose one from")
|
228 |
+
Sections.image_picker(session_state)
|
229 |
input_label = "Enter the classes to chose from separated by a semi-colon. (f.x. `banana; boat; honesty; apple`)"
|
230 |
Sections.prompts_input(session_state, input_label, prompt_prefix='A picture of a ')
|
231 |
Sections.single_image_input_preview(session_state)
|
232 |
Sections.classification_output(session_state)
|
233 |
elif task_name == "Prompt ranking":
|
234 |
+
session_state = get_state()
|
235 |
Sections.header()
|
236 |
+
Sections.image_uploader(session_state, accept_multiple_files=False)
|
237 |
+
if session_state.images is None:
|
238 |
+
st.markdown("or choose one from")
|
239 |
+
Sections.image_picker(session_state)
|
240 |
input_label = "Enter the prompts to choose from separated by a semi-colon. " \
|
241 |
"(f.x. `An image that inspires; A feeling of loneliness; joyful and young; apple`)"
|
242 |
Sections.prompts_input(session_state, input_label)
|
243 |
Sections.single_image_input_preview(session_state)
|
244 |
Sections.classification_output(session_state)
|
245 |
elif task_name == "Image ranking":
|
246 |
+
session_state = get_state()
|
247 |
Sections.header()
|
248 |
+
Sections.image_uploader(session_state, accept_multiple_files=True)
|
249 |
+
if session_state.images is None:
|
250 |
+
st.markdown("or use this random dataset")
|
251 |
+
Sections.dataset_picker(session_state)
|
252 |
Sections.prompts_input(session_state, "Enter the prompt to query the images by")
|
253 |
Sections.multiple_images_input_preview(session_state)
|
254 |
Sections.classification_output(session_state)
|
255 |
+
print(session_state.images)
|
256 |
|
257 |
session_state.sync()
|