JavierFnts commited on
Commit
f16e869
·
unverified ·
2 Parent(s): fe3ce1c 40497e6

Merge pull request #3 from JaviFuentes94/feature--local-model

Browse files
Files changed (3) hide show
  1. streamlit_app.py → app.py +152 -167
  2. clip_model.py +71 -0
  3. requirements.txt +4 -3
streamlit_app.py → app.py RENAMED
@@ -1,59 +1,62 @@
1
  import random
 
2
 
3
  import streamlit as st
4
-
5
- from session_state import SessionState, get_state
6
- from images_mocker import ImagesMocker
7
-
8
- images_mocker = ImagesMocker()
9
- import booste
10
 
11
  from PIL import Image
12
 
13
- # Unfortunately Streamlit sharing does not allow to hide enviroment variables yet.
14
- # Do not copy this API key, go to https://www.booste.io/ and get your own, it is free!
15
- BOOSTE_API_KEY = "3818ba84-3526-4029-9dc8-ef3038697ea2"
16
-
17
  IMAGES_LINKS = ["https://cdn.pixabay.com/photo/2014/10/13/21/34/clipper-487503_960_720.jpg",
18
  "https://cdn.pixabay.com/photo/2019/09/06/04/25/beach-4455433_960_720.jpg",
19
- # "https://cdn.pixabay.com/photo/2019/10/19/12/21/hot-air-balloons-4561264_960_720.jpg",
20
- # "https://cdn.pixabay.com/photo/2019/12/17/18/20/peacock-4702197_960_720.jpg",
21
- # "https://cdn.pixabay.com/photo/2016/11/15/16/24/banana-1826760_960_720.jpg",
22
- # "https://cdn.pixabay.com/photo/2020/12/28/22/48/buddha-5868759_960_720.jpg",
23
  "https://cdn.pixabay.com/photo/2019/11/11/14/30/zebra-4618513_960_720.jpg",
24
  "https://cdn.pixabay.com/photo/2020/11/04/15/29/coffee-beans-5712780_960_720.jpg",
25
  "https://cdn.pixabay.com/photo/2020/03/24/20/42/namibia-4965457_960_720.jpg",
26
  "https://cdn.pixabay.com/photo/2020/08/27/07/31/restaurant-5521372_960_720.jpg",
27
- # "https://cdn.pixabay.com/photo/2020/08/28/06/13/building-5523630_960_720.jpg",
28
  "https://cdn.pixabay.com/photo/2020/08/24/21/41/couple-5515141_960_720.jpg",
29
  "https://cdn.pixabay.com/photo/2020/01/31/07/10/billboards-4807268_960_720.jpg",
30
  "https://cdn.pixabay.com/photo/2017/07/31/20/48/shell-2560930_960_720.jpg",
31
  "https://cdn.pixabay.com/photo/2020/08/13/01/29/koala-5483931_960_720.jpg",
32
- # "https://cdn.pixabay.com/photo/2016/11/29/04/52/architecture-1867411_960_720.jpg",
33
  ]
34
 
35
  @st.cache # Cache this so that it doesn't change every time something changes in the page
36
- def select_random_dataset():
37
- return random.sample(IMAGES_LINKS, 10)
 
 
 
 
 
 
 
38
 
 
 
 
 
 
 
 
 
 
39
 
40
- def limit_number_images(state: SessionState):
 
41
  """When moving between tasks sometimes the state of images can have too many samples"""
42
- if state.images is not None and len(state.images) > 1:
43
- state.images = [state.images[0]]
44
 
45
 
46
- def limit_number_prompts(state: SessionState):
47
  """When moving between tasks sometimes the state of prompts can have too many samples"""
48
- if state.prompts is not None and len(state.prompts) > 1:
49
- state.prompts = [state.prompts[0]]
50
 
51
 
52
- def is_valid_prediction_state(state: SessionState) -> bool:
53
- if state.images is None or len(state.images) < 1:
54
  st.error("Choose at least one image before predicting")
55
  return False
56
- if state.prompts is None or len(state.prompts) < 1:
57
  st.error("Write at least one prompt before predicting")
58
  return False
59
  return True
@@ -101,16 +104,16 @@ class Sections:
101
  st.markdown("### Try OpenAI's CLIP model in your browser")
102
  st.markdown(" ")
103
  st.markdown(" ")
104
- with st.beta_expander("What is CLIP?"):
105
  st.markdown("CLIP is a machine learning model that computes similarity between text "
106
  "(also called prompts) and images. It has been trained on a dataset with millions of diverse"
107
  " image-prompt pairs, which allows it to generalize to unseen examples."
108
  " <br /> Check out [OpenAI's blogpost](https://openai.com/blog/clip/) for more details",
109
  unsafe_allow_html=True)
110
- col1, col2 = st.beta_columns(2)
111
  col1.image("https://openaiassets.blob.core.windows.net/$web/clip/draft/20210104b/overview-a.svg")
112
  col2.image("https://openaiassets.blob.core.windows.net/$web/clip/draft/20210104b/overview-b.svg")
113
- with st.beta_expander("What can CLIP do?"):
114
  st.markdown("#### Prompt ranking")
115
  st.markdown("Given different prompts and an image CLIP will rank the different prompts based on how well they describe the image")
116
  st.markdown("#### Image ranking")
@@ -122,7 +125,7 @@ class Sections:
122
  st.markdown(" ")
123
 
124
  @staticmethod
125
- def image_uploader(state: SessionState, accept_multiple_files: bool):
126
  uploaded_images = st.file_uploader("Upload image", type=[".png", ".jpg", ".jpeg"],
127
  accept_multiple_files=accept_multiple_files)
128
  if (not accept_multiple_files and uploaded_images is not None) or (accept_multiple_files and len(uploaded_images) >= 1):
@@ -133,199 +136,181 @@ class Sections:
133
  pil_image = Image.open(uploaded_image)
134
  pil_image = preprocess_image(pil_image)
135
  images.append(pil_image)
136
- state.images = images
137
 
138
 
139
  @staticmethod
140
- def image_picker(state: SessionState, default_text_input: str):
141
- col1, col2, col3 = st.beta_columns(3)
142
  with col1:
143
- default_image_1 = "https://cdn.pixabay.com/photo/2014/10/13/21/34/clipper-487503_960_720.jpg"
144
  st.image(default_image_1, use_column_width=True)
145
  if st.button("Select image 1"):
146
- state.images = [default_image_1]
147
- state.default_text_input = default_text_input
148
  with col2:
149
- default_image_2 = "https://cdn.pixabay.com/photo/2019/11/11/14/30/zebra-4618513_960_720.jpg"
150
  st.image(default_image_2, use_column_width=True)
151
  if st.button("Select image 2"):
152
- state.images = [default_image_2]
153
- state.default_text_input = default_text_input
154
  with col3:
155
- default_image_3 = "https://cdn.pixabay.com/photo/2016/11/15/16/24/banana-1826760_960_720.jpg"
156
  st.image(default_image_3, use_column_width=True)
157
  if st.button("Select image 3"):
158
- state.images = [default_image_3]
159
- state.default_text_input = default_text_input
160
 
161
  @staticmethod
162
- def dataset_picker(state: SessionState):
163
- columns = st.beta_columns(5)
164
- state.dataset = select_random_dataset()
165
  image_idx = 0
166
  for col in columns:
167
- col.image(state.dataset[image_idx])
168
  image_idx += 1
169
- col.image(state.dataset[image_idx])
170
  image_idx += 1
171
  if st.button("Select random dataset"):
172
- state.images = state.dataset
173
- state.default_text_input = "A sign that says 'SLOW DOWN'"
174
 
175
  @staticmethod
176
- def prompts_input(state: SessionState, input_label: str, prompt_prefix: str = ''):
177
  raw_text_input = st.text_input(input_label,
178
- value=state.default_text_input if state.default_text_input is not None else "")
179
- state.is_default_text_input = raw_text_input == state.default_text_input
180
  if raw_text_input:
181
- state.prompts = [prompt_prefix + class_name for class_name in raw_text_input.split(";") if len(class_name) > 1]
182
 
183
  @staticmethod
184
- def single_image_input_preview(state: SessionState):
185
  st.markdown("### Preview")
186
- col1, col2 = st.beta_columns([1, 2])
187
  with col1:
188
  st.markdown("Image to classify")
189
- if state.images is not None:
190
- st.image(state.images[0], use_column_width=True)
191
  else:
192
  st.warning("Select an image")
193
 
194
  with col2:
195
  st.markdown("Labels to choose from")
196
- if state.prompts is not None:
197
- for prompt in state.prompts:
198
  st.markdown(f"* {prompt}")
199
- if len(state.prompts) < 2:
200
  st.warning("At least two prompts/classes are needed")
201
  else:
202
  st.warning("Enter the prompts/classes to classify from")
203
 
204
  @staticmethod
205
- def multiple_images_input_preview(state: SessionState):
206
  st.markdown("### Preview")
207
  st.markdown("Images to classify")
208
- col1, col2, col3 = st.beta_columns(3)
209
- if state.images is not None:
210
- for idx, image in enumerate(state.images):
211
- if idx < len(state.images) / 2:
212
- col1.image(state.images[idx], use_column_width=True)
213
  else:
214
- col2.image(state.images[idx], use_column_width=True)
215
- if len(state.images) < 2:
216
  col2.warning("At least 2 images required")
217
  else:
218
  col1.warning("Select an image")
219
 
220
  with col3:
221
  st.markdown("Query prompt")
222
- if state.prompts is not None:
223
- for prompt in state.prompts:
224
  st.write(prompt)
225
  else:
226
  st.warning("Enter the prompt to classify")
227
 
228
  @staticmethod
229
- def classification_output(state: SessionState):
230
- # Possible way of customize this https://discuss.streamlit.io/t/st-button-in-a-custom-layout/2187/2
231
- if st.button("Predict") and is_valid_prediction_state(state): # PREDICT 🚀
232
  with st.spinner("Predicting..."):
233
- if isinstance(state.images[0], str):
234
- clip_response = booste.clip(BOOSTE_API_KEY,
235
- prompts=state.prompts,
236
- images=state.images)
237
- else:
238
- images_mocker.calculate_image_id2image_lookup(state.images)
239
- images_mocker.start_mocking()
240
- clip_response = booste.clip(BOOSTE_API_KEY,
241
- prompts=state.prompts,
242
- images=images_mocker.image_ids)
243
- images_mocker.stop_mocking()
244
  st.markdown("### Results")
245
- # st.write(clip_response)
246
- if len(state.images) == 1:
247
- simplified_clip_results = [(prompt,
248
- list(results.values())[0]["probabilityRelativeToPrompts"])
249
- for prompt, results in clip_response.items()]
250
- simplified_clip_results = sorted(simplified_clip_results, key=lambda x: x[1], reverse=True)
251
-
252
- for prompt, probability in simplified_clip_results:
253
  percentage_prob = int(probability * 100)
254
  st.markdown(
255
- f"### ![prob](https://progress-bar.dev/{percentage_prob}/?width=200) &nbsp &nbsp {prompt}")
256
- else:
257
- st.markdown(f"### {state.prompts[0]}")
258
- assert len(state.prompts) == 1
259
- if isinstance(state.images[0], str):
260
- simplified_clip_results = [(image, results["probabilityRelativeToImages"]) for image, results
261
- in list(clip_response.values())[0].items()]
262
- else:
263
- simplified_clip_results = [(images_mocker.image_id2image(image),
264
- results["probabilityRelativeToImages"]) for image, results
265
- in list(clip_response.values())[0].items()]
266
- simplified_clip_results = sorted(simplified_clip_results, key=lambda x: x[1], reverse=True)
267
- for image, probability in simplified_clip_results[:5]:
268
- col1, col2 = st.beta_columns([1, 3])
269
  col1.image(image, use_column_width=True)
270
  percentage_prob = int(probability * 100)
271
  col2.markdown(f"### ![prob](https://progress-bar.dev/{percentage_prob}/?width=200)")
272
- is_default_image = isinstance(state.images[0], str)
273
- is_default_prediction = is_default_image and state.is_default_text_input
274
- if is_default_prediction:
275
- st.markdown("<br>:information_source: Try writing your own prompts and using your own pictures!",
276
- unsafe_allow_html=True)
277
- elif is_default_image:
278
- st.markdown("<br>:information_source: You can also use your own pictures!",
279
- unsafe_allow_html=True)
280
- elif state.is_default_text_input:
281
- st.markdown("<br>:information_source: Try writing your own prompts!"
282
- " It can be whatever you can think of",
283
- unsafe_allow_html=True)
284
-
285
-
286
- Sections.header()
287
- col1, col2 = st.beta_columns([1, 2])
288
- col1.markdown(" "); col1.markdown(" ")
289
- col1.markdown("#### Task selection")
290
- task_name: str = col2.selectbox("", options=["Prompt ranking", "Image ranking", "Image classification"])
291
- st.markdown("<br>", unsafe_allow_html=True)
292
-
293
- images_mocker.stop_mocking() # Sometimes it gets stuck mocking
294
-
295
- session_state = get_state()
296
- if task_name == "Image classification":
297
- Sections.image_uploader(session_state, accept_multiple_files=False)
298
- if session_state.images is None:
299
- st.markdown("or choose one from")
300
- Sections.image_picker(session_state, default_text_input="banana; boat; bird")
301
- input_label = "Enter the classes to chose from separated by a semi-colon. (f.x. `banana; boat; honesty; apple`)"
302
- Sections.prompts_input(session_state, input_label, prompt_prefix='A picture of a ')
303
- limit_number_images(session_state)
304
- Sections.single_image_input_preview(session_state)
305
- Sections.classification_output(session_state)
306
- elif task_name == "Prompt ranking":
307
- Sections.image_uploader(session_state, accept_multiple_files=False)
308
- if session_state.images is None:
309
- st.markdown("or choose one from")
310
- Sections.image_picker(session_state, default_text_input="A calm afternoon in the Mediterranean; "
311
- "A beautiful creature;"
312
- " Something that grows in tropical regions")
313
- input_label = "Enter the prompts to choose from separated by a semi-colon. " \
314
- "(f.x. `An image that inspires; A feeling of loneliness; joyful and young; apple`)"
315
- Sections.prompts_input(session_state, input_label)
316
- limit_number_images(session_state)
317
- Sections.single_image_input_preview(session_state)
318
- Sections.classification_output(session_state)
319
- elif task_name == "Image ranking":
320
- Sections.image_uploader(session_state, accept_multiple_files=True)
321
- if session_state.images is None or len(session_state.images) < 2:
322
- st.markdown("or use this random dataset")
323
- Sections.dataset_picker(session_state)
324
- Sections.prompts_input(session_state, "Enter the prompt to query the images by")
325
- limit_number_prompts(session_state)
326
- Sections.multiple_images_input_preview(session_state)
327
- Sections.classification_output(session_state)
328
-
329
- st.markdown("<br><br><br><br>Made by [@JavierFnts](https://twitter.com/JavierFnts) | [How was CLIP Playground built?](https://twitter.com/JavierFnts/status/1363522529072214019)"
330
- "", unsafe_allow_html=True)
331
- session_state.sync()
 
1
  import random
2
+ import requests
3
 
4
  import streamlit as st
5
+ from clip_model import ClipModel
 
 
 
 
 
6
 
7
  from PIL import Image
8
 
 
 
 
 
9
  IMAGES_LINKS = ["https://cdn.pixabay.com/photo/2014/10/13/21/34/clipper-487503_960_720.jpg",
10
  "https://cdn.pixabay.com/photo/2019/09/06/04/25/beach-4455433_960_720.jpg",
 
 
 
 
11
  "https://cdn.pixabay.com/photo/2019/11/11/14/30/zebra-4618513_960_720.jpg",
12
  "https://cdn.pixabay.com/photo/2020/11/04/15/29/coffee-beans-5712780_960_720.jpg",
13
  "https://cdn.pixabay.com/photo/2020/03/24/20/42/namibia-4965457_960_720.jpg",
14
  "https://cdn.pixabay.com/photo/2020/08/27/07/31/restaurant-5521372_960_720.jpg",
 
15
  "https://cdn.pixabay.com/photo/2020/08/24/21/41/couple-5515141_960_720.jpg",
16
  "https://cdn.pixabay.com/photo/2020/01/31/07/10/billboards-4807268_960_720.jpg",
17
  "https://cdn.pixabay.com/photo/2017/07/31/20/48/shell-2560930_960_720.jpg",
18
  "https://cdn.pixabay.com/photo/2020/08/13/01/29/koala-5483931_960_720.jpg",
 
19
  ]
20
 
21
  @st.cache # Cache this so that it doesn't change every time something changes in the page
22
+ def load_default_dataset():
23
+ return [load_image_from_url(url) for url in IMAGES_LINKS]
24
+
25
+ def load_image_from_url(url: str) -> Image.Image:
26
+ return Image.open(requests.get(url, stream=True).raw)
27
+
28
+ @st.cache
29
+ def load_model() -> ClipModel:
30
+ return ClipModel()
31
 
32
+ def init_state():
33
+ if "images" not in st.session_state:
34
+ st.session_state.images = None
35
+ if "prompts" not in st.session_state:
36
+ st.session_state.prompts = None
37
+ if "predictions" not in st.session_state:
38
+ st.session_state.predictions = None
39
+ if "default_text_input" not in st.session_state:
40
+ st.session_state.default_text_input = None
41
 
42
+
43
+ def limit_number_images():
44
  """When moving between tasks sometimes the state of images can have too many samples"""
45
+ if st.session_state.images is not None and len(st.session_state.images) > 1:
46
+ st.session_state.images = [st.session_state.images[0]]
47
 
48
 
49
+ def limit_number_prompts():
50
  """When moving between tasks sometimes the state of prompts can have too many samples"""
51
+ if st.session_state.prompts is not None and len(st.session_state.prompts) > 1:
52
+ st.session_state.prompts = [st.session_state.prompts[0]]
53
 
54
 
55
+ def is_valid_prediction_state() -> bool:
56
+ if st.session_state.images is None or len(st.session_state.images) < 1:
57
  st.error("Choose at least one image before predicting")
58
  return False
59
+ if st.session_state.prompts is None or len(st.session_state.prompts) < 1:
60
  st.error("Write at least one prompt before predicting")
61
  return False
62
  return True
 
104
  st.markdown("### Try OpenAI's CLIP model in your browser")
105
  st.markdown(" ")
106
  st.markdown(" ")
107
+ with st.expander("What is CLIP?"):
108
  st.markdown("CLIP is a machine learning model that computes similarity between text "
109
  "(also called prompts) and images. It has been trained on a dataset with millions of diverse"
110
  " image-prompt pairs, which allows it to generalize to unseen examples."
111
  " <br /> Check out [OpenAI's blogpost](https://openai.com/blog/clip/) for more details",
112
  unsafe_allow_html=True)
113
+ col1, col2 = st.columns(2)
114
  col1.image("https://openaiassets.blob.core.windows.net/$web/clip/draft/20210104b/overview-a.svg")
115
  col2.image("https://openaiassets.blob.core.windows.net/$web/clip/draft/20210104b/overview-b.svg")
116
+ with st.expander("What can CLIP do?"):
117
  st.markdown("#### Prompt ranking")
118
  st.markdown("Given different prompts and an image CLIP will rank the different prompts based on how well they describe the image")
119
  st.markdown("#### Image ranking")
 
125
  st.markdown(" ")
126
 
127
  @staticmethod
128
+ def image_uploader(accept_multiple_files: bool):
129
  uploaded_images = st.file_uploader("Upload image", type=[".png", ".jpg", ".jpeg"],
130
  accept_multiple_files=accept_multiple_files)
131
  if (not accept_multiple_files and uploaded_images is not None) or (accept_multiple_files and len(uploaded_images) >= 1):
 
136
  pil_image = Image.open(uploaded_image)
137
  pil_image = preprocess_image(pil_image)
138
  images.append(pil_image)
139
+ st.session_state.images = images
140
 
141
 
142
  @staticmethod
143
+ def image_picker(default_text_input: str):
144
+ col1, col2, col3 = st.columns(3)
145
  with col1:
146
+ default_image_1 = load_image_from_url("https://cdn.pixabay.com/photo/2014/10/13/21/34/clipper-487503_960_720.jpg")
147
  st.image(default_image_1, use_column_width=True)
148
  if st.button("Select image 1"):
149
+ st.session_state.images = [default_image_1]
150
+ st.session_state.default_text_input = default_text_input
151
  with col2:
152
+ default_image_2 = load_image_from_url("https://cdn.pixabay.com/photo/2019/11/11/14/30/zebra-4618513_960_720.jpg")
153
  st.image(default_image_2, use_column_width=True)
154
  if st.button("Select image 2"):
155
+ st.session_state.images = [default_image_2]
156
+ st.session_state.default_text_input = default_text_input
157
  with col3:
158
+ default_image_3 = load_image_from_url("https://cdn.pixabay.com/photo/2016/11/15/16/24/banana-1826760_960_720.jpg")
159
  st.image(default_image_3, use_column_width=True)
160
  if st.button("Select image 3"):
161
+ st.session_state.images = [default_image_3]
162
+ st.session_state.default_text_input = default_text_input
163
 
164
  @staticmethod
165
+ def dataset_picker():
166
+ columns = st.columns(5)
167
+ st.session_state.dataset = load_default_dataset()
168
  image_idx = 0
169
  for col in columns:
170
+ col.image(st.session_state.dataset[image_idx])
171
  image_idx += 1
172
+ col.image(st.session_state.dataset[image_idx])
173
  image_idx += 1
174
  if st.button("Select random dataset"):
175
+ st.session_state.images = st.session_state.dataset
176
+ st.session_state.default_text_input = "A sign that says 'SLOW DOWN'"
177
 
178
  @staticmethod
179
+ def prompts_input(input_label: str, prompt_prefix: str = ''):
180
  raw_text_input = st.text_input(input_label,
181
+ value=st.session_state.default_text_input if st.session_state.default_text_input is not None else "")
182
+ st.session_state.is_default_text_input = raw_text_input == st.session_state.default_text_input
183
  if raw_text_input:
184
+ st.session_state.prompts = [prompt_prefix + class_name for class_name in raw_text_input.split(";") if len(class_name) > 1]
185
 
186
  @staticmethod
187
+ def single_image_input_preview():
188
  st.markdown("### Preview")
189
+ col1, col2 = st.columns([1, 2])
190
  with col1:
191
  st.markdown("Image to classify")
192
+ if st.session_state.images is not None:
193
+ st.image(st.session_state.images[0], use_column_width=True)
194
  else:
195
  st.warning("Select an image")
196
 
197
  with col2:
198
  st.markdown("Labels to choose from")
199
+ if st.session_state.prompts is not None:
200
+ for prompt in st.session_state.prompts:
201
  st.markdown(f"* {prompt}")
202
+ if len(st.session_state.prompts) < 2:
203
  st.warning("At least two prompts/classes are needed")
204
  else:
205
  st.warning("Enter the prompts/classes to classify from")
206
 
207
  @staticmethod
208
+ def multiple_images_input_preview():
209
  st.markdown("### Preview")
210
  st.markdown("Images to classify")
211
+ col1, col2, col3 = st.columns(3)
212
+ if st.session_state.images is not None:
213
+ for idx, image in enumerate(st.session_state.images):
214
+ if idx < len(st.session_state.images) / 2:
215
+ col1.image(st.session_state.images[idx], use_column_width=True)
216
  else:
217
+ col2.image(st.session_state.images[idx], use_column_width=True)
218
+ if len(st.session_state.images) < 2:
219
  col2.warning("At least 2 images required")
220
  else:
221
  col1.warning("Select an image")
222
 
223
  with col3:
224
  st.markdown("Query prompt")
225
+ if st.session_state.prompts is not None:
226
+ for prompt in st.session_state.prompts:
227
  st.write(prompt)
228
  else:
229
  st.warning("Enter the prompt to classify")
230
 
231
  @staticmethod
232
+ def classification_output(model: ClipModel):
233
+ if st.button("Predict") and is_valid_prediction_state():
 
234
  with st.spinner("Predicting..."):
235
+
 
 
 
 
 
 
 
 
 
 
236
  st.markdown("### Results")
237
+ if len(st.session_state.images) == 1:
238
+ scores = model.compute_prompts_probabilities(st.session_state.images[0], st.session_state.prompts)
239
+ scored_prompts = [(prompt, score) for prompt, score in zip(st.session_state.prompts, scores)]
240
+ sorted_scored_prompts = sorted(scored_prompts, key=lambda x: x[1], reverse=True)
241
+ for prompt, probability in sorted_scored_prompts:
 
 
 
242
  percentage_prob = int(probability * 100)
243
  st.markdown(
244
+ f"### ![prob](https://progress-bar.dev/{percentage_prob}/?width=200) {prompt}")
245
+ elif len(st.session_state.prompts) == 1:
246
+ st.markdown(f"### {st.session_state.prompts[0]}")
247
+
248
+ scores = model.compute_images_probabilities(st.session_state.images, st.session_state.prompts[0])
249
+ scored_images = [(image, score) for image, score in zip(st.session_state.images, scores)]
250
+ sorted_scored_images = sorted(scored_images, key=lambda x: x[1], reverse=True)
251
+
252
+ for image, probability in sorted_scored_images[:5]:
253
+ col1, col2 = st.columns([1, 3])
 
 
 
 
254
  col1.image(image, use_column_width=True)
255
  percentage_prob = int(probability * 100)
256
  col2.markdown(f"### ![prob](https://progress-bar.dev/{percentage_prob}/?width=200)")
257
+ else:
258
+ raise ValueError("Invalid state")
259
+
260
+ # is_default_image = isinstance(state.images[0], str)
261
+ # is_default_prediction = is_default_image and state.is_default_text_input
262
+ # if is_default_prediction:
263
+ # st.markdown("<br>:information_source: Try writing your own prompts and using your own pictures!",
264
+ # unsafe_allow_html=True)
265
+ # elif is_default_image:
266
+ # st.markdown("<br>:information_source: You can also use your own pictures!",
267
+ # unsafe_allow_html=True)
268
+ # elif state.is_default_text_input:
269
+ # st.markdown("<br>:information_source: Try writing your own prompts!"
270
+ # " It can be whatever you can think of",
271
+ # unsafe_allow_html=True)
272
+
273
+ if __name__ == "__main__":
274
+ Sections.header()
275
+ col1, col2 = st.columns([1, 2])
276
+ col1.markdown(" "); col1.markdown(" ")
277
+ col1.markdown("#### Task selection")
278
+ task_name: str = col2.selectbox("", options=["Prompt ranking", "Image ranking", "Image classification"])
279
+ st.markdown("<br>", unsafe_allow_html=True)
280
+ init_state()
281
+ model = load_model()
282
+ if task_name == "Image classification":
283
+ Sections.image_uploader(accept_multiple_files=False)
284
+ if st.session_state.images is None:
285
+ st.markdown("or choose one from")
286
+ Sections.image_picker(default_text_input="banana; boat; bird")
287
+ input_label = "Enter the classes to chose from separated by a semi-colon. (f.x. `banana; boat; honesty; apple`)"
288
+ Sections.prompts_input(input_label, prompt_prefix='A picture of a ')
289
+ limit_number_images()
290
+ Sections.single_image_input_preview()
291
+ Sections.classification_output(model)
292
+ elif task_name == "Prompt ranking":
293
+ Sections.image_uploader(accept_multiple_files=False)
294
+ if st.session_state.images is None:
295
+ st.markdown("or choose one from")
296
+ Sections.image_picker(default_text_input="A calm afternoon in the Mediterranean; "
297
+ "A beautiful creature;"
298
+ " Something that grows in tropical regions")
299
+ input_label = "Enter the prompts to choose from separated by a semi-colon. " \
300
+ "(f.x. `An image that inspires; A feeling of loneliness; joyful and young; apple`)"
301
+ Sections.prompts_input(input_label)
302
+ limit_number_images()
303
+ Sections.single_image_input_preview()
304
+ Sections.classification_output(model)
305
+ elif task_name == "Image ranking":
306
+ Sections.image_uploader(accept_multiple_files=True)
307
+ if st.session_state.images is None or len(st.session_state.images) < 2:
308
+ st.markdown("or use this random dataset")
309
+ Sections.dataset_picker()
310
+ Sections.prompts_input("Enter the prompt to query the images by")
311
+ limit_number_prompts()
312
+ Sections.multiple_images_input_preview()
313
+ Sections.classification_output(model)
314
+
315
+ st.markdown("<br><br><br><br>Made by [@JavierFnts](https://twitter.com/JavierFnts) | [How was CLIP Playground built?](https://twitter.com/JavierFnts/status/1363522529072214019)"
316
+ "", unsafe_allow_html=True)
clip_model.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import clip
2
+ from PIL.Image import Image
3
+ import torch
4
+
5
+
6
+
7
+ class ClipModel:
8
+ def __init__(self, model_name: str = 'RN50') -> None:
9
+ """
10
+ Available models
11
+ ['RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32',
12
+ 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px']
13
+ """
14
+ self._model, self._img_preprocess = clip.load(model_name)
15
+
16
+ def predict(self, images: list[Image], prompts: list[str]) -> dict:
17
+ if len(images) == 1:
18
+ return self.compute_prompts_probabilities(images[0], prompts)
19
+ elif len(prompts) == 1:
20
+ return self.compute_images_probabilities(images, prompts[0])
21
+ else:
22
+ raise ValueError('Either images or prompts must be a single element')
23
+
24
+ def compute_prompts_probabilities(self, image: Image, prompts: list[str]) -> list[float]:
25
+ preprocessed_image = self._img_preprocess(image).unsqueeze(0)
26
+ tokenized_prompts = clip.tokenize(prompts)
27
+ with torch.inference_mode():
28
+ image_features = self._model.encode_image(preprocessed_image)
29
+ text_features = self._model.encode_text(tokenized_prompts)
30
+
31
+ # normalized features
32
+ image_features = image_features / image_features.norm(dim=1, keepdim=True)
33
+ text_features = text_features / text_features.norm(dim=1, keepdim=True)
34
+
35
+ # cosine similarity as logits
36
+ logit_scale = self._model.logit_scale.exp()
37
+ logits_per_image = logit_scale * image_features @ text_features.t()
38
+
39
+ probs = list(logits_per_image.softmax(dim=-1).cpu().numpy()[0])
40
+
41
+ return probs
42
+
43
+ def compute_images_probabilities(self, images: list[Image], prompt: str) -> list[float]:
44
+ preprocessed_images = [self._img_preprocess(image).unsqueeze(0) for image in images]
45
+ tokenized_prompts = clip.tokenize(prompt)
46
+ with torch.inference_mode():
47
+ image_features = torch.cat([self._model.encode_image(preprocessed_image) for preprocessed_image in preprocessed_images])
48
+ text_features = self._model.encode_text(tokenized_prompts)
49
+
50
+ # normalized features
51
+ image_features = image_features / image_features.norm(dim=1, keepdim=True)
52
+ text_features = text_features / text_features.norm(dim=1, keepdim=True)
53
+
54
+ # cosine similarity as logits
55
+ logit_scale = self._model.logit_scale.exp()
56
+ logits_per_image = logit_scale * text_features @ image_features.t()
57
+
58
+ probs = list(logits_per_image.softmax(dim=-1).cpu().numpy()[0])
59
+
60
+ return probs
61
+
62
+ if __name__ == "__main__":
63
+ from app import load_default_dataset
64
+
65
+ model = ClipModel()
66
+ images = load_default_dataset()
67
+ prompts = ['Hello', 'How are you', 'Goodbye']
68
+ prompts_scores = model.compute_prompts_probabilities(images[0], prompts)
69
+ images_scores = model.compute_images_probabilities(images, prompts[0])
70
+ print(f"Prompts scores: {prompts_scores}")
71
+ print(f"Images scores: {images_scores}")
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
- streamlit~=0.76.0
2
- booste==0.2.8
3
  Pillow==8.1.0
4
- mock==4.0.3
 
 
1
+ streamlit~=1.11.1
2
+ git+https://github.com/openai/CLIP@b46f5ac
3
  Pillow==8.1.0
4
+ mock==4.0.3
5
+ protobuf==3.20.0 # It raises errors otherwise