vancauwe commited on
Commit
54319e9
·
1 Parent(s): c4d6745

feat: multi image input

Browse files
Dockerfile ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ From ubuntu:latest
2
+
3
+ RUN apt-get update
4
+ RUN apt-get install python3 python3-pip -y
5
+
6
+ # https://stackoverflow.com/questions/75608323/how-do-i-solve-error-externally-managed-environment-every-time-i-use-pip-3
7
+ # https://veronneau.org/python-311-pip-and-breaking-system-packages.html
8
+ ENV PIP_BREAK_SYSTEM_PACKAGES 1
9
+
10
+
11
+ ##################################################
12
+ # Ubuntu setup
13
+ ##################################################
14
+
15
+ RUN apt-get update \
16
+ && apt-get install -y wget \
17
+ && rm -rf /var/lib/apt/lists/*
18
+
19
+ RUN apt-get update && apt-get -y upgrade \
20
+ && apt-get install -y --no-install-recommends \
21
+ unzip \
22
+ nano \
23
+ git \
24
+ g++ \
25
+ gcc \
26
+ htop \
27
+ zip \
28
+ ca-certificates \
29
+ && rm -rf /var/lib/apt/lists/*
30
+
31
+ ##################################################
32
+ # ODTP setup
33
+ ##################################################
34
+
35
+ RUN mkdir /app
36
+ COPY . /saving-willy
37
+ RUN pip3 install --upgrade setuptools
38
+ RUN pip3 install -r /saving-willy/requirements.txt
39
+
40
+ WORKDIR /saving-willy
41
+
42
+ ENTRYPOINT bash
README.md CHANGED
@@ -28,7 +28,7 @@ pip install -r requirements.txt
28
  ```
29
 
30
  ```
31
- streamlit run app.py
32
  ```
33
 
34
 
 
28
  ```
29
 
30
  ```
31
+ streamlit run src/main.py
32
  ```
33
 
34
 
basic_map/app.py DELETED
@@ -1,21 +0,0 @@
1
- import pandas as pd
2
- import streamlit as st
3
- import folium
4
-
5
- from streamlit_folium import st_folium
6
- from streamlit_folium import folium_static
7
-
8
-
9
- visp_loc = 46.295833, 7.883333
10
- #m = folium.Map(location=visp_loc, zoom_start=9)
11
-
12
-
13
- st.markdown("# :whale: :whale: Cetaceans :red[& friends] :balloon:")
14
-
15
- m = folium.Map(location=visp_loc, zoom_start=9,
16
- tiles='https://tile.opentopomap.org/{z}/{x}/{y}.png',
17
- attr='<a href="https://opentopomap.org/">Open Topo Map</a>')
18
-
19
- folium_static(m)
20
-
21
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
basic_map/app1.py DELETED
@@ -1,42 +0,0 @@
1
- # lets try using map stuff without folium, maybe stlite doesnt support that.
2
-
3
- import streamlit as st
4
- import pandas as pd
5
-
6
- # Load data
7
- f = 'mountains_clr.csv'
8
- df = pd.read_csv(f).dropna()
9
-
10
- print(df)
11
-
12
- st.markdown("# :whale: :whale: Cetaceans :red[& friends] :balloon:")
13
-
14
- st.markdown("## :mountain: Mountains")
15
- st.markdown(f"library version: **{st.__version__}**")
16
- # not sure where my versions are getting pegged from, but we have a 1y spread :(
17
- # https://github.com/streamlit/streamlit/blob/1.24.1/lib/streamlit/elements/map.py
18
- # rather hard to find the docs for old versions, no selector unlike many libraries.
19
-
20
- visp_loc = 46.295833, 7.883333
21
- tile_xyz = 'https://tile.opentopomap.org/{z}/{x}/{y}.png'
22
- tile_attr = '<a href="https://opentopomap.org/">Open Topo Map</a>'
23
- st.map(df, latitude='lat', longitude='lon', color='color', size='size', zoom=7)
24
- #, tiles=tile_xyz, attr=tile_attr)
25
-
26
- #st.map(df)
27
-
28
- #st.map(df, latitude="col1", longitude="col2", size="col3", color="col4")
29
-
30
- import numpy as np
31
-
32
- df2 = pd.DataFrame(
33
- {
34
- "col1": np.random.randn(1000) / 50 + 37.76,
35
- "col2": np.random.randn(1000) / 50 + -122.4,
36
- "col3": np.random.randn(1000) * 100,
37
- "col4": np.random.rand(1000, 4).tolist(),
38
- }
39
- )
40
- #st.map(df, latitude="col1", longitude="col2", size="col3", color="col4")
41
-
42
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
basic_map/requirements.txt DELETED
@@ -1,4 +0,0 @@
1
- streamlit
2
- folium
3
- streamlit-folium
4
-
 
 
 
 
 
docs/app.md CHANGED
@@ -1,5 +0,0 @@
1
- Here is the documentation for the app code generating the streamlit front-end.
2
-
3
- # Streamlit App
4
-
5
- ::: basic_map.app
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
- numpy==1.24
2
  pandas==2.2.3
3
 
4
 
 
1
+ numpy==1.26.4
2
  pandas==2.2.3
3
 
4
 
src/input_handling.py CHANGED
@@ -12,9 +12,10 @@ from streamlit.delta_generator import DeltaGenerator
12
  import cv2
13
  import numpy as np
14
 
 
 
 
15
  m_logger = logging.getLogger(__name__)
16
- # we can set the log level locally for funcs in this module
17
- #g_m_logger.setLevel(logging.DEBUG)
18
  m_logger.setLevel(logging.INFO)
19
 
20
  '''
@@ -22,11 +23,8 @@ A module to setup the input handling for the whale observation guidance tool
22
 
23
  both the UI elements (setup_input_UI) and the validation functions.
24
  '''
25
- #allowed_image_types = ['webp']
26
  allowed_image_types = ['jpg', 'jpeg', 'png', 'webp']
27
 
28
- import random
29
- import string
30
  def generate_random_md5():
31
  # Generate a random string
32
  random_string = ''.join(random.choices(string.ascii_letters + string.digits, k=16))
@@ -115,8 +113,6 @@ class InputObservation:
115
  "author_email": self.author_email,
116
  "date": self.date,
117
  "time": self.time,
118
- # "date_option": self.date_option,
119
- # "time_option": self.time_option,
120
  "date_option": str(self.date_option),
121
  "time_option": str(self.time_option),
122
  "uploaded_filename": self.uploaded_filename
@@ -168,7 +164,8 @@ def is_valid_email(email:str) -> bool:
168
  return re.match(pattern, email) is not None
169
 
170
  # Function to extract date and time from image metadata
171
- def get_image_datetime(image_file: UploadedFile) -> str | None:
 
172
  """
173
  Extracts the original date and time from the EXIF metadata of an uploaded image file.
174
 
@@ -204,7 +201,6 @@ spoof_metadata = {
204
  "time": None,
205
  }
206
 
207
- #def display_whale(whale_classes:List[str], i:int, viewcontainer=None):
208
  def setup_input(
209
  viewcontainer: DeltaGenerator=None,
210
  _allowed_image_types: list=None, ) -> InputObservation:
@@ -232,61 +228,62 @@ def setup_input(
232
 
233
  viewcontainer.title("Input image and data")
234
 
235
- # 1. Image Selector
236
- uploaded_filename = viewcontainer.file_uploader("Upload an image", type=allowed_image_types)
237
- image_datetime = None # For storing date-time from image
238
-
239
- if uploaded_filename is not None:
240
- # Display the uploaded image
241
- #image = Image.open(uploaded_filename)
242
- # load image using cv2 format, so it is compatible with the ML models
243
- file_bytes = np.asarray(bytearray(uploaded_filename.read()), dtype=np.uint8)
244
- image = cv2.imdecode(file_bytes, 1)
245
-
246
-
247
- viewcontainer.image(image, caption='Uploaded Image.', use_column_width=True)
248
- # store the image in the session state
249
- st.session_state.image = image
250
-
251
-
252
- # Extract and display image date-time
253
- image_datetime = get_image_datetime(uploaded_filename)
254
- print(f"[D] image date extracted as {image_datetime}")
255
- m_logger.debug(f"image date extracted as {image_datetime} (from {uploaded_filename})")
256
-
257
-
258
- # 2. Latitude Entry Box
259
- latitude = viewcontainer.text_input("Latitude", spoof_metadata.get('latitude', ""))
260
- if latitude and not is_valid_number(latitude):
261
- viewcontainer.error("Please enter a valid latitude (numerical only).")
262
- m_logger.error(f"Invalid latitude entered: {latitude}.")
263
- # 3. Longitude Entry Box
264
- longitude = viewcontainer.text_input("Longitude", spoof_metadata.get('longitude', ""))
265
- if longitude and not is_valid_number(longitude):
266
- viewcontainer.error("Please enter a valid longitude (numerical only).")
267
- m_logger.error(f"Invalid latitude entered: {latitude}.")
268
-
269
- # 4. Author Box with Email Address Validator
270
  author_email = viewcontainer.text_input("Author Email", spoof_metadata.get('author_email', ""))
271
-
272
  if author_email and not is_valid_email(author_email):
273
  viewcontainer.error("Please enter a valid email address.")
274
 
275
- # 5. date/time
276
- ## first from image metadata
277
- if image_datetime is not None:
278
- time_value = datetime.datetime.strptime(image_datetime, '%Y:%m:%d %H:%M:%S').time()
279
- date_value = datetime.datetime.strptime(image_datetime, '%Y:%m:%d %H:%M:%S').date()
280
- else:
281
- time_value = datetime.datetime.now().time() # Default to current time
282
- date_value = datetime.datetime.now().date()
283
-
284
- ## if not, give user the option to enter manually
285
- date_option = st.sidebar.date_input("Date", value=date_value)
286
- time_option = st.sidebar.time_input("Time", time_value)
287
-
288
- observation = InputObservation(image=uploaded_filename, latitude=latitude, longitude=longitude,
289
- author_email=author_email, date=image_datetime, time=None,
290
- date_option=date_option, time_option=time_option)
291
- return observation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
 
 
12
  import cv2
13
  import numpy as np
14
 
15
+ import random
16
+ import string
17
+
18
  m_logger = logging.getLogger(__name__)
 
 
19
  m_logger.setLevel(logging.INFO)
20
 
21
  '''
 
23
 
24
  both the UI elements (setup_input_UI) and the validation functions.
25
  '''
 
26
  allowed_image_types = ['jpg', 'jpeg', 'png', 'webp']
27
 
 
 
28
  def generate_random_md5():
29
  # Generate a random string
30
  random_string = ''.join(random.choices(string.ascii_letters + string.digits, k=16))
 
113
  "author_email": self.author_email,
114
  "date": self.date,
115
  "time": self.time,
 
 
116
  "date_option": str(self.date_option),
117
  "time_option": str(self.time_option),
118
  "uploaded_filename": self.uploaded_filename
 
164
  return re.match(pattern, email) is not None
165
 
166
  # Function to extract date and time from image metadata
167
+ # def get_image_datetime(image_file: UploadedFile) -> str | None:
168
+ def get_image_datetime(image_file):
169
  """
170
  Extracts the original date and time from the EXIF metadata of an uploaded image file.
171
 
 
201
  "time": None,
202
  }
203
 
 
204
  def setup_input(
205
  viewcontainer: DeltaGenerator=None,
206
  _allowed_image_types: list=None, ) -> InputObservation:
 
228
 
229
  viewcontainer.title("Input image and data")
230
 
231
+ # 1. Input the author email
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
  author_email = viewcontainer.text_input("Author Email", spoof_metadata.get('author_email', ""))
 
233
  if author_email and not is_valid_email(author_email):
234
  viewcontainer.error("Please enter a valid email address.")
235
 
236
+ # 2. Image Selector
237
+ uploaded_files = viewcontainer.file_uploader("Upload an image", type=allowed_image_types, accept_multiple_files=True)
238
+ observations = {}
239
+ images = {}
240
+ if uploaded_files is not None:
241
+ for file in uploaded_files:
242
+
243
+ viewcontainer.title(f"Metadata for {file.name}")
244
+
245
+ # Display the uploaded image
246
+ # load image using cv2 format, so it is compatible with the ML models
247
+ file_bytes = np.asarray(bytearray(file.read()), dtype=np.uint8)
248
+ filename = file.name
249
+ image = cv2.imdecode(file_bytes, 1)
250
+ # Extract and display image date-time
251
+ image_datetime = None # For storing date-time from image
252
+ image_datetime = get_image_datetime(file)
253
+ m_logger.debug(f"image date extracted as {image_datetime} (from {uploaded_files})")
254
+
255
+
256
+ # 3. Latitude Entry Box
257
+ latitude = viewcontainer.text_input("Latitude for "+filename, spoof_metadata.get('latitude', ""))
258
+ if latitude and not is_valid_number(latitude):
259
+ viewcontainer.error("Please enter a valid latitude (numerical only).")
260
+ m_logger.error(f"Invalid latitude entered: {latitude}.")
261
+ # 4. Longitude Entry Box
262
+ longitude = viewcontainer.text_input("Longitude for "+filename, spoof_metadata.get('longitude', ""))
263
+ if longitude and not is_valid_number(longitude):
264
+ viewcontainer.error("Please enter a valid longitude (numerical only).")
265
+ m_logger.error(f"Invalid latitude entered: {latitude}.")
266
+ # 5. Date/time
267
+ ## first from image metadata
268
+ if image_datetime is not None:
269
+ time_value = datetime.datetime.strptime(image_datetime, '%Y:%m:%d %H:%M:%S').time()
270
+ date_value = datetime.datetime.strptime(image_datetime, '%Y:%m:%d %H:%M:%S').date()
271
+ else:
272
+ time_value = datetime.datetime.now().time() # Default to current time
273
+ date_value = datetime.datetime.now().date()
274
+
275
+ ## if not, give user the option to enter manually
276
+ date_option = st.sidebar.date_input("Date for "+filename, value=date_value)
277
+ time_option = st.sidebar.time_input("Time for "+filename, time_value)
278
+
279
+ observation = InputObservation(image=file, latitude=latitude, longitude=longitude,
280
+ author_email=author_email, date=image_datetime, time=None,
281
+ date_option=date_option, time_option=time_option)
282
+ observations[file.name] = observation
283
+ images[file.name] = image
284
+
285
+ st.session_state.image = images
286
+ st.session_state.files = uploaded_files
287
+
288
+ return observations
289
 
src/main.py CHANGED
@@ -77,14 +77,14 @@ def metadata2md() -> str:
77
 
78
  """
79
  markdown_str = "\n"
80
- for key, value in st.session_state.full_data.items():
81
  markdown_str += f"- **{key}**: {value}\n"
82
  return markdown_str
83
 
84
 
85
- def push_observation(tab_log:DeltaGenerator=None):
86
  """
87
- Push the observation to the Hugging Face dataset
88
 
89
  Args:
90
  tab_log (streamlit.container): The container to log messages to. If not provided,
@@ -94,12 +94,12 @@ def push_observation(tab_log:DeltaGenerator=None):
94
  """
95
  # we get the data from session state: 1 is the dict 2 is the image.
96
  # first, lets do an info display (popup)
97
- metadata_str = json.dumps(st.session_state.full_data)
98
 
99
- st.toast(f"Uploading observation: {metadata_str}", icon="🦭")
100
  tab_log = st.session_state.tab_log
101
  if tab_log is not None:
102
- tab_log.info(f"Uploading observation: {metadata_str}")
103
 
104
  # get huggingface api
105
  import os
@@ -111,7 +111,7 @@ def push_observation(tab_log:DeltaGenerator=None):
111
  f.close()
112
  st.info(f"temp file: {f.name} with metadata written...")
113
 
114
- path_in_repo= f"metadata/{st.session_state.full_data['author_email']}/{st.session_state.full_data['image_md5']}.json"
115
  msg = f"fname: {f.name} | path: {path_in_repo}"
116
  print(msg)
117
  st.warning(msg)
@@ -134,7 +134,7 @@ def main() -> None:
134
 
135
  The organisation is as follows:
136
 
137
- 1. data input (a new observation) is handled in the sidebar
138
  2. the rest of the interface is organised in tabs:
139
 
140
  - cetean classifier
@@ -161,12 +161,12 @@ def main() -> None:
161
  st.session_state.tab_log = tab_log
162
 
163
 
164
- # create a sidebar, and parse all the input (returned as `observation` object)
165
- observation = sw_inp.setup_input(viewcontainer=st.sidebar)
166
 
167
 
168
  if 0:## WIP
169
- # goal of this code is to allow the user to override the ML prediction, before transmitting an observation
170
  predicted_class = st.sidebar.selectbox("Predicted Class", sw_wv.WHALE_CLASSES)
171
  override_prediction = st.sidebar.checkbox("Override Prediction")
172
 
@@ -236,18 +236,13 @@ def main() -> None:
236
  # Display submitted data
237
  if st.sidebar.button("Validate"):
238
  # create a dictionary with the submitted data
239
- submitted_data = observation.to_dict()
240
- #print(submitted_data)
241
-
242
- #full_data.update(**submitted_data)
243
- for k, v in submitted_data.items():
244
- st.session_state.full_data[k] = v
245
 
246
- #st.write(f"full dict of data: {json.dumps(submitted_data)}")
247
- #tab_inference.info(f"{st.session_state.full_data}")
248
  tab_log.info(f"{st.session_state.full_data}")
249
 
250
- df = pd.DataFrame(submitted_data, index=[0])
 
251
  with tab_data:
252
  st.table(df)
253
 
@@ -259,7 +254,7 @@ def main() -> None:
259
  # - the model predicts the top 3 most likely species from the input image
260
  # - these species are shown
261
  # - the user can override the species prediction using the dropdown
262
- # - an observation is uploaded if the user chooses.
263
 
264
  if tab_inference.button("Identify with cetacean classifier"):
265
  #pipe = pipeline("image-classification", model="Saving-Willy/cetacean-classifier", trust_remote_code=True)
@@ -271,44 +266,53 @@ def main() -> None:
271
  # TODO: cleaner design to disable the button until data input done?
272
  st.info("Please upload an image first.")
273
  else:
274
- # run classifier model on `image`, and persistently store the output
275
- out = cetacean_classifier(st.session_state.image) # get top 3 matches
276
- st.session_state.whale_prediction1 = out['predictions'][0]
277
- st.session_state.classify_whale_done = True
278
- msg = f"[D]2 classify_whale_done: {st.session_state.classify_whale_done}, whale_prediction1: {st.session_state.whale_prediction1}"
279
- st.info(msg)
280
- g_logger.info(msg)
281
-
282
- # dropdown for selecting/overriding the species prediction
283
- #st.info(f"[D] classify_whale_done: {st.session_state.classify_whale_done}, whale_prediction1: {st.session_state.whale_prediction1}")
284
- if not st.session_state.classify_whale_done:
285
- selected_class = tab_inference.sidebar.selectbox("Species", sw_wv.WHALE_CLASSES, index=None, placeholder="Species not yet identified...", disabled=True)
286
- else:
287
- pred1 = st.session_state.whale_prediction1
288
- # get index of pred1 from WHALE_CLASSES, none if not present
289
- print(f"[D] pred1: {pred1}")
290
- ix = sw_wv.WHALE_CLASSES.index(pred1) if pred1 in sw_wv.WHALE_CLASSES else None
291
- selected_class = tab_inference.selectbox("Species", sw_wv.WHALE_CLASSES, index=ix)
292
-
293
- st.session_state.full_data['predicted_class'] = selected_class
294
- if selected_class != st.session_state.whale_prediction1:
295
- st.session_state.full_data['class_overriden'] = selected_class
 
 
 
 
296
 
297
- btn = st.button("Upload observation to THE INTERNET!", on_click=push_observation)
298
- # TODO: the metadata only fills properly if `validate` was clicked.
299
- tab_inference.markdown(metadata2md())
300
-
301
- msg = f"[D] full data after inference: {st.session_state.full_data}"
302
- g_logger.debug(msg)
303
- print(msg)
304
- # TODO: add a link to more info on the model, next to the button.
305
-
306
- whale_classes = out['predictions'][:]
307
- # render images for the top 3 (that is what the model api returns)
308
- with tab_inference:
309
- st.markdown("## Species detected")
310
- for i in range(len(whale_classes)):
311
- sw_wv.display_whale(whale_classes, i)
 
 
 
 
 
312
 
313
 
314
 
@@ -325,27 +329,29 @@ def main() -> None:
325
 
326
  if st.session_state.image is None:
327
  st.info("Please upload an image first.")
328
- st.info(str(observation.to_dict()))
329
 
330
  else:
331
  col1, col2 = tab_hotdogs.columns(2)
332
-
333
- # display the image (use cached version, no need to reread)
334
- col1.image(st.session_state.image, use_column_width=True)
335
- # and then run inference on the image
336
- hotdog_image = Image.fromarray(st.session_state.image)
337
- predictions = pipeline_hot_dog(hotdog_image)
338
-
339
- col2.header("Probabilities")
340
- first = True
341
- for p in predictions:
342
- col2.subheader(f"{ p['label'] }: { round(p['score'] * 100, 1)}%")
343
- if first:
344
- st.session_state.full_data['predicted_class'] = p['label']
345
- st.session_state.full_data['predicted_score'] = round(p['score'] * 100, 1)
346
- first = False
347
-
348
- tab_hotdogs.write(f"Session Data: {json.dumps(st.session_state.full_data)}")
 
 
349
 
350
 
351
 
 
77
 
78
  """
79
  markdown_str = "\n"
80
+ for key, value in st.session_state.public_observation.items():
81
  markdown_str += f"- **{key}**: {value}\n"
82
  return markdown_str
83
 
84
 
85
+ def push_observations(tab_log:DeltaGenerator=None):
86
  """
87
+ Push the observations to the Hugging Face dataset
88
 
89
  Args:
90
  tab_log (streamlit.container): The container to log messages to. If not provided,
 
94
  """
95
  # we get the data from session state: 1 is the dict 2 is the image.
96
  # first, lets do an info display (popup)
97
+ metadata_str = json.dumps(st.session_state.public_observation)
98
 
99
+ st.toast(f"Uploading observations: {metadata_str}", icon="🦭")
100
  tab_log = st.session_state.tab_log
101
  if tab_log is not None:
102
+ tab_log.info(f"Uploading observations: {metadata_str}")
103
 
104
  # get huggingface api
105
  import os
 
111
  f.close()
112
  st.info(f"temp file: {f.name} with metadata written...")
113
 
114
+ path_in_repo= f"metadata/{st.session_state.public_observation['author_email']}/{st.session_state.public_observation['image_md5']}.json"
115
  msg = f"fname: {f.name} | path: {path_in_repo}"
116
  print(msg)
117
  st.warning(msg)
 
134
 
135
  The organisation is as follows:
136
 
137
+ 1. data input (a new observations) is handled in the sidebar
138
  2. the rest of the interface is organised in tabs:
139
 
140
  - cetean classifier
 
161
  st.session_state.tab_log = tab_log
162
 
163
 
164
+ # create a sidebar, and parse all the input (returned as `observations` object)
165
+ observations = sw_inp.setup_input(viewcontainer=st.sidebar)
166
 
167
 
168
  if 0:## WIP
169
+ # goal of this code is to allow the user to override the ML prediction, before transmitting an observations
170
  predicted_class = st.sidebar.selectbox("Predicted Class", sw_wv.WHALE_CLASSES)
171
  override_prediction = st.sidebar.checkbox("Override Prediction")
172
 
 
236
  # Display submitted data
237
  if st.sidebar.button("Validate"):
238
  # create a dictionary with the submitted data
239
+ submitted_data = observations
240
+ st.session_state.full_data = observations
 
 
 
 
241
 
 
 
242
  tab_log.info(f"{st.session_state.full_data}")
243
 
244
+ df = pd.DataFrame(submitted_data)
245
+ print("Dataframe Shape: ", df.shape)
246
  with tab_data:
247
  st.table(df)
248
 
 
254
  # - the model predicts the top 3 most likely species from the input image
255
  # - these species are shown
256
  # - the user can override the species prediction using the dropdown
257
+ # - an observations is uploaded if the user chooses.
258
 
259
  if tab_inference.button("Identify with cetacean classifier"):
260
  #pipe = pipeline("image-classification", model="Saving-Willy/cetacean-classifier", trust_remote_code=True)
 
266
  # TODO: cleaner design to disable the button until data input done?
267
  st.info("Please upload an image first.")
268
  else:
269
+ files = st.session_state.files
270
+ images = st.session_state.images
271
+ full_data = st.session_state.full_data
272
+ for file in files:
273
+ image = images[file]
274
+ data = full_data[file]
275
+ # run classifier model on `image`, and persistently store the output
276
+ out = cetacean_classifier(image) # get top 3 matches
277
+ st.session_state.whale_prediction1 = out['predictions'][0]
278
+ st.session_state.classify_whale_done = True
279
+ msg = f"[D]2 classify_whale_done: {st.session_state.classify_whale_done}, whale_prediction1: {st.session_state.whale_prediction1}"
280
+ # st.info(msg)
281
+ g_logger.info(msg)
282
+
283
+ # dropdown for selecting/overriding the species prediction
284
+ #st.info(f"[D] classify_whale_done: {st.session_state.classify_whale_done}, whale_prediction1: {st.session_state.whale_prediction1}")
285
+ if not st.session_state.classify_whale_done:
286
+ selected_class = tab_inference.sidebar.selectbox("Species", sw_wv.WHALE_CLASSES,
287
+ index=None, placeholder="Species not yet identified...",
288
+ disabled=True)
289
+ else:
290
+ pred1 = st.session_state.whale_prediction1
291
+ # get index of pred1 from WHALE_CLASSES, none if not present
292
+ print(f"[D] pred1: {pred1}")
293
+ ix = sw_wv.WHALE_CLASSES.index(pred1) if pred1 in sw_wv.WHALE_CLASSES else None
294
+ selected_class = tab_inference.selectbox("Species", sw_wv.WHALE_CLASSES, index=ix)
295
 
296
+ data['predicted_class'] = selected_class
297
+ if selected_class != st.session_state.whale_prediction1:
298
+ data['class_overriden'] = selected_class
299
+
300
+ st.session_state.public_observation = data
301
+ st.button("Upload observations to THE INTERNET!", on_click=push_observations)
302
+ # TODO: the metadata only fills properly if `validate` was clicked.
303
+ tab_inference.markdown(metadata2md())
304
+
305
+ msg = f"[D] full data after inference: {data}"
306
+ g_logger.debug(msg)
307
+ print(msg)
308
+ # TODO: add a link to more info on the model, next to the button.
309
+
310
+ whale_classes = out['predictions'][:]
311
+ # render images for the top 3 (that is what the model api returns)
312
+ with tab_inference:
313
+ st.markdown("## Species detected")
314
+ for i in range(len(whale_classes)):
315
+ sw_wv.display_whale(whale_classes, i)
316
 
317
 
318
 
 
329
 
330
  if st.session_state.image is None:
331
  st.info("Please upload an image first.")
332
+ st.info(str(observations.to_dict()))
333
 
334
  else:
335
  col1, col2 = tab_hotdogs.columns(2)
336
+ for file in st.session_state.files:
337
+ image = st.session_state.images[file]
338
+ data = st.session_state.full_data[file]
339
+ # display the image (use cached version, no need to reread)
340
+ col1.image(image, use_column_width=True)
341
+ # and then run inference on the image
342
+ hotdog_image = Image.fromarray(image)
343
+ predictions = pipeline_hot_dog(hotdog_image)
344
+
345
+ col2.header("Probabilities")
346
+ first = True
347
+ for p in predictions:
348
+ col2.subheader(f"{ p['label'] }: { round(p['score'] * 100, 1)}%")
349
+ if first:
350
+ data['predicted_class'] = p['label']
351
+ data['predicted_score'] = round(p['score'] * 100, 1)
352
+ first = False
353
+
354
+ tab_hotdogs.write(f"Session Data: {json.dumps(data)}")
355
 
356
 
357