Javi commited on
Commit
081801e
·
1 Parent(s): b0cb25e

Image ranking working

Browse files
Files changed (1) hide show
  1. streamlit_app.py +94 -27
streamlit_app.py CHANGED
@@ -1,8 +1,8 @@
 
1
  from typing import Optional, List
2
 
3
- from PIL import Image
4
- import streamlit as st
5
  import booste
 
6
 
7
  from session_state import SessionState, get_state
8
 
@@ -10,6 +10,28 @@ from session_state import SessionState, get_state
10
  # Do not copy this API key, go to https://www.booste.io/ and get your own, it is free!
11
  BOOSTE_API_KEY = "3818ba84-3526-4029-9dc8-ef3038697ea2"
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  class Sections:
15
  @staticmethod
@@ -35,17 +57,30 @@ class Sections:
35
  default_image_1 = "https://cdn.pixabay.com/photo/2014/10/13/21/34/clipper-487503_960_720.jpg"
36
  st.image(default_image_1, use_column_width=True)
37
  if st.button("Select image 1"):
38
- state.image = default_image_1
39
  with col2:
40
  default_image_2 = "https://cdn.pixabay.com/photo/2019/12/17/18/20/peacock-4702197_960_720.jpg"
41
  st.image(default_image_2, use_column_width=True)
42
  if st.button("Select image 2"):
43
- state.image = default_image_2
44
  with col3:
45
  default_image_3 = "https://cdn.pixabay.com/photo/2016/11/15/16/24/banana-1826760_960_720.jpg"
46
  st.image(default_image_3, use_column_width=True)
47
  if st.button("Select image 3"):
48
- state.image = default_image_3
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
  @staticmethod
51
  def prompts_input(state: SessionState, input_label: str, prompt_prefix: str = ''):
@@ -55,23 +90,44 @@ class Sections:
55
  state.prompt_prefix = prompt_prefix
56
 
57
  @staticmethod
58
- def input_preview(state: SessionState):
59
  col1, col2 = st.beta_columns([2, 1])
60
  with col1:
61
  st.markdown("Image to classify")
62
- if state.image is not None:
63
- st.image(state.image, use_column_width=True)
64
  else:
65
  st.warning("Select an image")
66
 
67
  with col2:
68
  st.markdown("Labels to choose from")
69
- if state.processed_classes is not None:
70
  for prompt in state.prompts:
71
  st.write(prompt[len(state.prompt_prefix):])
72
  else:
73
  st.warning("Enter the classes to classify from")
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  @staticmethod
76
  def classification_output(state: SessionState):
77
  # Possible way of customize this https://discuss.streamlit.io/t/st-button-in-a-custom-layout/2187/2
@@ -79,19 +135,32 @@ class Sections:
79
  with st.spinner("Predicting..."):
80
  clip_response = booste.clip(BOOSTE_API_KEY,
81
  prompts=state.prompts,
82
- images=[state.image],
83
  pretty_print=True)
84
  st.markdown("### Results")
85
- simplified_clip_results = [(prompt[len(state.prompt_prefix):],
86
- list(results.values())[0]["probabilityRelativeToPrompts"])
87
- for prompt, results in clip_response.items()]
88
- simplified_clip_results = sorted(simplified_clip_results, key=lambda x: x[1], reverse=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
- for prompt, probability in simplified_clip_results:
91
- percentage_prob = int(probability * 100)
92
- st.markdown(
93
- f"### ![prob](https://progress-bar.dev/{percentage_prob}/?width=200) &nbsp &nbsp {prompt}")
94
- st.write(clip_response)
95
 
96
 
97
  task_name: str = st.sidebar.radio("Task", options=["Image classification", "Image ranking", "Prompt ranking"])
@@ -103,7 +172,7 @@ if task_name == "Image classification":
103
  Sections.image_picker(session_state)
104
  input_label = "Enter the classes to chose from separated by a semi-colon. (f.x. `banana; boat; honesty; apple`)"
105
  Sections.prompts_input(session_state, input_label, prompt_prefix='A picture of a ')
106
- Sections.input_preview(session_state)
107
  Sections.classification_output(session_state)
108
  elif task_name == "Prompt ranking":
109
  Sections.header()
@@ -113,17 +182,15 @@ elif task_name == "Prompt ranking":
113
  input_label = "Enter the prompts to choose from separated by a semi-colon. " \
114
  "(f.x. `An image that inspires; A feeling of loneliness; joyful and young; apple`)"
115
  Sections.prompts_input(session_state, input_label)
116
- Sections.input_preview(session_state)
117
  Sections.classification_output(session_state)
118
  elif task_name == "Image ranking":
119
  Sections.header()
120
  Sections.image_uploader(accept_multiple_files=True)
121
  st.markdown("or use random dataset")
122
- Sections.image_picker(session_state)
123
-
124
-
 
125
 
126
  session_state.sync()
127
-
128
-
129
-
 
1
+ import random
2
  from typing import Optional, List
3
 
 
 
4
  import booste
5
+ import streamlit as st
6
 
7
  from session_state import SessionState, get_state
8
 
 
10
  # Do not copy this API key, go to https://www.booste.io/ and get your own, it is free!
11
  BOOSTE_API_KEY = "3818ba84-3526-4029-9dc8-ef3038697ea2"
12
 
13
+ IMAGES_LINKS = ["https://cdn.pixabay.com/photo/2014/10/13/21/34/clipper-487503_960_720.jpg",
14
+ "https://cdn.pixabay.com/photo/2019/09/06/04/25/beach-4455433_960_720.jpg",
15
+ "https://cdn.pixabay.com/photo/2019/10/19/12/21/hot-air-balloons-4561264_960_720.jpg",
16
+ "https://cdn.pixabay.com/photo/2019/12/17/18/20/peacock-4702197_960_720.jpg",
17
+ "https://cdn.pixabay.com/photo/2016/11/15/16/24/banana-1826760_960_720.jpg",
18
+ "https://cdn.pixabay.com/photo/2020/12/28/22/48/buddha-5868759_960_720.jpg",
19
+ "https://cdn.pixabay.com/photo/2019/11/11/14/30/zebra-4618513_960_720.jpg",
20
+ "https://cdn.pixabay.com/photo/2020/11/04/15/29/coffee-beans-5712780_960_720.jpg",
21
+ "https://cdn.pixabay.com/photo/2020/03/24/20/42/namibia-4965457_960_720.jpg",
22
+ "https://cdn.pixabay.com/photo/2020/08/27/07/31/restaurant-5521372_960_720.jpg",
23
+ "https://cdn.pixabay.com/photo/2020/08/28/06/13/building-5523630_960_720.jpg",
24
+ "https://cdn.pixabay.com/photo/2020/08/24/21/41/couple-5515141_960_720.jpg",
25
+ "https://cdn.pixabay.com/photo/2020/01/31/07/10/billboards-4807268_960_720.jpg",
26
+ "https://cdn.pixabay.com/photo/2017/07/31/20/48/shell-2560930_960_720.jpg",
27
+ "https://cdn.pixabay.com/photo/2020/08/13/01/29/koala-5483931_960_720.jpg",
28
+ "https://cdn.pixabay.com/photo/2016/11/29/04/52/architecture-1867411_960_720.jpg",
29
+ ]
30
+
31
+ @st.cache
32
+ def select_random_dataset():
33
+ return random.sample(IMAGES_LINKS, 10)
34
+
35
 
36
  class Sections:
37
  @staticmethod
 
57
  default_image_1 = "https://cdn.pixabay.com/photo/2014/10/13/21/34/clipper-487503_960_720.jpg"
58
  st.image(default_image_1, use_column_width=True)
59
  if st.button("Select image 1"):
60
+ state.images = [default_image_1]
61
  with col2:
62
  default_image_2 = "https://cdn.pixabay.com/photo/2019/12/17/18/20/peacock-4702197_960_720.jpg"
63
  st.image(default_image_2, use_column_width=True)
64
  if st.button("Select image 2"):
65
+ state.images = [default_image_2]
66
  with col3:
67
  default_image_3 = "https://cdn.pixabay.com/photo/2016/11/15/16/24/banana-1826760_960_720.jpg"
68
  st.image(default_image_3, use_column_width=True)
69
  if st.button("Select image 3"):
70
+ state.images = [default_image_3]
71
+
72
+ @staticmethod
73
+ def dataset_picker(state: SessionState):
74
+ columns = st.beta_columns(5)
75
+ state.dataset = select_random_dataset()
76
+ image_idx = 0
77
+ for col in columns:
78
+ col.image(state.dataset[image_idx])
79
+ image_idx += 1
80
+ col.image(state.dataset[image_idx])
81
+ image_idx += 1
82
+ if st.button("Select random dataset"):
83
+ state.images = state.dataset
84
 
85
  @staticmethod
86
  def prompts_input(state: SessionState, input_label: str, prompt_prefix: str = ''):
 
90
  state.prompt_prefix = prompt_prefix
91
 
92
  @staticmethod
93
+ def single_image_input_preview(state: SessionState):
94
  col1, col2 = st.beta_columns([2, 1])
95
  with col1:
96
  st.markdown("Image to classify")
97
+ if state.images is not None:
98
+ st.image(state.images[0], use_column_width=True)
99
  else:
100
  st.warning("Select an image")
101
 
102
  with col2:
103
  st.markdown("Labels to choose from")
104
+ if state.prompts is not None:
105
  for prompt in state.prompts:
106
  st.write(prompt[len(state.prompt_prefix):])
107
  else:
108
  st.warning("Enter the classes to classify from")
109
 
110
+ @staticmethod
111
+ def multiple_images_input_preview(state: SessionState):
112
+ st.markdown("Images to classify")
113
+ col1, col2, col3 = st.beta_columns(3)
114
+ if state.images is not None:
115
+ for idx, image in enumerate(state.images):
116
+ if idx < len(state.images) / 2:
117
+ col1.image(state.images[idx], use_column_width=True)
118
+ else:
119
+ col2.image(state.images[idx], use_column_width=True)
120
+ else:
121
+ col1.warning("Select an image")
122
+
123
+ with col3:
124
+ st.markdown("Query prompt")
125
+ if state.prompts is not None:
126
+ for prompt in state.prompts:
127
+ st.write(prompt[len(state.prompt_prefix):])
128
+ else:
129
+ st.warning("Enter the prompt to classify")
130
+
131
  @staticmethod
132
  def classification_output(state: SessionState):
133
  # Possible way of customize this https://discuss.streamlit.io/t/st-button-in-a-custom-layout/2187/2
 
135
  with st.spinner("Predicting..."):
136
  clip_response = booste.clip(BOOSTE_API_KEY,
137
  prompts=state.prompts,
138
+ images=state.images,
139
  pretty_print=True)
140
  st.markdown("### Results")
141
+ # st.write(clip_response)
142
+ if len(state.images) == 1:
143
+ simplified_clip_results = [(prompt[len(state.prompt_prefix):],
144
+ list(results.values())[0]["probabilityRelativeToPrompts"])
145
+ for prompt, results in clip_response.items()]
146
+ simplified_clip_results = sorted(simplified_clip_results, key=lambda x: x[1], reverse=True)
147
+
148
+ for prompt, probability in simplified_clip_results:
149
+ percentage_prob = int(probability * 100)
150
+ st.markdown(
151
+ f"### ![prob](https://progress-bar.dev/{percentage_prob}/?width=200) &nbsp &nbsp {prompt}")
152
+ else:
153
+ st.markdown(f"### {state.prompts[0]}")
154
+ assert len(state.prompts) == 1
155
+ simplified_clip_results = [(image, results["probabilityRelativeToImages"]) for image, results
156
+ in list(clip_response.values())[0].items()]
157
+ simplified_clip_results = sorted(simplified_clip_results, key=lambda x: x[1], reverse=True)
158
+ for image, probability in simplified_clip_results[:5]:
159
+ col1, col2 = st.beta_columns([1, 3])
160
+ col1.image(image, use_column_width=True)
161
+ percentage_prob = int(probability * 100)
162
+ col2.markdown(f"### ![prob](https://progress-bar.dev/{percentage_prob}/?width=200)")
163
 
 
 
 
 
 
164
 
165
 
166
  task_name: str = st.sidebar.radio("Task", options=["Image classification", "Image ranking", "Prompt ranking"])
 
172
  Sections.image_picker(session_state)
173
  input_label = "Enter the classes to chose from separated by a semi-colon. (f.x. `banana; boat; honesty; apple`)"
174
  Sections.prompts_input(session_state, input_label, prompt_prefix='A picture of a ')
175
+ Sections.single_image_input_preview(session_state)
176
  Sections.classification_output(session_state)
177
  elif task_name == "Prompt ranking":
178
  Sections.header()
 
182
  input_label = "Enter the prompts to choose from separated by a semi-colon. " \
183
  "(f.x. `An image that inspires; A feeling of loneliness; joyful and young; apple`)"
184
  Sections.prompts_input(session_state, input_label)
185
+ Sections.single_image_input_preview(session_state)
186
  Sections.classification_output(session_state)
187
  elif task_name == "Image ranking":
188
  Sections.header()
189
  Sections.image_uploader(accept_multiple_files=True)
190
  st.markdown("or use random dataset")
191
+ Sections.dataset_picker(session_state)
192
+ Sections.prompts_input(session_state, "Enter the prompt to query the images by")
193
+ Sections.multiple_images_input_preview(session_state)
194
+ Sections.classification_output(session_state)
195
 
196
  session_state.sync()