Spaces:
Runtime error
Runtime error
Javi
commited on
Commit
·
081801e
1
Parent(s):
b0cb25e
Image ranking working
Browse files- streamlit_app.py +94 -27
streamlit_app.py
CHANGED
@@ -1,8 +1,8 @@
|
|
|
|
1 |
from typing import Optional, List
|
2 |
|
3 |
-
from PIL import Image
|
4 |
-
import streamlit as st
|
5 |
import booste
|
|
|
6 |
|
7 |
from session_state import SessionState, get_state
|
8 |
|
@@ -10,6 +10,28 @@ from session_state import SessionState, get_state
|
|
10 |
# Do not copy this API key, go to https://www.booste.io/ and get your own, it is free!
|
11 |
BOOSTE_API_KEY = "3818ba84-3526-4029-9dc8-ef3038697ea2"
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
class Sections:
|
15 |
@staticmethod
|
@@ -35,17 +57,30 @@ class Sections:
|
|
35 |
default_image_1 = "https://cdn.pixabay.com/photo/2014/10/13/21/34/clipper-487503_960_720.jpg"
|
36 |
st.image(default_image_1, use_column_width=True)
|
37 |
if st.button("Select image 1"):
|
38 |
-
state.
|
39 |
with col2:
|
40 |
default_image_2 = "https://cdn.pixabay.com/photo/2019/12/17/18/20/peacock-4702197_960_720.jpg"
|
41 |
st.image(default_image_2, use_column_width=True)
|
42 |
if st.button("Select image 2"):
|
43 |
-
state.
|
44 |
with col3:
|
45 |
default_image_3 = "https://cdn.pixabay.com/photo/2016/11/15/16/24/banana-1826760_960_720.jpg"
|
46 |
st.image(default_image_3, use_column_width=True)
|
47 |
if st.button("Select image 3"):
|
48 |
-
state.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
|
50 |
@staticmethod
|
51 |
def prompts_input(state: SessionState, input_label: str, prompt_prefix: str = ''):
|
@@ -55,23 +90,44 @@ class Sections:
|
|
55 |
state.prompt_prefix = prompt_prefix
|
56 |
|
57 |
@staticmethod
|
58 |
-
def
|
59 |
col1, col2 = st.beta_columns([2, 1])
|
60 |
with col1:
|
61 |
st.markdown("Image to classify")
|
62 |
-
if state.
|
63 |
-
st.image(state.
|
64 |
else:
|
65 |
st.warning("Select an image")
|
66 |
|
67 |
with col2:
|
68 |
st.markdown("Labels to choose from")
|
69 |
-
if state.
|
70 |
for prompt in state.prompts:
|
71 |
st.write(prompt[len(state.prompt_prefix):])
|
72 |
else:
|
73 |
st.warning("Enter the classes to classify from")
|
74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
@staticmethod
|
76 |
def classification_output(state: SessionState):
|
77 |
# Possible way of customize this https://discuss.streamlit.io/t/st-button-in-a-custom-layout/2187/2
|
@@ -79,19 +135,32 @@ class Sections:
|
|
79 |
with st.spinner("Predicting..."):
|
80 |
clip_response = booste.clip(BOOSTE_API_KEY,
|
81 |
prompts=state.prompts,
|
82 |
-
images=
|
83 |
pretty_print=True)
|
84 |
st.markdown("### Results")
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
|
90 |
-
for prompt, probability in simplified_clip_results:
|
91 |
-
percentage_prob = int(probability * 100)
|
92 |
-
st.markdown(
|
93 |
-
f"###      {prompt}")
|
94 |
-
st.write(clip_response)
|
95 |
|
96 |
|
97 |
task_name: str = st.sidebar.radio("Task", options=["Image classification", "Image ranking", "Prompt ranking"])
|
@@ -103,7 +172,7 @@ if task_name == "Image classification":
|
|
103 |
Sections.image_picker(session_state)
|
104 |
input_label = "Enter the classes to chose from separated by a semi-colon. (f.x. `banana; boat; honesty; apple`)"
|
105 |
Sections.prompts_input(session_state, input_label, prompt_prefix='A picture of a ')
|
106 |
-
Sections.
|
107 |
Sections.classification_output(session_state)
|
108 |
elif task_name == "Prompt ranking":
|
109 |
Sections.header()
|
@@ -113,17 +182,15 @@ elif task_name == "Prompt ranking":
|
|
113 |
input_label = "Enter the prompts to choose from separated by a semi-colon. " \
|
114 |
"(f.x. `An image that inspires; A feeling of loneliness; joyful and young; apple`)"
|
115 |
Sections.prompts_input(session_state, input_label)
|
116 |
-
Sections.
|
117 |
Sections.classification_output(session_state)
|
118 |
elif task_name == "Image ranking":
|
119 |
Sections.header()
|
120 |
Sections.image_uploader(accept_multiple_files=True)
|
121 |
st.markdown("or use random dataset")
|
122 |
-
Sections.
|
123 |
-
|
124 |
-
|
|
|
125 |
|
126 |
session_state.sync()
|
127 |
-
|
128 |
-
|
129 |
-
|
|
|
1 |
+
import random
|
2 |
from typing import Optional, List
|
3 |
|
|
|
|
|
4 |
import booste
|
5 |
+
import streamlit as st
|
6 |
|
7 |
from session_state import SessionState, get_state
|
8 |
|
|
|
10 |
# Do not copy this API key, go to https://www.booste.io/ and get your own, it is free!
|
11 |
BOOSTE_API_KEY = "3818ba84-3526-4029-9dc8-ef3038697ea2"
|
12 |
|
13 |
+
IMAGES_LINKS = ["https://cdn.pixabay.com/photo/2014/10/13/21/34/clipper-487503_960_720.jpg",
|
14 |
+
"https://cdn.pixabay.com/photo/2019/09/06/04/25/beach-4455433_960_720.jpg",
|
15 |
+
"https://cdn.pixabay.com/photo/2019/10/19/12/21/hot-air-balloons-4561264_960_720.jpg",
|
16 |
+
"https://cdn.pixabay.com/photo/2019/12/17/18/20/peacock-4702197_960_720.jpg",
|
17 |
+
"https://cdn.pixabay.com/photo/2016/11/15/16/24/banana-1826760_960_720.jpg",
|
18 |
+
"https://cdn.pixabay.com/photo/2020/12/28/22/48/buddha-5868759_960_720.jpg",
|
19 |
+
"https://cdn.pixabay.com/photo/2019/11/11/14/30/zebra-4618513_960_720.jpg",
|
20 |
+
"https://cdn.pixabay.com/photo/2020/11/04/15/29/coffee-beans-5712780_960_720.jpg",
|
21 |
+
"https://cdn.pixabay.com/photo/2020/03/24/20/42/namibia-4965457_960_720.jpg",
|
22 |
+
"https://cdn.pixabay.com/photo/2020/08/27/07/31/restaurant-5521372_960_720.jpg",
|
23 |
+
"https://cdn.pixabay.com/photo/2020/08/28/06/13/building-5523630_960_720.jpg",
|
24 |
+
"https://cdn.pixabay.com/photo/2020/08/24/21/41/couple-5515141_960_720.jpg",
|
25 |
+
"https://cdn.pixabay.com/photo/2020/01/31/07/10/billboards-4807268_960_720.jpg",
|
26 |
+
"https://cdn.pixabay.com/photo/2017/07/31/20/48/shell-2560930_960_720.jpg",
|
27 |
+
"https://cdn.pixabay.com/photo/2020/08/13/01/29/koala-5483931_960_720.jpg",
|
28 |
+
"https://cdn.pixabay.com/photo/2016/11/29/04/52/architecture-1867411_960_720.jpg",
|
29 |
+
]
|
30 |
+
|
31 |
+
@st.cache
|
32 |
+
def select_random_dataset():
|
33 |
+
return random.sample(IMAGES_LINKS, 10)
|
34 |
+
|
35 |
|
36 |
class Sections:
|
37 |
@staticmethod
|
|
|
57 |
default_image_1 = "https://cdn.pixabay.com/photo/2014/10/13/21/34/clipper-487503_960_720.jpg"
|
58 |
st.image(default_image_1, use_column_width=True)
|
59 |
if st.button("Select image 1"):
|
60 |
+
state.images = [default_image_1]
|
61 |
with col2:
|
62 |
default_image_2 = "https://cdn.pixabay.com/photo/2019/12/17/18/20/peacock-4702197_960_720.jpg"
|
63 |
st.image(default_image_2, use_column_width=True)
|
64 |
if st.button("Select image 2"):
|
65 |
+
state.images = [default_image_2]
|
66 |
with col3:
|
67 |
default_image_3 = "https://cdn.pixabay.com/photo/2016/11/15/16/24/banana-1826760_960_720.jpg"
|
68 |
st.image(default_image_3, use_column_width=True)
|
69 |
if st.button("Select image 3"):
|
70 |
+
state.images = [default_image_3]
|
71 |
+
|
72 |
+
@staticmethod
|
73 |
+
def dataset_picker(state: SessionState):
|
74 |
+
columns = st.beta_columns(5)
|
75 |
+
state.dataset = select_random_dataset()
|
76 |
+
image_idx = 0
|
77 |
+
for col in columns:
|
78 |
+
col.image(state.dataset[image_idx])
|
79 |
+
image_idx += 1
|
80 |
+
col.image(state.dataset[image_idx])
|
81 |
+
image_idx += 1
|
82 |
+
if st.button("Select random dataset"):
|
83 |
+
state.images = state.dataset
|
84 |
|
85 |
@staticmethod
|
86 |
def prompts_input(state: SessionState, input_label: str, prompt_prefix: str = ''):
|
|
|
90 |
state.prompt_prefix = prompt_prefix
|
91 |
|
92 |
@staticmethod
|
93 |
+
def single_image_input_preview(state: SessionState):
|
94 |
col1, col2 = st.beta_columns([2, 1])
|
95 |
with col1:
|
96 |
st.markdown("Image to classify")
|
97 |
+
if state.images is not None:
|
98 |
+
st.image(state.images[0], use_column_width=True)
|
99 |
else:
|
100 |
st.warning("Select an image")
|
101 |
|
102 |
with col2:
|
103 |
st.markdown("Labels to choose from")
|
104 |
+
if state.prompts is not None:
|
105 |
for prompt in state.prompts:
|
106 |
st.write(prompt[len(state.prompt_prefix):])
|
107 |
else:
|
108 |
st.warning("Enter the classes to classify from")
|
109 |
|
110 |
+
@staticmethod
|
111 |
+
def multiple_images_input_preview(state: SessionState):
|
112 |
+
st.markdown("Images to classify")
|
113 |
+
col1, col2, col3 = st.beta_columns(3)
|
114 |
+
if state.images is not None:
|
115 |
+
for idx, image in enumerate(state.images):
|
116 |
+
if idx < len(state.images) / 2:
|
117 |
+
col1.image(state.images[idx], use_column_width=True)
|
118 |
+
else:
|
119 |
+
col2.image(state.images[idx], use_column_width=True)
|
120 |
+
else:
|
121 |
+
col1.warning("Select an image")
|
122 |
+
|
123 |
+
with col3:
|
124 |
+
st.markdown("Query prompt")
|
125 |
+
if state.prompts is not None:
|
126 |
+
for prompt in state.prompts:
|
127 |
+
st.write(prompt[len(state.prompt_prefix):])
|
128 |
+
else:
|
129 |
+
st.warning("Enter the prompt to classify")
|
130 |
+
|
131 |
@staticmethod
|
132 |
def classification_output(state: SessionState):
|
133 |
# Possible way of customize this https://discuss.streamlit.io/t/st-button-in-a-custom-layout/2187/2
|
|
|
135 |
with st.spinner("Predicting..."):
|
136 |
clip_response = booste.clip(BOOSTE_API_KEY,
|
137 |
prompts=state.prompts,
|
138 |
+
images=state.images,
|
139 |
pretty_print=True)
|
140 |
st.markdown("### Results")
|
141 |
+
# st.write(clip_response)
|
142 |
+
if len(state.images) == 1:
|
143 |
+
simplified_clip_results = [(prompt[len(state.prompt_prefix):],
|
144 |
+
list(results.values())[0]["probabilityRelativeToPrompts"])
|
145 |
+
for prompt, results in clip_response.items()]
|
146 |
+
simplified_clip_results = sorted(simplified_clip_results, key=lambda x: x[1], reverse=True)
|
147 |
+
|
148 |
+
for prompt, probability in simplified_clip_results:
|
149 |
+
percentage_prob = int(probability * 100)
|
150 |
+
st.markdown(
|
151 |
+
f"###      {prompt}")
|
152 |
+
else:
|
153 |
+
st.markdown(f"### {state.prompts[0]}")
|
154 |
+
assert len(state.prompts) == 1
|
155 |
+
simplified_clip_results = [(image, results["probabilityRelativeToImages"]) for image, results
|
156 |
+
in list(clip_response.values())[0].items()]
|
157 |
+
simplified_clip_results = sorted(simplified_clip_results, key=lambda x: x[1], reverse=True)
|
158 |
+
for image, probability in simplified_clip_results[:5]:
|
159 |
+
col1, col2 = st.beta_columns([1, 3])
|
160 |
+
col1.image(image, use_column_width=True)
|
161 |
+
percentage_prob = int(probability * 100)
|
162 |
+
col2.markdown(f"### ")
|
163 |
|
|
|
|
|
|
|
|
|
|
|
164 |
|
165 |
|
166 |
task_name: str = st.sidebar.radio("Task", options=["Image classification", "Image ranking", "Prompt ranking"])
|
|
|
172 |
Sections.image_picker(session_state)
|
173 |
input_label = "Enter the classes to chose from separated by a semi-colon. (f.x. `banana; boat; honesty; apple`)"
|
174 |
Sections.prompts_input(session_state, input_label, prompt_prefix='A picture of a ')
|
175 |
+
Sections.single_image_input_preview(session_state)
|
176 |
Sections.classification_output(session_state)
|
177 |
elif task_name == "Prompt ranking":
|
178 |
Sections.header()
|
|
|
182 |
input_label = "Enter the prompts to choose from separated by a semi-colon. " \
|
183 |
"(f.x. `An image that inspires; A feeling of loneliness; joyful and young; apple`)"
|
184 |
Sections.prompts_input(session_state, input_label)
|
185 |
+
Sections.single_image_input_preview(session_state)
|
186 |
Sections.classification_output(session_state)
|
187 |
elif task_name == "Image ranking":
|
188 |
Sections.header()
|
189 |
Sections.image_uploader(accept_multiple_files=True)
|
190 |
st.markdown("or use random dataset")
|
191 |
+
Sections.dataset_picker(session_state)
|
192 |
+
Sections.prompts_input(session_state, "Enter the prompt to query the images by")
|
193 |
+
Sections.multiple_images_input_preview(session_state)
|
194 |
+
Sections.classification_output(session_state)
|
195 |
|
196 |
session_state.sync()
|
|
|
|
|
|