rmm commited on
Commit
c915f7c
·
1 Parent(s): 5823912

feat: extended InputObservation to contain species/prediction info

Browse files

- when manual validation is performed (dropdown selection among
species), it is written to the observations (And not the
dynamically-created dicts).

- TODO: decide if we need to retain public_observations in
session_state, or just generate the dict each time it is needed.

src/classifier/classifier_image.py CHANGED
@@ -10,6 +10,7 @@ import whale_viewer as viewer
10
  from hf_push_observations import push_observations
11
  from utils.grid_maker import gridder
12
  from utils.metadata_handler import metadata2md
 
13
 
14
  def add_header_text() -> None:
15
  """
@@ -24,12 +25,11 @@ def add_header_text() -> None:
24
  def cetacean_just_classify(cetacean_classifier):
25
 
26
  images = st.session_state.images
27
- observations = st.session_state.observations
28
  hashes = st.session_state.image_hashes
29
 
30
  for hash in hashes:
31
  image = images[hash]
32
- observation = observations[hash].to_dict()
33
  # run classifier model on `image`, and persistently store the output
34
  out = cetacean_classifier(image) # get top 3 matches
35
  st.session_state.whale_prediction1[hash] = out['predictions'][0]
@@ -39,8 +39,6 @@ def cetacean_just_classify(cetacean_classifier):
39
  msg = f"[D]2 classify_whale_done for {hash}: {st.session_state.classify_whale_done[hash]}, whale_prediction1: {st.session_state.whale_prediction1[hash]}"
40
  g_logger.info(msg)
41
 
42
- # store the elements of the observation that will be transmitted (not image)
43
- st.session_state.public_observations[hash] = observation
44
  if st.session_state.MODE_DEV_STATEFUL:
45
  st.write(f"*[D] Observation {hash} classified as {st.session_state.whale_prediction1[hash]}*")
46
 
@@ -58,7 +56,8 @@ def cetacean_show_results_and_review():
58
 
59
  for hash in hashes:
60
  image = images[hash]
61
- observation = observations[hash].to_dict()
 
62
 
63
  with grid[col]:
64
  st.image(image, use_column_width=True)
@@ -75,14 +74,19 @@ def cetacean_show_results_and_review():
75
  ix = viewer.WHALE_CLASSES.index(pred1) if pred1 in viewer.WHALE_CLASSES else None
76
  selected_class = st.selectbox(f"Species for observation {str(o)}", viewer.WHALE_CLASSES, index=ix)
77
 
78
- observation['predicted_class'] = selected_class
79
- if selected_class != st.session_state.whale_prediction1[hash]:
80
- observation['class_overriden'] = selected_class # TODO: this should be boolean!
 
 
81
 
 
 
82
  st.session_state.public_observations[hash] = observation
 
83
  #st.button(f"Upload observation {str(o)} to THE INTERNET!", on_click=push_observations)
84
  # TODO: the metadata only fills properly if `validate` was clicked.
85
- st.markdown(metadata2md(hash))
86
 
87
  msg = f"[D] full observation after inference: {observation}"
88
  g_logger.debug(msg)
@@ -138,12 +142,7 @@ def cetacean_show_results():
138
 
139
  #st.button(f"Upload observation {str(o)} to THE INTERNET!", on_click=push_observations)
140
  #
141
- st.markdown(metadata2md(hash))
142
- # TODO: FIXME: this is the data taht will get pushed -- it DOESN'T reflect any adjustments
143
- # # made via the dropdown on the last step!!!!
144
- #st.markdown(f"- **selected species**: {observation['predicted_class']}")
145
- st.markdown(f"- **selected species**: {st.session_state.whale_prediction1[hash]}")
146
- st.markdown(f"- **hash**: {hash}")
147
 
148
  msg = f"[D] full observation after inference: {observation}"
149
  g_logger.debug(msg)
@@ -223,4 +222,4 @@ def cetacean_classify_show_and_review(cetacean_classifier):
223
  for i in range(len(whale_classes)):
224
  viewer.display_whale(whale_classes, i)
225
  o += 1
226
- col = (col + 1) % row_size
 
10
  from hf_push_observations import push_observations
11
  from utils.grid_maker import gridder
12
  from utils.metadata_handler import metadata2md
13
+ from input.input_observation import InputObservation
14
 
15
  def add_header_text() -> None:
16
  """
 
25
  def cetacean_just_classify(cetacean_classifier):
26
 
27
  images = st.session_state.images
28
+ #observations = st.session_state.observations
29
  hashes = st.session_state.image_hashes
30
 
31
  for hash in hashes:
32
  image = images[hash]
 
33
  # run classifier model on `image`, and persistently store the output
34
  out = cetacean_classifier(image) # get top 3 matches
35
  st.session_state.whale_prediction1[hash] = out['predictions'][0]
 
39
  msg = f"[D]2 classify_whale_done for {hash}: {st.session_state.classify_whale_done[hash]}, whale_prediction1: {st.session_state.whale_prediction1[hash]}"
40
  g_logger.info(msg)
41
 
 
 
42
  if st.session_state.MODE_DEV_STATEFUL:
43
  st.write(f"*[D] Observation {hash} classified as {st.session_state.whale_prediction1[hash]}*")
44
 
 
56
 
57
  for hash in hashes:
58
  image = images[hash]
59
+ #observation = observations[hash].to_dict()
60
+ _observation:InputObservation = observations[hash]
61
 
62
  with grid[col]:
63
  st.image(image, use_column_width=True)
 
74
  ix = viewer.WHALE_CLASSES.index(pred1) if pred1 in viewer.WHALE_CLASSES else None
75
  selected_class = st.selectbox(f"Species for observation {str(o)}", viewer.WHALE_CLASSES, index=ix)
76
 
77
+ _observation.set_selected_class(selected_class)
78
+ #observation['predicted_class'] = selected_class
79
+ # this logic is now in the InputObservation class automatially
80
+ #if selected_class != st.session_state.whale_prediction1[hash]:
81
+ # observation['class_overriden'] = selected_class # TODO: this should be boolean!
82
 
83
+ # store the elements of the observation that will be transmitted (not image)
84
+ observation = _observation.to_dict()
85
  st.session_state.public_observations[hash] = observation
86
+
87
  #st.button(f"Upload observation {str(o)} to THE INTERNET!", on_click=push_observations)
88
  # TODO: the metadata only fills properly if `validate` was clicked.
89
+ st.markdown(metadata2md(hash, debug=True))
90
 
91
  msg = f"[D] full observation after inference: {observation}"
92
  g_logger.debug(msg)
 
142
 
143
  #st.button(f"Upload observation {str(o)} to THE INTERNET!", on_click=push_observations)
144
  #
145
+ st.markdown(metadata2md(hash, debug=True))
 
 
 
 
 
146
 
147
  msg = f"[D] full observation after inference: {observation}"
148
  g_logger.debug(msg)
 
222
  for i in range(len(whale_classes)):
223
  viewer.display_whale(whale_classes, i)
224
  o += 1
225
+ col = (col + 1) % row_size
src/input/input_observation.py CHANGED
@@ -68,7 +68,10 @@ class InputObservation:
68
  self.time = time
69
  self.uploaded_file = uploaded_file
70
  self.image_md5 = image_md5
 
71
  self._top_predictions = []
 
 
72
 
73
  InputObservation._inst_count += 1
74
  self._inst_id = InputObservation._inst_count
@@ -81,11 +84,30 @@ class InputObservation:
81
 
82
  def set_top_predictions(self, top_predictions:list):
83
  self._top_predictions = top_predictions
 
 
84
 
85
- # add a method to get the top predictions (property?)
 
 
 
 
 
 
 
 
86
  @property
87
  def top_predictions(self):
88
  return self._top_predictions
 
 
 
 
 
 
 
 
 
89
 
90
  # add a method to assign the image_md5 only once
91
  def assign_image_md5(self):
@@ -194,6 +216,10 @@ class InputObservation:
194
  "image_datetime_raw": self.image_datetime_raw,
195
  "date": str(self.date),
196
  "time": str(self.time),
 
 
 
 
197
  #"uploaded_file": self.uploaded_file # can't serialize this in json, not sent to dataset anyway.
198
  }
199
 
 
68
  self.time = time
69
  self.uploaded_file = uploaded_file
70
  self.image_md5 = image_md5
71
+ # attributes that get set after predictions/processing
72
  self._top_predictions = []
73
+ self._selected_class = None
74
+ self._class_overriden = False
75
 
76
  InputObservation._inst_count += 1
77
  self._inst_id = InputObservation._inst_count
 
84
 
85
  def set_top_predictions(self, top_predictions:list):
86
  self._top_predictions = top_predictions
87
+ if len(top_predictions) > 0:
88
+ self.set_selected_class(top_predictions[0])
89
 
90
+ def set_selected_class(self, selected_class:str):
91
+ self._selected_class = selected_class
92
+ if selected_class != self._top_predictions[0]:
93
+ self.set_class_overriden(True)
94
+
95
+ def set_class_overriden(self, class_overriden:bool):
96
+ self._class_overriden = class_overriden
97
+
98
+ # add getters for the top_predictions, selected_class and class_overriden
99
  @property
100
  def top_predictions(self):
101
  return self._top_predictions
102
+
103
+ @property
104
+ def selected_class(self):
105
+ return self._selected_class
106
+
107
+ @property
108
+ def class_overriden(self):
109
+ return self._class_overriden
110
+
111
 
112
  # add a method to assign the image_md5 only once
113
  def assign_image_md5(self):
 
216
  "image_datetime_raw": self.image_datetime_raw,
217
  "date": str(self.date),
218
  "time": str(self.time),
219
+ "selected_class": self._selected_class,
220
+ "top_prediction": self._top_predictions[0] if len(self._top_predictions) else None,
221
+ "class_overriden": self._class_overriden,
222
+
223
  #"uploaded_file": self.uploaded_file # can't serialize this in json, not sent to dataset anyway.
224
  }
225
 
src/main.py CHANGED
@@ -237,7 +237,8 @@ def main() -> None:
237
  if st.sidebar.button(":white_check_mark:[**Validate**]"):
238
  # create a dictionary with the submitted observation
239
  tab_log.info(f"{st.session_state.observations}")
240
- df = pd.DataFrame(st.session_state.observations, index=[0])
 
241
  with tab_coords:
242
  st.table(df)
243
  # there doesn't seem to be any actual validation here?? TODO: find validator function (each element is validated by the input box, but is there something at the whole image level?)
@@ -320,7 +321,8 @@ def main() -> None:
320
  cetacean_show_results()
321
 
322
  st.divider()
323
- df = pd.DataFrame(st.session_state.observations, index=[0])
 
324
  st.table(df)
325
 
326
  # didn't decide what the next state is here - I think we are in the terminal state.
 
237
  if st.sidebar.button(":white_check_mark:[**Validate**]"):
238
  # create a dictionary with the submitted observation
239
  tab_log.info(f"{st.session_state.observations}")
240
+ df = pd.DataFrame([obs.to_dict() for obs in st.session_state.observations.values()])
241
+ #df = pd.DataFrame(st.session_state.observations, index=[0])
242
  with tab_coords:
243
  st.table(df)
244
  # there doesn't seem to be any actual validation here?? TODO: find validator function (each element is validated by the input box, but is there something at the whole image level?)
 
321
  cetacean_show_results()
322
 
323
  st.divider()
324
+ #df = pd.DataFrame(st.session_state.observations, index=[0])
325
+ df = pd.DataFrame([obs.to_dict() for obs in st.session_state.observations.values()])
326
  st.table(df)
327
 
328
  # didn't decide what the next state is here - I think we are in the terminal state.
src/utils/metadata_handler.py CHANGED
@@ -1,10 +1,11 @@
1
  import streamlit as st
2
 
3
- def metadata2md(image_hash:str) -> str:
4
  """Get metadata from cache and return as markdown-formatted key-value list
5
 
6
  Args:
7
  image_hash (str): The hash of the image to get metadata for
 
8
 
9
  Returns:
10
  str: Markdown-formatted key-value list of metadata
@@ -12,6 +13,8 @@ def metadata2md(image_hash:str) -> str:
12
  """
13
  markdown_str = "\n"
14
  keys_to_print = ["author_email", "latitude", "longitude", "date", "time"]
 
 
15
 
16
  observation = st.session_state.public_observations.get(image_hash, {})
17
 
 
1
  import streamlit as st
2
 
3
+ def metadata2md(image_hash:str, debug:bool=False) -> str:
4
  """Get metadata from cache and return as markdown-formatted key-value list
5
 
6
  Args:
7
  image_hash (str): The hash of the image to get metadata for
8
+ debug (bool, optional): Whether to print additional fields.
9
 
10
  Returns:
11
  str: Markdown-formatted key-value list of metadata
 
13
  """
14
  markdown_str = "\n"
15
  keys_to_print = ["author_email", "latitude", "longitude", "date", "time"]
16
+ if debug:
17
+ keys_to_print += ["iamge_md5", "selected_class", "top_prediction", "class_overriden"]
18
 
19
  observation = st.session_state.public_observations.get(image_hash, {})
20