Spaces:
Running
Running
rmm
commited on
Commit
·
c915f7c
1
Parent(s):
5823912
feat: extended InputObservation to contain species/prediction info
Browse files- when manual validation is performed (dropdown selection among
species), it is written to the observations (And not the
dynamically-created dicts).
- TODO: decide if we need to retain public_observations in
session_state, or just generate the dict each time it is needed.
- src/classifier/classifier_image.py +15 -16
- src/input/input_observation.py +27 -1
- src/main.py +4 -2
- src/utils/metadata_handler.py +4 -1
src/classifier/classifier_image.py
CHANGED
@@ -10,6 +10,7 @@ import whale_viewer as viewer
|
|
10 |
from hf_push_observations import push_observations
|
11 |
from utils.grid_maker import gridder
|
12 |
from utils.metadata_handler import metadata2md
|
|
|
13 |
|
14 |
def add_header_text() -> None:
|
15 |
"""
|
@@ -24,12 +25,11 @@ def add_header_text() -> None:
|
|
24 |
def cetacean_just_classify(cetacean_classifier):
|
25 |
|
26 |
images = st.session_state.images
|
27 |
-
observations = st.session_state.observations
|
28 |
hashes = st.session_state.image_hashes
|
29 |
|
30 |
for hash in hashes:
|
31 |
image = images[hash]
|
32 |
-
observation = observations[hash].to_dict()
|
33 |
# run classifier model on `image`, and persistently store the output
|
34 |
out = cetacean_classifier(image) # get top 3 matches
|
35 |
st.session_state.whale_prediction1[hash] = out['predictions'][0]
|
@@ -39,8 +39,6 @@ def cetacean_just_classify(cetacean_classifier):
|
|
39 |
msg = f"[D]2 classify_whale_done for {hash}: {st.session_state.classify_whale_done[hash]}, whale_prediction1: {st.session_state.whale_prediction1[hash]}"
|
40 |
g_logger.info(msg)
|
41 |
|
42 |
-
# store the elements of the observation that will be transmitted (not image)
|
43 |
-
st.session_state.public_observations[hash] = observation
|
44 |
if st.session_state.MODE_DEV_STATEFUL:
|
45 |
st.write(f"*[D] Observation {hash} classified as {st.session_state.whale_prediction1[hash]}*")
|
46 |
|
@@ -58,7 +56,8 @@ def cetacean_show_results_and_review():
|
|
58 |
|
59 |
for hash in hashes:
|
60 |
image = images[hash]
|
61 |
-
observation = observations[hash].to_dict()
|
|
|
62 |
|
63 |
with grid[col]:
|
64 |
st.image(image, use_column_width=True)
|
@@ -75,14 +74,19 @@ def cetacean_show_results_and_review():
|
|
75 |
ix = viewer.WHALE_CLASSES.index(pred1) if pred1 in viewer.WHALE_CLASSES else None
|
76 |
selected_class = st.selectbox(f"Species for observation {str(o)}", viewer.WHALE_CLASSES, index=ix)
|
77 |
|
78 |
-
|
79 |
-
|
80 |
-
|
|
|
|
|
81 |
|
|
|
|
|
82 |
st.session_state.public_observations[hash] = observation
|
|
|
83 |
#st.button(f"Upload observation {str(o)} to THE INTERNET!", on_click=push_observations)
|
84 |
# TODO: the metadata only fills properly if `validate` was clicked.
|
85 |
-
st.markdown(metadata2md(hash))
|
86 |
|
87 |
msg = f"[D] full observation after inference: {observation}"
|
88 |
g_logger.debug(msg)
|
@@ -138,12 +142,7 @@ def cetacean_show_results():
|
|
138 |
|
139 |
#st.button(f"Upload observation {str(o)} to THE INTERNET!", on_click=push_observations)
|
140 |
#
|
141 |
-
st.markdown(metadata2md(hash))
|
142 |
-
# TODO: FIXME: this is the data taht will get pushed -- it DOESN'T reflect any adjustments
|
143 |
-
# # made via the dropdown on the last step!!!!
|
144 |
-
#st.markdown(f"- **selected species**: {observation['predicted_class']}")
|
145 |
-
st.markdown(f"- **selected species**: {st.session_state.whale_prediction1[hash]}")
|
146 |
-
st.markdown(f"- **hash**: {hash}")
|
147 |
|
148 |
msg = f"[D] full observation after inference: {observation}"
|
149 |
g_logger.debug(msg)
|
@@ -223,4 +222,4 @@ def cetacean_classify_show_and_review(cetacean_classifier):
|
|
223 |
for i in range(len(whale_classes)):
|
224 |
viewer.display_whale(whale_classes, i)
|
225 |
o += 1
|
226 |
-
col = (col + 1) % row_size
|
|
|
10 |
from hf_push_observations import push_observations
|
11 |
from utils.grid_maker import gridder
|
12 |
from utils.metadata_handler import metadata2md
|
13 |
+
from input.input_observation import InputObservation
|
14 |
|
15 |
def add_header_text() -> None:
|
16 |
"""
|
|
|
25 |
def cetacean_just_classify(cetacean_classifier):
|
26 |
|
27 |
images = st.session_state.images
|
28 |
+
#observations = st.session_state.observations
|
29 |
hashes = st.session_state.image_hashes
|
30 |
|
31 |
for hash in hashes:
|
32 |
image = images[hash]
|
|
|
33 |
# run classifier model on `image`, and persistently store the output
|
34 |
out = cetacean_classifier(image) # get top 3 matches
|
35 |
st.session_state.whale_prediction1[hash] = out['predictions'][0]
|
|
|
39 |
msg = f"[D]2 classify_whale_done for {hash}: {st.session_state.classify_whale_done[hash]}, whale_prediction1: {st.session_state.whale_prediction1[hash]}"
|
40 |
g_logger.info(msg)
|
41 |
|
|
|
|
|
42 |
if st.session_state.MODE_DEV_STATEFUL:
|
43 |
st.write(f"*[D] Observation {hash} classified as {st.session_state.whale_prediction1[hash]}*")
|
44 |
|
|
|
56 |
|
57 |
for hash in hashes:
|
58 |
image = images[hash]
|
59 |
+
#observation = observations[hash].to_dict()
|
60 |
+
_observation:InputObservation = observations[hash]
|
61 |
|
62 |
with grid[col]:
|
63 |
st.image(image, use_column_width=True)
|
|
|
74 |
ix = viewer.WHALE_CLASSES.index(pred1) if pred1 in viewer.WHALE_CLASSES else None
|
75 |
selected_class = st.selectbox(f"Species for observation {str(o)}", viewer.WHALE_CLASSES, index=ix)
|
76 |
|
77 |
+
_observation.set_selected_class(selected_class)
|
78 |
+
#observation['predicted_class'] = selected_class
|
79 |
+
# this logic is now in the InputObservation class automatially
|
80 |
+
#if selected_class != st.session_state.whale_prediction1[hash]:
|
81 |
+
# observation['class_overriden'] = selected_class # TODO: this should be boolean!
|
82 |
|
83 |
+
# store the elements of the observation that will be transmitted (not image)
|
84 |
+
observation = _observation.to_dict()
|
85 |
st.session_state.public_observations[hash] = observation
|
86 |
+
|
87 |
#st.button(f"Upload observation {str(o)} to THE INTERNET!", on_click=push_observations)
|
88 |
# TODO: the metadata only fills properly if `validate` was clicked.
|
89 |
+
st.markdown(metadata2md(hash, debug=True))
|
90 |
|
91 |
msg = f"[D] full observation after inference: {observation}"
|
92 |
g_logger.debug(msg)
|
|
|
142 |
|
143 |
#st.button(f"Upload observation {str(o)} to THE INTERNET!", on_click=push_observations)
|
144 |
#
|
145 |
+
st.markdown(metadata2md(hash, debug=True))
|
|
|
|
|
|
|
|
|
|
|
146 |
|
147 |
msg = f"[D] full observation after inference: {observation}"
|
148 |
g_logger.debug(msg)
|
|
|
222 |
for i in range(len(whale_classes)):
|
223 |
viewer.display_whale(whale_classes, i)
|
224 |
o += 1
|
225 |
+
col = (col + 1) % row_size
|
src/input/input_observation.py
CHANGED
@@ -68,7 +68,10 @@ class InputObservation:
|
|
68 |
self.time = time
|
69 |
self.uploaded_file = uploaded_file
|
70 |
self.image_md5 = image_md5
|
|
|
71 |
self._top_predictions = []
|
|
|
|
|
72 |
|
73 |
InputObservation._inst_count += 1
|
74 |
self._inst_id = InputObservation._inst_count
|
@@ -81,11 +84,30 @@ class InputObservation:
|
|
81 |
|
82 |
def set_top_predictions(self, top_predictions:list):
|
83 |
self._top_predictions = top_predictions
|
|
|
|
|
84 |
|
85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
@property
|
87 |
def top_predictions(self):
|
88 |
return self._top_predictions
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
|
90 |
# add a method to assign the image_md5 only once
|
91 |
def assign_image_md5(self):
|
@@ -194,6 +216,10 @@ class InputObservation:
|
|
194 |
"image_datetime_raw": self.image_datetime_raw,
|
195 |
"date": str(self.date),
|
196 |
"time": str(self.time),
|
|
|
|
|
|
|
|
|
197 |
#"uploaded_file": self.uploaded_file # can't serialize this in json, not sent to dataset anyway.
|
198 |
}
|
199 |
|
|
|
68 |
self.time = time
|
69 |
self.uploaded_file = uploaded_file
|
70 |
self.image_md5 = image_md5
|
71 |
+
# attributes that get set after predictions/processing
|
72 |
self._top_predictions = []
|
73 |
+
self._selected_class = None
|
74 |
+
self._class_overriden = False
|
75 |
|
76 |
InputObservation._inst_count += 1
|
77 |
self._inst_id = InputObservation._inst_count
|
|
|
84 |
|
85 |
def set_top_predictions(self, top_predictions:list):
|
86 |
self._top_predictions = top_predictions
|
87 |
+
if len(top_predictions) > 0:
|
88 |
+
self.set_selected_class(top_predictions[0])
|
89 |
|
90 |
+
def set_selected_class(self, selected_class:str):
|
91 |
+
self._selected_class = selected_class
|
92 |
+
if selected_class != self._top_predictions[0]:
|
93 |
+
self.set_class_overriden(True)
|
94 |
+
|
95 |
+
def set_class_overriden(self, class_overriden:bool):
|
96 |
+
self._class_overriden = class_overriden
|
97 |
+
|
98 |
+
# add getters for the top_predictions, selected_class and class_overriden
|
99 |
@property
|
100 |
def top_predictions(self):
|
101 |
return self._top_predictions
|
102 |
+
|
103 |
+
@property
|
104 |
+
def selected_class(self):
|
105 |
+
return self._selected_class
|
106 |
+
|
107 |
+
@property
|
108 |
+
def class_overriden(self):
|
109 |
+
return self._class_overriden
|
110 |
+
|
111 |
|
112 |
# add a method to assign the image_md5 only once
|
113 |
def assign_image_md5(self):
|
|
|
216 |
"image_datetime_raw": self.image_datetime_raw,
|
217 |
"date": str(self.date),
|
218 |
"time": str(self.time),
|
219 |
+
"selected_class": self._selected_class,
|
220 |
+
"top_prediction": self._top_predictions[0] if len(self._top_predictions) else None,
|
221 |
+
"class_overriden": self._class_overriden,
|
222 |
+
|
223 |
#"uploaded_file": self.uploaded_file # can't serialize this in json, not sent to dataset anyway.
|
224 |
}
|
225 |
|
src/main.py
CHANGED
@@ -237,7 +237,8 @@ def main() -> None:
|
|
237 |
if st.sidebar.button(":white_check_mark:[**Validate**]"):
|
238 |
# create a dictionary with the submitted observation
|
239 |
tab_log.info(f"{st.session_state.observations}")
|
240 |
-
df = pd.DataFrame(st.session_state.observations
|
|
|
241 |
with tab_coords:
|
242 |
st.table(df)
|
243 |
# there doesn't seem to be any actual validation here?? TODO: find validator function (each element is validated by the input box, but is there something at the whole image level?)
|
@@ -320,7 +321,8 @@ def main() -> None:
|
|
320 |
cetacean_show_results()
|
321 |
|
322 |
st.divider()
|
323 |
-
df = pd.DataFrame(st.session_state.observations, index=[0])
|
|
|
324 |
st.table(df)
|
325 |
|
326 |
# didn't decide what the next state is here - I think we are in the terminal state.
|
|
|
237 |
if st.sidebar.button(":white_check_mark:[**Validate**]"):
|
238 |
# create a dictionary with the submitted observation
|
239 |
tab_log.info(f"{st.session_state.observations}")
|
240 |
+
df = pd.DataFrame([obs.to_dict() for obs in st.session_state.observations.values()])
|
241 |
+
#df = pd.DataFrame(st.session_state.observations, index=[0])
|
242 |
with tab_coords:
|
243 |
st.table(df)
|
244 |
# there doesn't seem to be any actual validation here?? TODO: find validator function (each element is validated by the input box, but is there something at the whole image level?)
|
|
|
321 |
cetacean_show_results()
|
322 |
|
323 |
st.divider()
|
324 |
+
#df = pd.DataFrame(st.session_state.observations, index=[0])
|
325 |
+
df = pd.DataFrame([obs.to_dict() for obs in st.session_state.observations.values()])
|
326 |
st.table(df)
|
327 |
|
328 |
# didn't decide what the next state is here - I think we are in the terminal state.
|
src/utils/metadata_handler.py
CHANGED
@@ -1,10 +1,11 @@
|
|
1 |
import streamlit as st
|
2 |
|
3 |
-
def metadata2md(image_hash:str) -> str:
|
4 |
"""Get metadata from cache and return as markdown-formatted key-value list
|
5 |
|
6 |
Args:
|
7 |
image_hash (str): The hash of the image to get metadata for
|
|
|
8 |
|
9 |
Returns:
|
10 |
str: Markdown-formatted key-value list of metadata
|
@@ -12,6 +13,8 @@ def metadata2md(image_hash:str) -> str:
|
|
12 |
"""
|
13 |
markdown_str = "\n"
|
14 |
keys_to_print = ["author_email", "latitude", "longitude", "date", "time"]
|
|
|
|
|
15 |
|
16 |
observation = st.session_state.public_observations.get(image_hash, {})
|
17 |
|
|
|
1 |
import streamlit as st
|
2 |
|
3 |
+
def metadata2md(image_hash:str, debug:bool=False) -> str:
|
4 |
"""Get metadata from cache and return as markdown-formatted key-value list
|
5 |
|
6 |
Args:
|
7 |
image_hash (str): The hash of the image to get metadata for
|
8 |
+
debug (bool, optional): Whether to print additional fields.
|
9 |
|
10 |
Returns:
|
11 |
str: Markdown-formatted key-value list of metadata
|
|
|
13 |
"""
|
14 |
markdown_str = "\n"
|
15 |
keys_to_print = ["author_email", "latitude", "longitude", "date", "time"]
|
16 |
+
if debug:
|
17 |
+
keys_to_print += ["iamge_md5", "selected_class", "top_prediction", "class_overriden"]
|
18 |
|
19 |
observation = st.session_state.public_observations.get(image_hash, {})
|
20 |
|