Javi commited on
Commit
714cf07
·
1 Parent(s): af5047d

Cleaning and bug fixing

Browse files
Files changed (2) hide show
  1. images_mocker.py +31 -0
  2. streamlit_app.py +24 -36
images_mocker.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ import uuid
3
+ from mock import patch
4
+
5
+
6
+ class ImagesMocker:
7
+ """HACK ALERT: I needed a way to call the booste API without storing the images first
8
+ (as that is not allowed in streamlit sharing). If you have a better idea on hwo to this let me know!"""
9
+
10
+ def __init__(self):
11
+ self.pil_patch = patch('PIL.Image.open', lambda x: self.image_id2image(x))
12
+ self.path_patch = patch('os.path.exists', lambda x: True)
13
+ self.image_id2image_lookup = {}
14
+
15
+ def start_mocking(self):
16
+ self.pil_patch.start()
17
+ self.path_patch.start()
18
+
19
+ def stop_mocking(self):
20
+ self.pil_patch.stop()
21
+ self.path_patch.stop()
22
+
23
+ def image_id2image(self, image_id: str):
24
+ return self.image_id2image_lookup[image_id]
25
+
26
+ def calculate_image_id2image_lookup(self, images: List):
27
+ self.image_id2image_lookup = {str(uuid.uuid4()) + ".png": image for image in images}
28
+
29
+ @property
30
+ def image_ids(self):
31
+ return list(self.image_id2image_lookup.keys())
streamlit_app.py CHANGED
@@ -1,41 +1,14 @@
1
  import random
2
- from typing import Optional, List
3
- import uuid
4
 
5
  import streamlit as st
6
- from mock import patch
7
 
8
- class ImagesMocker:
9
- """HACK ALERT: I needed a way to call the booste API without storing the images first
10
- (as that is not allowed in streamlit sharing). If you have a better idea on hwo to this let me know!"""
11
-
12
- def __init__(self):
13
- self.pil_patch = patch('PIL.Image.open', lambda x: self.image_id2image(x))
14
- self.path_patch = patch('os.path.exists', lambda x: True)
15
- self.image_id2image_lookup = {}
16
-
17
- def start_mocking(self):
18
- self.pil_patch.start()
19
- self.path_patch.start()
20
-
21
- def stop_mocking(self):
22
- self.pil_patch.stop()
23
- self.path_patch.stop()
24
-
25
- def image_id2image(self, image_id: str):
26
- return self.image_id2image_lookup[image_id]
27
-
28
- def calculate_image_id2image_lookup(self, images: List):
29
- self.image_id2image_lookup = {str(uuid.uuid4()) + ".png": image for image in images}
30
- @property
31
- def image_ids(self):
32
- return list(self.image_id2image_lookup.keys())
33
 
34
  images_mocker = ImagesMocker()
35
  import booste
36
 
37
  from PIL import Image
38
- from session_state import SessionState, get_state
39
 
40
  # Unfortunately Streamlit sharing does not allow to hide enviroment variables yet.
41
  # Do not copy this API key, go to https://www.booste.io/ and get your own, it is free!
@@ -64,6 +37,18 @@ def select_random_dataset():
64
  return random.sample(IMAGES_LINKS, 10)
65
 
66
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  class Sections:
68
  @staticmethod
69
  def header():
@@ -80,7 +65,7 @@ class Sections:
80
  def image_uploader(state: SessionState, accept_multiple_files: bool):
81
  uploaded_images = st.file_uploader("Upload image", type=[".png", ".jpg", ".jpeg"],
82
  accept_multiple_files=accept_multiple_files)
83
- if uploaded_images is not None or (accept_multiple_files and len(uploaded_images) > 1):
84
  images = []
85
  if not accept_multiple_files:
86
  uploaded_images = [uploaded_images]
@@ -142,9 +127,11 @@ class Sections:
142
  st.markdown("Labels to choose from")
143
  if state.prompts is not None:
144
  for prompt in state.prompts:
145
- st.write(prompt[len(state.prompt_prefix):])
 
 
146
  else:
147
- st.warning("Enter the classes to classify from")
148
 
149
  @staticmethod
150
  def multiple_images_input_preview(state: SessionState):
@@ -219,8 +206,8 @@ class Sections:
219
 
220
 
221
  task_name: str = st.sidebar.radio("Task", options=["Image classification", "Image ranking", "Prompt ranking"])
 
222
  if task_name == "Image classification":
223
- session_state = get_state()
224
  Sections.header()
225
  Sections.image_uploader(session_state, accept_multiple_files=False)
226
  if session_state.images is None:
@@ -228,10 +215,10 @@ if task_name == "Image classification":
228
  Sections.image_picker(session_state)
229
  input_label = "Enter the classes to chose from separated by a semi-colon. (f.x. `banana; boat; honesty; apple`)"
230
  Sections.prompts_input(session_state, input_label, prompt_prefix='A picture of a ')
 
231
  Sections.single_image_input_preview(session_state)
232
  Sections.classification_output(session_state)
233
  elif task_name == "Prompt ranking":
234
- session_state = get_state()
235
  Sections.header()
236
  Sections.image_uploader(session_state, accept_multiple_files=False)
237
  if session_state.images is None:
@@ -240,16 +227,17 @@ elif task_name == "Prompt ranking":
240
  input_label = "Enter the prompts to choose from separated by a semi-colon. " \
241
  "(f.x. `An image that inspires; A feeling of loneliness; joyful and young; apple`)"
242
  Sections.prompts_input(session_state, input_label)
 
243
  Sections.single_image_input_preview(session_state)
244
  Sections.classification_output(session_state)
245
  elif task_name == "Image ranking":
246
- session_state = get_state()
247
  Sections.header()
248
  Sections.image_uploader(session_state, accept_multiple_files=True)
249
- if session_state.images is None:
250
  st.markdown("or use this random dataset")
251
  Sections.dataset_picker(session_state)
252
  Sections.prompts_input(session_state, "Enter the prompt to query the images by")
 
253
  Sections.multiple_images_input_preview(session_state)
254
  Sections.classification_output(session_state)
255
  print(session_state.images)
 
1
  import random
 
 
2
 
3
  import streamlit as st
 
4
 
5
+ from session_state import SessionState, get_state
6
+ from images_mocker import ImagesMocker
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  images_mocker = ImagesMocker()
9
  import booste
10
 
11
  from PIL import Image
 
12
 
13
  # Unfortunately Streamlit sharing does not allow to hide enviroment variables yet.
14
  # Do not copy this API key, go to https://www.booste.io/ and get your own, it is free!
 
37
  return random.sample(IMAGES_LINKS, 10)
38
 
39
 
40
+ def limit_number_images(state: SessionState):
41
+ """When moving between tasks sometimes the state of images can have too many samples"""
42
+ if state.images is not None and len(state.images) > 1:
43
+ state.images = [state.images[0]]
44
+
45
+
46
+ def limit_number_prompts(state: SessionState):
47
+ """When moving between tasks sometimes the state of prompts can have too many samples"""
48
+ if state.prompts is not None and len(state.prompts) > 1:
49
+ state.prompts = [state.prompts[0]]
50
+
51
+
52
  class Sections:
53
  @staticmethod
54
  def header():
 
65
  def image_uploader(state: SessionState, accept_multiple_files: bool):
66
  uploaded_images = st.file_uploader("Upload image", type=[".png", ".jpg", ".jpeg"],
67
  accept_multiple_files=accept_multiple_files)
68
+ if (not accept_multiple_files and uploaded_images is not None) or (accept_multiple_files and len(uploaded_images) >= 1):
69
  images = []
70
  if not accept_multiple_files:
71
  uploaded_images = [uploaded_images]
 
127
  st.markdown("Labels to choose from")
128
  if state.prompts is not None:
129
  for prompt in state.prompts:
130
+ st.markdown(f"* {prompt[len(state.prompt_prefix):]}")
131
+ if len(state.prompts) < 2:
132
+ st.warning("At least two prompts/classes are needed")
133
  else:
134
+ st.warning("Enter the prompts/classes to classify from")
135
 
136
  @staticmethod
137
  def multiple_images_input_preview(state: SessionState):
 
206
 
207
 
208
  task_name: str = st.sidebar.radio("Task", options=["Image classification", "Image ranking", "Prompt ranking"])
209
+ session_state = get_state()
210
  if task_name == "Image classification":
 
211
  Sections.header()
212
  Sections.image_uploader(session_state, accept_multiple_files=False)
213
  if session_state.images is None:
 
215
  Sections.image_picker(session_state)
216
  input_label = "Enter the classes to chose from separated by a semi-colon. (f.x. `banana; boat; honesty; apple`)"
217
  Sections.prompts_input(session_state, input_label, prompt_prefix='A picture of a ')
218
+ limit_number_images(session_state)
219
  Sections.single_image_input_preview(session_state)
220
  Sections.classification_output(session_state)
221
  elif task_name == "Prompt ranking":
 
222
  Sections.header()
223
  Sections.image_uploader(session_state, accept_multiple_files=False)
224
  if session_state.images is None:
 
227
  input_label = "Enter the prompts to choose from separated by a semi-colon. " \
228
  "(f.x. `An image that inspires; A feeling of loneliness; joyful and young; apple`)"
229
  Sections.prompts_input(session_state, input_label)
230
+ limit_number_images(session_state)
231
  Sections.single_image_input_preview(session_state)
232
  Sections.classification_output(session_state)
233
  elif task_name == "Image ranking":
 
234
  Sections.header()
235
  Sections.image_uploader(session_state, accept_multiple_files=True)
236
+ if session_state.images is None or len(session_state.images) < 2:
237
  st.markdown("or use this random dataset")
238
  Sections.dataset_picker(session_state)
239
  Sections.prompts_input(session_state, "Enter the prompt to query the images by")
240
+ limit_number_prompts(session_state)
241
  Sections.multiple_images_input_preview(session_state)
242
  Sections.classification_output(session_state)
243
  print(session_state.images)