Spaces:
Running
Running
Merge pull request #29 from sdsc-ordes/feat/stateful-workflow
Browse files- .github/workflows/python-pycov-onPR.yml +39 -0
- requirements.txt +2 -1
- src/classifier/classifier_image.py +202 -8
- src/classifier_image.py +0 -70
- src/hf_push_observations.py +74 -8
- src/input/input_handling.py +370 -72
- src/input/input_observation.py +190 -43
- src/input/input_validator.py +28 -9
- src/main.py +149 -76
- src/maps/obs_map.py +2 -2
- src/utils/metadata_handler.py +15 -5
- src/utils/st_logs.py +11 -0
- src/utils/workflow_state.py +108 -0
- src/utils/workflow_ui.py +48 -0
- src/whale_viewer.py +3 -0
- tests/test_input_handling.py +2 -5
- tests/test_whale_viewer.py +1 -3
.github/workflows/python-pycov-onPR.yml
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This workflow will install dependencies, create coverage tests and run Pytest Coverage Commentator
|
2 |
+
# For more information see: https://github.com/coroo/pytest-coverage-commentator
|
3 |
+
name: pytest-coverage-in-PR
|
4 |
+
on:
|
5 |
+
pull_request:
|
6 |
+
branches:
|
7 |
+
- '*'
|
8 |
+
jobs:
|
9 |
+
build:
|
10 |
+
runs-on: ubuntu-latest
|
11 |
+
permissions:
|
12 |
+
contents: write
|
13 |
+
pull-requests: write
|
14 |
+
steps:
|
15 |
+
- uses: actions/checkout@v4
|
16 |
+
- name: Set up Python 3.10
|
17 |
+
uses: actions/setup-python@v3
|
18 |
+
with:
|
19 |
+
python-version: "3.10"
|
20 |
+
- name: Install dependencies
|
21 |
+
run: |
|
22 |
+
python -m pip install --upgrade pip
|
23 |
+
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
|
24 |
+
if [ -f tests/requirements.txt ]; then pip install -r tests/requirements.txt; fi
|
25 |
+
|
26 |
+
- name: Build coverage files for mishakav commenter action
|
27 |
+
run: |
|
28 |
+
pytest --junitxml=pytest.xml --cov-report=term-missing:skip-covered --cov=src tests/ | tee pytest-coverage.txt
|
29 |
+
echo "working dir:" && pwd
|
30 |
+
echo "files in cwd:" && ls -ltr
|
31 |
+
|
32 |
+
- name: Pytest coverage comment
|
33 |
+
uses: MishaKav/pytest-coverage-comment@main
|
34 |
+
with:
|
35 |
+
pytest-coverage-path: ./pytest-coverage.txt
|
36 |
+
junitxml-path: ./pytest.xml
|
37 |
+
|
38 |
+
#- name: Comment coverage
|
39 |
+
# uses: coroo/[email protected]
|
requirements.txt
CHANGED
@@ -10,7 +10,8 @@ streamlit_folium==0.23.1
|
|
10 |
|
11 |
# backend
|
12 |
datasets==3.0.2
|
13 |
-
|
|
|
14 |
|
15 |
# running ML models
|
16 |
|
|
|
10 |
|
11 |
# backend
|
12 |
datasets==3.0.2
|
13 |
+
## FSM
|
14 |
+
transitions==0.9.2
|
15 |
|
16 |
# running ML models
|
17 |
|
src/classifier/classifier_image.py
CHANGED
@@ -10,13 +10,207 @@ import whale_viewer as viewer
|
|
10 |
from hf_push_observations import push_observations
|
11 |
from utils.grid_maker import gridder
|
12 |
from utils.metadata_handler import metadata2md
|
|
|
13 |
|
14 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
"""Cetacean classifier using the saving-willy model from Saving Willy Hugging Face space.
|
16 |
For each image in the session state, classify the image and display the top 3 predictions.
|
17 |
Args:
|
18 |
cetacean_classifier ([type]): saving-willy model from Saving Willy Hugging Face space
|
19 |
"""
|
|
|
20 |
images = st.session_state.images
|
21 |
observations = st.session_state.observations
|
22 |
hashes = st.session_state.image_hashes
|
@@ -33,25 +227,25 @@ def cetacean_classify(cetacean_classifier):
|
|
33 |
observation = observations[hash].to_dict()
|
34 |
# run classifier model on `image`, and persistently store the output
|
35 |
out = cetacean_classifier(image) # get top 3 matches
|
36 |
-
st.session_state.whale_prediction1 = out['predictions'][0]
|
37 |
-
st.session_state.classify_whale_done = True
|
38 |
-
msg = f"[D]2 classify_whale_done: {st.session_state.classify_whale_done}, whale_prediction1: {st.session_state.whale_prediction1}"
|
39 |
g_logger.info(msg)
|
40 |
|
41 |
# dropdown for selecting/overriding the species prediction
|
42 |
-
if not st.session_state.classify_whale_done:
|
43 |
selected_class = st.sidebar.selectbox("Species", viewer.WHALE_CLASSES,
|
44 |
index=None, placeholder="Species not yet identified...",
|
45 |
disabled=True)
|
46 |
else:
|
47 |
-
pred1 = st.session_state.whale_prediction1
|
48 |
# get index of pred1 from WHALE_CLASSES, none if not present
|
49 |
print(f"[D] pred1: {pred1}")
|
50 |
ix = viewer.WHALE_CLASSES.index(pred1) if pred1 in viewer.WHALE_CLASSES else None
|
51 |
selected_class = st.selectbox(f"Species for observation {str(o)}", viewer.WHALE_CLASSES, index=ix)
|
52 |
|
53 |
observation['predicted_class'] = selected_class
|
54 |
-
if selected_class != st.session_state.whale_prediction1:
|
55 |
observation['class_overriden'] = selected_class
|
56 |
|
57 |
st.session_state.public_observation = observation
|
@@ -70,4 +264,4 @@ def cetacean_classify(cetacean_classifier):
|
|
70 |
for i in range(len(whale_classes)):
|
71 |
viewer.display_whale(whale_classes, i)
|
72 |
o += 1
|
73 |
-
col = (col + 1) % row_size
|
|
|
10 |
from hf_push_observations import push_observations
|
11 |
from utils.grid_maker import gridder
|
12 |
from utils.metadata_handler import metadata2md
|
13 |
+
from input.input_observation import InputObservation
|
14 |
|
15 |
+
def init_classifier_session_states() -> None:
|
16 |
+
'''
|
17 |
+
Initialise the session state variables used in classification
|
18 |
+
'''
|
19 |
+
if "classify_whale_done" not in st.session_state:
|
20 |
+
st.session_state.classify_whale_done = {}
|
21 |
+
|
22 |
+
if "whale_prediction1" not in st.session_state:
|
23 |
+
st.session_state.whale_prediction1 = {}
|
24 |
+
|
25 |
+
|
26 |
+
def add_classifier_header() -> None:
|
27 |
+
"""
|
28 |
+
Add brief explainer text about cetacean classification to the tab
|
29 |
+
"""
|
30 |
+
st.markdown("""
|
31 |
+
*Run classifer to identify the species of cetean on the uploaded image.
|
32 |
+
Once inference is complete, the top three predictions are shown.
|
33 |
+
You can override the prediction by selecting a species from the dropdown.*""")
|
34 |
+
|
35 |
+
|
36 |
+
# func to just run classification, store results.
|
37 |
+
def cetacean_just_classify(cetacean_classifier):
|
38 |
+
"""
|
39 |
+
Infer cetacean species for all observations in the session state.
|
40 |
+
|
41 |
+
- this function runs the classifier, and stores results in the session state.
|
42 |
+
- the top 3 predictions are stored in the observation object, which is retained
|
43 |
+
in st.session_state.observations
|
44 |
+
- to display results use cetacean_show_results() or cetacean_show_results_and_review()
|
45 |
+
|
46 |
+
Args:
|
47 |
+
cetacean_classifier ([type]): saving-willy model from Saving Willy Hugging Face space
|
48 |
+
"""
|
49 |
+
|
50 |
+
images = st.session_state.images
|
51 |
+
#observations = st.session_state.observations
|
52 |
+
hashes = st.session_state.image_hashes
|
53 |
+
|
54 |
+
for hash in hashes:
|
55 |
+
image = images[hash]
|
56 |
+
# run classifier model on `image`, and persistently store the output
|
57 |
+
out = cetacean_classifier(image) # get top 3 matches
|
58 |
+
st.session_state.whale_prediction1[hash] = out['predictions'][0]
|
59 |
+
st.session_state.classify_whale_done[hash] = True
|
60 |
+
st.session_state.observations[hash].set_top_predictions(out['predictions'][:])
|
61 |
+
|
62 |
+
msg = f"[D]2 classify_whale_done for {hash}: {st.session_state.classify_whale_done[hash]}, whale_prediction1: {st.session_state.whale_prediction1[hash]}"
|
63 |
+
g_logger.info(msg)
|
64 |
+
|
65 |
+
if st.session_state.MODE_DEV_STATEFUL:
|
66 |
+
st.write(f"*[D] Observation {hash} classified as {st.session_state.whale_prediction1[hash]}*")
|
67 |
+
|
68 |
+
|
69 |
+
# func to show results and allow review
|
70 |
+
def cetacean_show_results_and_review() -> None:
|
71 |
+
"""
|
72 |
+
Present classification results and allow user to review and override the prediction.
|
73 |
+
|
74 |
+
- for each observation in the session state, displays the image, summarised
|
75 |
+
metadata, and the top 3 predictions.
|
76 |
+
- allows user to override the prediction by selecting a species from the dropdown.
|
77 |
+
- the selected species is stored in the observation object, which is retained in
|
78 |
+
st.session_state.observations
|
79 |
+
|
80 |
+
"""
|
81 |
+
|
82 |
+
images = st.session_state.images
|
83 |
+
observations = st.session_state.observations
|
84 |
+
hashes = st.session_state.image_hashes
|
85 |
+
batch_size, row_size, page = gridder(hashes)
|
86 |
+
|
87 |
+
grid = st.columns(row_size)
|
88 |
+
col = 0
|
89 |
+
o = 1
|
90 |
+
|
91 |
+
for hash in hashes:
|
92 |
+
image = images[hash]
|
93 |
+
#observation = observations[hash].to_dict()
|
94 |
+
_observation:InputObservation = observations[hash]
|
95 |
+
|
96 |
+
with grid[col]:
|
97 |
+
st.image(image, use_column_width=True)
|
98 |
+
|
99 |
+
# dropdown for selecting/overriding the species prediction
|
100 |
+
if not st.session_state.classify_whale_done[hash]:
|
101 |
+
selected_class = st.sidebar.selectbox("Species", viewer.WHALE_CLASSES,
|
102 |
+
index=None, placeholder="Species not yet identified...",
|
103 |
+
disabled=True)
|
104 |
+
else:
|
105 |
+
pred1 = st.session_state.whale_prediction1[hash]
|
106 |
+
# get index of pred1 from WHALE_CLASSES, none if not present
|
107 |
+
print(f"[D] {o:3} pred1: {pred1:30} | {hash}")
|
108 |
+
ix = viewer.WHALE_CLASSES.index(pred1) if pred1 in viewer.WHALE_CLASSES else None
|
109 |
+
selected_class = st.selectbox(f"Species for observation {str(o)}", viewer.WHALE_CLASSES, index=ix)
|
110 |
+
|
111 |
+
_observation.set_selected_class(selected_class)
|
112 |
+
#observation['predicted_class'] = selected_class
|
113 |
+
# this logic is now in the InputObservation class automatially
|
114 |
+
#if selected_class != st.session_state.whale_prediction1[hash]:
|
115 |
+
# observation['class_overriden'] = selected_class # TODO: this should be boolean!
|
116 |
+
|
117 |
+
# store the elements of the observation that will be transmitted (not image)
|
118 |
+
observation = _observation.to_dict()
|
119 |
+
st.session_state.public_observations[hash] = observation
|
120 |
+
|
121 |
+
#st.button(f"Upload observation {str(o)} to THE INTERNET!", on_click=push_observations)
|
122 |
+
# TODO: the metadata only fills properly if `validate` was clicked.
|
123 |
+
st.markdown(metadata2md(hash, debug=True))
|
124 |
+
|
125 |
+
msg = f"[D] full observation after inference: {observation}"
|
126 |
+
g_logger.debug(msg)
|
127 |
+
print(msg)
|
128 |
+
# TODO: add a link to more info on the model, next to the button.
|
129 |
+
|
130 |
+
whale_classes = observations[hash].top_predictions
|
131 |
+
# render images for the top 3 (that is what the model api returns)
|
132 |
+
n = len(whale_classes)
|
133 |
+
st.markdown(f"**Top {n} Predictions for observation {str(o)}**")
|
134 |
+
for i in range(n):
|
135 |
+
viewer.display_whale(whale_classes, i)
|
136 |
+
o += 1
|
137 |
+
col = (col + 1) % row_size
|
138 |
+
|
139 |
+
|
140 |
+
# func to just present results
|
141 |
+
def cetacean_show_results():
|
142 |
+
"""
|
143 |
+
Present classification results that may be pushed to the online dataset.
|
144 |
+
|
145 |
+
- for each observation in the session state, displays the image, summarised
|
146 |
+
metadata, the top 3 predictions, and the selected species (which may have
|
147 |
+
been manually selected, or the top prediction accepted).
|
148 |
+
|
149 |
+
"""
|
150 |
+
images = st.session_state.images
|
151 |
+
observations = st.session_state.observations
|
152 |
+
hashes = st.session_state.image_hashes
|
153 |
+
batch_size, row_size, page = gridder(hashes)
|
154 |
+
|
155 |
+
|
156 |
+
grid = st.columns(row_size)
|
157 |
+
col = 0
|
158 |
+
o = 1
|
159 |
+
|
160 |
+
for hash in hashes:
|
161 |
+
image = images[hash]
|
162 |
+
observation = observations[hash].to_dict()
|
163 |
+
|
164 |
+
with grid[col]:
|
165 |
+
st.image(image, use_column_width=True)
|
166 |
+
|
167 |
+
# # dropdown for selecting/overriding the species prediction
|
168 |
+
# if not st.session_state.classify_whale_done[hash]:
|
169 |
+
# selected_class = st.sidebar.selectbox("Species", viewer.WHALE_CLASSES,
|
170 |
+
# index=None, placeholder="Species not yet identified...",
|
171 |
+
# disabled=True)
|
172 |
+
# else:
|
173 |
+
# pred1 = st.session_state.whale_prediction1[hash]
|
174 |
+
# # get index of pred1 from WHALE_CLASSES, none if not present
|
175 |
+
# print(f"[D] pred1: {pred1}")
|
176 |
+
# ix = viewer.WHALE_CLASSES.index(pred1) if pred1 in viewer.WHALE_CLASSES else None
|
177 |
+
# selected_class = st.selectbox(f"Species for observation {str(o)}", viewer.WHALE_CLASSES, index=ix)
|
178 |
+
|
179 |
+
# observation['predicted_class'] = selected_class
|
180 |
+
# if selected_class != st.session_state.whale_prediction1[hash]:
|
181 |
+
# observation['class_overriden'] = selected_class # TODO: this should be boolean!
|
182 |
+
|
183 |
+
# st.session_state.public_observation = observation
|
184 |
+
|
185 |
+
#st.button(f"Upload observation {str(o)} to THE INTERNET!", on_click=push_observations)
|
186 |
+
#
|
187 |
+
st.markdown(metadata2md(hash, debug=True))
|
188 |
+
|
189 |
+
msg = f"[D] full observation after inference: {observation}"
|
190 |
+
g_logger.debug(msg)
|
191 |
+
print(msg)
|
192 |
+
# TODO: add a link to more info on the model, next to the button.
|
193 |
+
|
194 |
+
whale_classes = observations[hash].top_predictions
|
195 |
+
# render images for the top 3 (that is what the model api returns)
|
196 |
+
n = len(whale_classes)
|
197 |
+
st.markdown(f"**Top {n} Predictions for observation {str(o)}**")
|
198 |
+
for i in range(n):
|
199 |
+
viewer.display_whale(whale_classes, i)
|
200 |
+
o += 1
|
201 |
+
col = (col + 1) % row_size
|
202 |
+
|
203 |
+
|
204 |
+
|
205 |
+
|
206 |
+
# func to do all in one
|
207 |
+
def cetacean_classify_show_and_review(cetacean_classifier):
|
208 |
"""Cetacean classifier using the saving-willy model from Saving Willy Hugging Face space.
|
209 |
For each image in the session state, classify the image and display the top 3 predictions.
|
210 |
Args:
|
211 |
cetacean_classifier ([type]): saving-willy model from Saving Willy Hugging Face space
|
212 |
"""
|
213 |
+
raise DeprecationWarning("This function is deprecated. Use individual steps instead")
|
214 |
images = st.session_state.images
|
215 |
observations = st.session_state.observations
|
216 |
hashes = st.session_state.image_hashes
|
|
|
227 |
observation = observations[hash].to_dict()
|
228 |
# run classifier model on `image`, and persistently store the output
|
229 |
out = cetacean_classifier(image) # get top 3 matches
|
230 |
+
st.session_state.whale_prediction1[hash] = out['predictions'][0]
|
231 |
+
st.session_state.classify_whale_done[hash] = True
|
232 |
+
msg = f"[D]2 classify_whale_done for {hash}: {st.session_state.classify_whale_done[hash]}, whale_prediction1: {st.session_state.whale_prediction1[hash]}"
|
233 |
g_logger.info(msg)
|
234 |
|
235 |
# dropdown for selecting/overriding the species prediction
|
236 |
+
if not st.session_state.classify_whale_done[hash]:
|
237 |
selected_class = st.sidebar.selectbox("Species", viewer.WHALE_CLASSES,
|
238 |
index=None, placeholder="Species not yet identified...",
|
239 |
disabled=True)
|
240 |
else:
|
241 |
+
pred1 = st.session_state.whale_prediction1[hash]
|
242 |
# get index of pred1 from WHALE_CLASSES, none if not present
|
243 |
print(f"[D] pred1: {pred1}")
|
244 |
ix = viewer.WHALE_CLASSES.index(pred1) if pred1 in viewer.WHALE_CLASSES else None
|
245 |
selected_class = st.selectbox(f"Species for observation {str(o)}", viewer.WHALE_CLASSES, index=ix)
|
246 |
|
247 |
observation['predicted_class'] = selected_class
|
248 |
+
if selected_class != st.session_state.whale_prediction1[hash]:
|
249 |
observation['class_overriden'] = selected_class
|
250 |
|
251 |
st.session_state.public_observation = observation
|
|
|
264 |
for i in range(len(whale_classes)):
|
265 |
viewer.display_whale(whale_classes, i)
|
266 |
o += 1
|
267 |
+
col = (col + 1) % row_size
|
src/classifier_image.py
DELETED
@@ -1,70 +0,0 @@
|
|
1 |
-
import streamlit as st
|
2 |
-
import logging
|
3 |
-
import os
|
4 |
-
|
5 |
-
# get a global var for logger accessor in this module
|
6 |
-
LOG_LEVEL = logging.DEBUG
|
7 |
-
g_logger = logging.getLogger(__name__)
|
8 |
-
g_logger.setLevel(LOG_LEVEL)
|
9 |
-
|
10 |
-
from grid_maker import gridder
|
11 |
-
import hf_push_observations as sw_push_obs
|
12 |
-
import utils.metadata_handler as meta_handler
|
13 |
-
import whale_viewer as sw_wv
|
14 |
-
|
15 |
-
def cetacean_classify(cetacean_classifier, tab_inference):
|
16 |
-
files = st.session_state.files
|
17 |
-
images = st.session_state.images
|
18 |
-
observations = st.session_state.observations
|
19 |
-
|
20 |
-
batch_size, row_size, page = gridder(files)
|
21 |
-
|
22 |
-
grid = st.columns(row_size)
|
23 |
-
col = 0
|
24 |
-
|
25 |
-
for file in files:
|
26 |
-
image = images[file.name]
|
27 |
-
|
28 |
-
with grid[col]:
|
29 |
-
st.image(image, use_column_width=True)
|
30 |
-
observation = observations[file.name].to_dict()
|
31 |
-
# run classifier model on `image`, and persistently store the output
|
32 |
-
out = cetacean_classifier(image) # get top 3 matches
|
33 |
-
st.session_state.whale_prediction1 = out['predictions'][0]
|
34 |
-
st.session_state.classify_whale_done = True
|
35 |
-
msg = f"[D]2 classify_whale_done: {st.session_state.classify_whale_done}, whale_prediction1: {st.session_state.whale_prediction1}"
|
36 |
-
g_logger.info(msg)
|
37 |
-
|
38 |
-
# dropdown for selecting/overriding the species prediction
|
39 |
-
if not st.session_state.classify_whale_done:
|
40 |
-
selected_class = st.sidebar.selectbox("Species", sw_wv.WHALE_CLASSES,
|
41 |
-
index=None, placeholder="Species not yet identified...",
|
42 |
-
disabled=True)
|
43 |
-
else:
|
44 |
-
pred1 = st.session_state.whale_prediction1
|
45 |
-
# get index of pred1 from WHALE_CLASSES, none if not present
|
46 |
-
print(f"[D] pred1: {pred1}")
|
47 |
-
ix = sw_wv.WHALE_CLASSES.index(pred1) if pred1 in sw_wv.WHALE_CLASSES else None
|
48 |
-
selected_class = tab_inference.selectbox("Species", sw_wv.WHALE_CLASSES, index=ix)
|
49 |
-
|
50 |
-
observation['predicted_class'] = selected_class
|
51 |
-
if selected_class != st.session_state.whale_prediction1:
|
52 |
-
observation['class_overriden'] = selected_class
|
53 |
-
|
54 |
-
st.session_state.public_observation = observation
|
55 |
-
st.button(f"Upload observation for {file.name} to THE INTERNET!", on_click=sw_push_obs.push_observations)
|
56 |
-
# TODO: the metadata only fills properly if `validate` was clicked.
|
57 |
-
st.markdown(meta_handler.metadata2md())
|
58 |
-
|
59 |
-
msg = f"[D] full observation after inference: {observation}"
|
60 |
-
g_logger.debug(msg)
|
61 |
-
print(msg)
|
62 |
-
# TODO: add a link to more info on the model, next to the button.
|
63 |
-
|
64 |
-
whale_classes = out['predictions'][:]
|
65 |
-
# render images for the top 3 (that is what the model api returns)
|
66 |
-
#with tab_inference:
|
67 |
-
st.title(f"Species detected for {file.name}")
|
68 |
-
for i in range(len(whale_classes)):
|
69 |
-
sw_wv.display_whale(whale_classes, i)
|
70 |
-
col = (col + 1) % row_size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/hf_push_observations.py
CHANGED
@@ -1,15 +1,82 @@
|
|
1 |
-
|
2 |
-
import streamlit as st
|
3 |
-
from huggingface_hub import HfApi
|
4 |
import json
|
5 |
import tempfile
|
6 |
import logging
|
7 |
|
|
|
|
|
|
|
|
|
|
|
8 |
# get a global var for logger accessor in this module
|
9 |
LOG_LEVEL = logging.DEBUG
|
10 |
g_logger = logging.getLogger(__name__)
|
11 |
g_logger.setLevel(LOG_LEVEL)
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
def push_observations(tab_log:DeltaGenerator=None):
|
14 |
"""
|
15 |
Push the observations to the Hugging Face dataset
|
@@ -20,17 +87,16 @@ def push_observations(tab_log:DeltaGenerator=None):
|
|
20 |
push any observation since generating the logger)
|
21 |
|
22 |
"""
|
|
|
|
|
23 |
# we get the observation from session state: 1 is the dict 2 is the image.
|
24 |
# first, lets do an info display (popup)
|
25 |
metadata_str = json.dumps(st.session_state.public_observation)
|
26 |
|
27 |
st.toast(f"Uploading observations: {metadata_str}", icon="🦭")
|
28 |
-
|
29 |
-
if tab_log is not None:
|
30 |
-
tab_log.info(f"Uploading observations: {metadata_str}")
|
31 |
|
32 |
# get huggingface api
|
33 |
-
import os
|
34 |
token = os.environ.get("HF_TOKEN", None)
|
35 |
api = HfApi(token=token)
|
36 |
|
@@ -53,4 +119,4 @@ def push_observations(tab_log:DeltaGenerator=None):
|
|
53 |
# msg = f"observation attempted tx to repo happy walrus: {rv}"
|
54 |
g_logger.info(msg)
|
55 |
st.info(msg)
|
56 |
-
|
|
|
1 |
+
import os
|
|
|
|
|
2 |
import json
|
3 |
import tempfile
|
4 |
import logging
|
5 |
|
6 |
+
from streamlit.delta_generator import DeltaGenerator
|
7 |
+
import streamlit as st
|
8 |
+
from huggingface_hub import HfApi, CommitInfo
|
9 |
+
|
10 |
+
|
11 |
# get a global var for logger accessor in this module
|
12 |
LOG_LEVEL = logging.DEBUG
|
13 |
g_logger = logging.getLogger(__name__)
|
14 |
g_logger.setLevel(LOG_LEVEL)
|
15 |
|
16 |
+
def push_observation(image_hash:str, api:HfApi, enable_push:False) -> CommitInfo:
|
17 |
+
'''
|
18 |
+
push one observation to the Hugging Face dataset
|
19 |
+
|
20 |
+
'''
|
21 |
+
# get the observation
|
22 |
+
observation = st.session_state.public_observations.get(image_hash)
|
23 |
+
if observation is None:
|
24 |
+
msg = f"Could not find observation with hash {image_hash}"
|
25 |
+
g_logger.error(msg)
|
26 |
+
st.error(msg)
|
27 |
+
return None
|
28 |
+
|
29 |
+
# convert to json
|
30 |
+
metadata_str = json.dumps(observation) # doesn't work yet, TODO
|
31 |
+
|
32 |
+
st.toast(f"Uploading observation: {metadata_str}", icon="🦭")
|
33 |
+
g_logger.info(f"Uploading observation: {metadata_str}")
|
34 |
+
|
35 |
+
# write to temp file so we can send it (why is this not using context mgr?)
|
36 |
+
f = tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False)
|
37 |
+
f.write(metadata_str)
|
38 |
+
f.close()
|
39 |
+
#st.info(f"temp file: {f.name} with metadata written...")
|
40 |
+
|
41 |
+
path_in_repo = f"metadata/{observation['author_email']}/{observation['image_md5']}.json"
|
42 |
+
|
43 |
+
msg = f"fname: {f.name} | path: {path_in_repo}"
|
44 |
+
print(msg)
|
45 |
+
st.warning(msg)
|
46 |
+
|
47 |
+
if enable_push:
|
48 |
+
rv = api.upload_file(
|
49 |
+
path_or_fileobj=f.name,
|
50 |
+
path_in_repo=path_in_repo,
|
51 |
+
repo_id="Saving-Willy/temp_dataset",
|
52 |
+
repo_type="dataset",
|
53 |
+
)
|
54 |
+
print(rv)
|
55 |
+
msg = f"observation attempted tx to repo happy walrus: {rv}"
|
56 |
+
g_logger.info(msg)
|
57 |
+
st.info(msg)
|
58 |
+
else:
|
59 |
+
rv = None # temp don't send anything
|
60 |
+
|
61 |
+
return rv
|
62 |
+
|
63 |
+
|
64 |
+
|
65 |
+
def push_all_observations(enable_push:bool=False):
|
66 |
+
'''
|
67 |
+
open an API connection to Hugging Face, and push all observation one by one
|
68 |
+
'''
|
69 |
+
|
70 |
+
# get huggingface api
|
71 |
+
token = os.environ.get("HF_TOKEN", None)
|
72 |
+
api = HfApi(token=token)
|
73 |
+
|
74 |
+
# iterate over the list of observations
|
75 |
+
for hash in st.session_state.public_observations.keys():
|
76 |
+
rv = push_observation(hash, api, enable_push=enable_push)
|
77 |
+
|
78 |
+
|
79 |
+
|
80 |
def push_observations(tab_log:DeltaGenerator=None):
|
81 |
"""
|
82 |
Push the observations to the Hugging Face dataset
|
|
|
87 |
push any observation since generating the logger)
|
88 |
|
89 |
"""
|
90 |
+
raise DeprecationWarning("This function is deprecated. Use push_all_observations instead.")
|
91 |
+
|
92 |
# we get the observation from session state: 1 is the dict 2 is the image.
|
93 |
# first, lets do an info display (popup)
|
94 |
metadata_str = json.dumps(st.session_state.public_observation)
|
95 |
|
96 |
st.toast(f"Uploading observations: {metadata_str}", icon="🦭")
|
97 |
+
g_logger.info(f"Uploading observations: {metadata_str}")
|
|
|
|
|
98 |
|
99 |
# get huggingface api
|
|
|
100 |
token = os.environ.get("HF_TOKEN", None)
|
101 |
api = HfApi(token=token)
|
102 |
|
|
|
119 |
# msg = f"observation attempted tx to repo happy walrus: {rv}"
|
120 |
g_logger.info(msg)
|
121 |
st.info(msg)
|
122 |
+
|
src/input/input_handling.py
CHANGED
@@ -1,14 +1,17 @@
|
|
|
|
1 |
import datetime
|
2 |
import logging
|
|
|
3 |
|
4 |
import streamlit as st
|
5 |
from streamlit.delta_generator import DeltaGenerator
|
|
|
6 |
|
7 |
import cv2
|
8 |
import numpy as np
|
9 |
|
10 |
from input.input_observation import InputObservation
|
11 |
-
from input.input_validator import get_image_datetime, is_valid_email, is_valid_number
|
12 |
|
13 |
m_logger = logging.getLogger(__name__)
|
14 |
m_logger.setLevel(logging.INFO)
|
@@ -23,99 +26,394 @@ allowed_image_types = ['jpg', 'jpeg', 'png', 'webp']
|
|
23 |
# an arbitrary set of defaults so testing is less painful...
|
24 |
# ideally we add in some randomization to the defaults
|
25 |
spoof_metadata = {
|
26 |
-
"latitude":
|
27 |
"longitude": 44,
|
28 |
"author_email": "[email protected]",
|
29 |
"date": None,
|
30 |
"time": None,
|
31 |
}
|
32 |
|
33 |
-
def
|
34 |
-
viewcontainer: DeltaGenerator=None,
|
35 |
-
_allowed_image_types: list=None, ) -> InputObservation:
|
36 |
"""
|
37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
-
|
40 |
-
|
41 |
|
42 |
-
|
43 |
-
|
44 |
-
|
|
|
|
|
|
|
45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
Returns:
|
47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
|
|
|
|
|
|
|
49 |
"""
|
50 |
-
|
51 |
-
|
52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
|
54 |
-
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
|
57 |
|
58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
-
#
|
66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
observations = {}
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
|
|
|
85 |
|
86 |
-
# 3. Latitude Entry Box
|
87 |
-
latitude = viewcontainer.text_input("Latitude for "+filename, spoof_metadata.get('latitude', ""))
|
88 |
-
if latitude and not is_valid_number(latitude):
|
89 |
-
viewcontainer.error("Please enter a valid latitude (numerical only).")
|
90 |
-
m_logger.error(f"Invalid latitude entered: {latitude}.")
|
91 |
-
# 4. Longitude Entry Box
|
92 |
-
longitude = viewcontainer.text_input("Longitude for "+filename, spoof_metadata.get('longitude', ""))
|
93 |
-
if longitude and not is_valid_number(longitude):
|
94 |
-
viewcontainer.error("Please enter a valid longitude (numerical only).")
|
95 |
-
m_logger.error(f"Invalid latitude entered: {latitude}.")
|
96 |
-
# 5. Date/time
|
97 |
-
## first from image metadata
|
98 |
-
if image_datetime is not None:
|
99 |
-
time_value = datetime.datetime.strptime(image_datetime, '%Y:%m:%d %H:%M:%S').time()
|
100 |
-
date_value = datetime.datetime.strptime(image_datetime, '%Y:%m:%d %H:%M:%S').date()
|
101 |
-
else:
|
102 |
-
time_value = datetime.datetime.now().time() # Default to current time
|
103 |
-
date_value = datetime.datetime.now().date()
|
104 |
|
105 |
-
|
106 |
-
|
107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
|
117 |
-
st.session_state.images = images
|
118 |
-
st.session_state.files = uploaded_files
|
119 |
-
st.session_state.observations = observations
|
120 |
-
st.session_state.image_hashes = image_hashes
|
121 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Tuple
|
2 |
import datetime
|
3 |
import logging
|
4 |
+
import hashlib
|
5 |
|
6 |
import streamlit as st
|
7 |
from streamlit.delta_generator import DeltaGenerator
|
8 |
+
from streamlit.runtime.uploaded_file_manager import UploadedFile
|
9 |
|
10 |
import cv2
|
11 |
import numpy as np
|
12 |
|
13 |
from input.input_observation import InputObservation
|
14 |
+
from input.input_validator import get_image_datetime, is_valid_email, is_valid_number, get_image_latlon
|
15 |
|
16 |
m_logger = logging.getLogger(__name__)
|
17 |
m_logger.setLevel(logging.INFO)
|
|
|
26 |
# an arbitrary set of defaults so testing is less painful...
|
27 |
# ideally we add in some randomization to the defaults
|
28 |
spoof_metadata = {
|
29 |
+
"latitude": 0.5,
|
30 |
"longitude": 44,
|
31 |
"author_email": "[email protected]",
|
32 |
"date": None,
|
33 |
"time": None,
|
34 |
}
|
35 |
|
36 |
+
def check_inputs_are_set(empty_ok:bool=False, debug:bool=False) -> bool:
|
|
|
|
|
37 |
"""
|
38 |
+
Checks if all expected inputs have been entered
|
39 |
+
|
40 |
+
Implementation: via the Streamlit session state.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
empty_ok (bool): If True, returns True if no inputs are set. Default is False.
|
44 |
+
debug (bool): If True, prints and logs the status of each expected input key. Default is False.
|
45 |
+
Returns:
|
46 |
+
bool: True if all expected input keys are set, False otherwise.
|
47 |
+
"""
|
48 |
+
image_hashes = st.session_state.image_hashes
|
49 |
+
if len(image_hashes) == 0:
|
50 |
+
return empty_ok
|
51 |
+
|
52 |
+
exp_input_key_stubs = ["input_latitude", "input_longitude", "input_date", "input_time"]
|
53 |
+
#exp_input_key_stubs = ["input_latitude", "input_longitude", "input_author_email", "input_date", "input_time",
|
54 |
+
|
55 |
+
vals = []
|
56 |
+
# the author_email is global/one-off - no hash extension.
|
57 |
+
if "input_author_email" in st.session_state:
|
58 |
+
val = st.session_state["input_author_email"]
|
59 |
+
vals.append(val)
|
60 |
+
if debug:
|
61 |
+
msg = f"{'input_author_email':15}, {(val is not None):8}, {val}"
|
62 |
+
m_logger.debug(msg)
|
63 |
+
print(msg)
|
64 |
+
|
65 |
+
|
66 |
+
for image_hash in image_hashes:
|
67 |
+
for stub in exp_input_key_stubs:
|
68 |
+
key = f"{stub}_{image_hash}"
|
69 |
+
val = None
|
70 |
+
if key in st.session_state:
|
71 |
+
val = st.session_state[key]
|
72 |
+
|
73 |
+
# handle cases where it is defined but empty
|
74 |
+
# if val is a string and empty, set to None
|
75 |
+
if isinstance(val, str) and not val:
|
76 |
+
val = None
|
77 |
+
# if val is a list and empty, set to None (not sure what UI elements would return a list?)
|
78 |
+
if isinstance(val, list) and not val:
|
79 |
+
val = None
|
80 |
+
# number 0 is ok - possibly. could be on the equator, e.g.
|
81 |
+
|
82 |
+
vals.append(val)
|
83 |
+
if debug:
|
84 |
+
msg = f"{key:15}, {(val is not None):8}, {val}"
|
85 |
+
m_logger.debug(msg)
|
86 |
+
print(msg)
|
87 |
+
|
88 |
+
|
89 |
+
|
90 |
+
return all([v is not None for v in vals])
|
91 |
+
|
92 |
+
|
93 |
+
def buffer_uploaded_files():
|
94 |
+
"""
|
95 |
+
Buffers uploaded files to session_state (images, image_hashes, filenames).
|
96 |
|
97 |
+
Buffers uploaded files by extracting and storing filenames, images, and
|
98 |
+
image hashes in the session state.
|
99 |
|
100 |
+
Adds the following keys to `st.session_state`:
|
101 |
+
- `images`: dict mapping image hashes to image data (numpy arrays)
|
102 |
+
- `files`: list of uploaded files
|
103 |
+
- `image_hashes`: list of image hashes
|
104 |
+
- `image_filenames`: list of filenames
|
105 |
+
"""
|
106 |
|
107 |
+
|
108 |
+
# buffer info from the file_uploader that doesn't require further user input
|
109 |
+
# - the image, the hash, the filename
|
110 |
+
# a separate function takes care of per-file user inputs for metadata
|
111 |
+
# - this is necessary because dynamically producing more widgets should be
|
112 |
+
# avoided inside callbacks (tl;dr: they dissapear)
|
113 |
+
|
114 |
+
# - note that the UploadedFile objects have file_ids, which are unique to each file
|
115 |
+
# - these file_ids are not persistent between sessions, seem to just be random identifiers.
|
116 |
+
|
117 |
+
|
118 |
+
# get files from state
|
119 |
+
uploaded_files = st.session_state.file_uploader_data
|
120 |
+
|
121 |
+
filenames = []
|
122 |
+
images = {}
|
123 |
+
image_hashes = []
|
124 |
+
|
125 |
+
for ix, file in enumerate(uploaded_files):
|
126 |
+
filename:str = file.name
|
127 |
+
print(f"[D] processing {ix}th file {filename}. {file.file_id} {file.type} {file.size}")
|
128 |
+
# image to np and hash both require reading the file so do together
|
129 |
+
image, image_hash = load_file_and_hash(file)
|
130 |
+
|
131 |
+
filenames.append(filename)
|
132 |
+
image_hashes.append(image_hash)
|
133 |
+
|
134 |
+
images[image_hash] = image
|
135 |
+
|
136 |
+
st.session_state.images = images
|
137 |
+
st.session_state.files = uploaded_files
|
138 |
+
st.session_state.image_hashes = image_hashes
|
139 |
+
st.session_state.image_filenames = filenames
|
140 |
+
|
141 |
+
|
142 |
+
def load_file_and_hash(file:UploadedFile) -> Tuple[np.ndarray, str]:
|
143 |
+
"""
|
144 |
+
Loads an image file and computes its MD5 hash.
|
145 |
+
|
146 |
+
Since both operations require reading the full file contentsV, they are done
|
147 |
+
together for efficiency.
|
148 |
+
|
149 |
+
Args:
|
150 |
+
file (UploadedFile): The uploaded file to be processed.
|
151 |
Returns:
|
152 |
+
Tuple[np.ndarray, str]: A tuple containing the decoded image as a NumPy array and the MD5 hash of the file's contents.
|
153 |
+
"""
|
154 |
+
# two operations that require reading the file done together for efficiency
|
155 |
+
# load the file, compute the hash, return the image and hash
|
156 |
+
_bytes = file.read()
|
157 |
+
image_hash = hashlib.md5(_bytes).hexdigest()
|
158 |
+
image: np.ndarray = cv2.imdecode(np.asarray(bytearray(_bytes), dtype=np.uint8), 1)
|
159 |
+
|
160 |
+
return (image, image_hash)
|
161 |
|
162 |
+
|
163 |
+
|
164 |
+
def metadata_inputs_one_file(file:UploadedFile, image_hash:str, dbg_ix:int=0) -> InputObservation:
|
165 |
"""
|
166 |
+
Creates and parses metadata inputs for a single file
|
167 |
+
|
168 |
+
Args:
|
169 |
+
file (UploadedFile): The uploaded file for which metadata is being handled.
|
170 |
+
image_hash (str): The hash of the image.
|
171 |
+
dbg_ix (int, optional): Debug index to differentiate data in each input group. Defaults to 0.
|
172 |
+
Returns:
|
173 |
+
InputObservation: An object containing the metadata and other information for the input file.
|
174 |
+
"""
|
175 |
+
# dbg_ix is a hack to have different data in each input group, checking persistence
|
176 |
+
|
177 |
+
if st.session_state.container_metadata_inputs is not None:
|
178 |
+
_viewcontainer = st.session_state.container_metadata_inputs
|
179 |
+
else:
|
180 |
+
_viewcontainer = st.sidebar
|
181 |
+
m_logger.warning(f"[W] `container_metadata_inputs` is None, using sidebar")
|
182 |
+
|
183 |
+
|
184 |
+
|
185 |
+
author_email = st.session_state["input_author_email"]
|
186 |
+
filename = file.name
|
187 |
+
image_datetime_raw = get_image_datetime(file)
|
188 |
+
latitude0, longitude0 = get_image_latlon(file)
|
189 |
+
msg = f"[D] {filename}: lat, lon from image metadata: {latitude0}, {longitude0}"
|
190 |
+
m_logger.debug(msg)
|
191 |
+
|
192 |
+
if latitude0 is None: # get some default values if not found in exifdata
|
193 |
+
latitude0:float = spoof_metadata.get('latitude', 0) + dbg_ix
|
194 |
+
if longitude0 is None:
|
195 |
+
longitude0:float = spoof_metadata.get('longitude', 0) - dbg_ix
|
196 |
|
197 |
+
image = st.session_state.images.get(image_hash, None)
|
198 |
+
# add the UI elements
|
199 |
+
#viewcontainer.title(f"Metadata for {filename}")
|
200 |
+
viewcontainer = _viewcontainer.expander(f"Metadata for {file.name}", expanded=True)
|
201 |
+
|
202 |
+
# TODO: use session state so any changes are persisted within session -- currently I think
|
203 |
+
# we are going to take the defaults over and over again -- if the user adjusts coords, or date, it will get lost
|
204 |
+
# - it is a bit complicated, if no values change, they persist (the widget definition: params, name, key, etc)
|
205 |
+
# even if the code is re-run. but if the value changes, it is lost.
|
206 |
|
207 |
|
208 |
+
# 3. Latitude Entry Box
|
209 |
+
latitude = viewcontainer.text_input(
|
210 |
+
"Latitude for " + filename,
|
211 |
+
latitude0,
|
212 |
+
key=f"input_latitude_{image_hash}")
|
213 |
+
if latitude and not is_valid_number(latitude):
|
214 |
+
viewcontainer.error("Please enter a valid latitude (numerical only).")
|
215 |
+
m_logger.error(f"Invalid latitude entered: {latitude}.")
|
216 |
+
# 4. Longitude Entry Box
|
217 |
+
longitude = viewcontainer.text_input(
|
218 |
+
"Longitude for " + filename,
|
219 |
+
longitude0,
|
220 |
+
key=f"input_longitude_{image_hash}")
|
221 |
+
if longitude and not is_valid_number(longitude):
|
222 |
+
viewcontainer.error("Please enter a valid longitude (numerical only).")
|
223 |
+
m_logger.error(f"Invalid latitude entered: {latitude}.")
|
224 |
+
|
225 |
+
# 5. Date/time
|
226 |
+
## first from image metadata
|
227 |
+
if image_datetime_raw is not None:
|
228 |
+
time_value = datetime.datetime.strptime(image_datetime_raw, '%Y:%m:%d %H:%M:%S').time()
|
229 |
+
date_value = datetime.datetime.strptime(image_datetime_raw, '%Y:%m:%d %H:%M:%S').date()
|
230 |
+
else:
|
231 |
+
time_value = datetime.datetime.now().time() # Default to current time
|
232 |
+
date_value = datetime.datetime.now().date()
|
233 |
+
|
234 |
+
## either way, give user the option to enter manually (or correct, e.g. if camera has no rtc clock)
|
235 |
+
date = viewcontainer.date_input("Date for "+filename, value=date_value, key=f"input_date_{image_hash}")
|
236 |
+
time = viewcontainer.time_input("Time for "+filename, time_value, key=f"input_time_{image_hash}")
|
237 |
+
|
238 |
+
observation = InputObservation(image=image, latitude=latitude, longitude=longitude,
|
239 |
+
author_email=author_email, image_datetime_raw=image_datetime_raw,
|
240 |
+
date=date, time=time,
|
241 |
+
uploaded_file=file, image_md5=image_hash
|
242 |
+
)
|
243 |
+
|
244 |
+
return observation
|
245 |
+
|
246 |
|
247 |
+
|
248 |
+
def _setup_dynamic_inputs() -> None:
|
249 |
+
"""
|
250 |
+
Setup metadata inputs dynamically for each uploaded file, and process.
|
251 |
+
|
252 |
+
This operates on the data buffered in the session state, and writes
|
253 |
+
the observation objects back to the session state.
|
254 |
+
|
255 |
+
"""
|
256 |
|
257 |
+
# for each file uploaded,
|
258 |
+
# - add the UI elements for the metadata
|
259 |
+
# - validate the data
|
260 |
+
# end of cycle should have observation objects set for each file.
|
261 |
+
# - and these go into session state
|
262 |
+
|
263 |
+
# load the files from the session state
|
264 |
+
uploaded_files = st.session_state.files
|
265 |
+
hashes = st.session_state.image_hashes
|
266 |
+
#images = st.session_state.images
|
267 |
observations = {}
|
268 |
+
|
269 |
+
for ix, file in enumerate(uploaded_files):
|
270 |
+
hash = hashes[ix]
|
271 |
+
observation = metadata_inputs_one_file(file, hash, ix)
|
272 |
+
old_obs = st.session_state.observations.get(hash, None)
|
273 |
+
if old_obs is not None:
|
274 |
+
if old_obs == observation:
|
275 |
+
m_logger.debug(f"[D] {ix}th observation is the same as before. retaining")
|
276 |
+
observations[hash] = old_obs
|
277 |
+
else:
|
278 |
+
m_logger.debug(f"[D] {ix}th observation is different from before. updating")
|
279 |
+
observations[hash] = observation
|
280 |
+
observation.show_diff(old_obs)
|
281 |
+
else:
|
282 |
+
m_logger.debug(f"[D] {ix}th observation is new (image_hash not seen before). Storing")
|
283 |
+
observations[hash] = observation
|
284 |
|
285 |
+
st.session_state.observations = observations
|
286 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
287 |
|
288 |
+
def _setup_oneoff_inputs() -> None:
|
289 |
+
'''
|
290 |
+
Add the UI input elements for which we have one covering all files
|
291 |
+
|
292 |
+
- author email
|
293 |
+
- file uploader (accepts multiple files)
|
294 |
+
'''
|
295 |
+
|
296 |
+
# fetch the container for the file uploader input elements
|
297 |
+
container_file_uploader = st.session_state.container_file_uploader
|
298 |
|
299 |
+
with container_file_uploader:
|
300 |
+
# 1. Input the author email
|
301 |
+
author_email = st.text_input("Author Email", spoof_metadata.get('author_email', ""),
|
302 |
+
key="input_author_email")
|
303 |
+
if author_email and not is_valid_email(author_email):
|
304 |
+
st.error("Please enter a valid email address.")
|
305 |
+
|
306 |
+
# 2. Image Selector
|
307 |
+
st.file_uploader(
|
308 |
+
"Upload one or more images", type=["png", 'jpg', 'jpeg', 'webp'],
|
309 |
+
accept_multiple_files=True,
|
310 |
+
key="file_uploader_data", on_change=buffer_uploaded_files)
|
311 |
+
|
312 |
+
|
313 |
+
|
314 |
+
|
315 |
|
|
|
|
|
|
|
|
|
316 |
|
317 |
+
|
318 |
+
def setup_input() -> None:
|
319 |
+
'''
|
320 |
+
Set up the user input handling (files and metadata)
|
321 |
+
|
322 |
+
It provides input fields for an image upload, and author email.
|
323 |
+
Then for each uploaded image,
|
324 |
+
- it provides input fields for lat/lon, date-time.
|
325 |
+
- In the ideal case, the image metadata will be used to populate location and datetime.
|
326 |
+
|
327 |
+
Data is stored in the Streamlit session state for downstream processing,
|
328 |
+
nothing is returned
|
329 |
+
|
330 |
+
'''
|
331 |
+
# configure the author email and file_uploader (with callback to buffer files)
|
332 |
+
_setup_oneoff_inputs()
|
333 |
+
|
334 |
+
# setup dynamic UI input elements, based on the data that is buffered in session_state
|
335 |
+
_setup_dynamic_inputs()
|
336 |
+
|
337 |
+
|
338 |
+
def init_input_container_states() -> None:
|
339 |
+
'''
|
340 |
+
Initialise the layout containers used in the input handling
|
341 |
+
'''
|
342 |
+
#if "container_per_file_input_elems" not in st.session_state:
|
343 |
+
# st.session_state.container_per_file_input_elems = None
|
344 |
+
|
345 |
+
if "container_file_uploader" not in st.session_state:
|
346 |
+
st.session_state.container_file_uploader = None
|
347 |
+
|
348 |
+
if "container_metadata_inputs" not in st.session_state:
|
349 |
+
st.session_state.container_metadata_inputs = None
|
350 |
+
|
351 |
+
def init_input_data_session_states() -> None:
|
352 |
+
'''
|
353 |
+
Initialise the session state variables used in the input handling
|
354 |
+
'''
|
355 |
+
|
356 |
+
if "image_hashes" not in st.session_state:
|
357 |
+
st.session_state.image_hashes = []
|
358 |
+
|
359 |
+
# TODO: ideally just use image_hashes, but need a unique key for the ui elements
|
360 |
+
# to track the user input phase; and these are created before the hash is generated.
|
361 |
+
if "image_filenames" not in st.session_state:
|
362 |
+
st.session_state.image_filenames = []
|
363 |
+
|
364 |
+
if "observations" not in st.session_state:
|
365 |
+
st.session_state.observations = {}
|
366 |
+
|
367 |
+
if "images" not in st.session_state:
|
368 |
+
st.session_state.images = {}
|
369 |
+
|
370 |
+
if "files" not in st.session_state:
|
371 |
+
st.session_state.files = {}
|
372 |
+
|
373 |
+
if "public_observations" not in st.session_state:
|
374 |
+
st.session_state.public_observations = {}
|
375 |
+
|
376 |
+
|
377 |
+
|
378 |
+
def add_input_UI_elements() -> None:
|
379 |
+
'''
|
380 |
+
Create the containers within which user input elements will be placed
|
381 |
+
'''
|
382 |
+
# we make containers ahead of time, allowing consistent order of elements
|
383 |
+
# which are not created in the same order.
|
384 |
+
|
385 |
+
st.divider()
|
386 |
+
st.title("Input image and data")
|
387 |
+
|
388 |
+
# create and style a container for the file uploader/other one-off inputs
|
389 |
+
st.markdown('<style>.st-key-container_file_uploader_id { border: 1px solid skyblue; border-radius: 5px; }</style>', unsafe_allow_html=True)
|
390 |
+
container_file_uploader = st.container(border=True, key="container_file_uploader_id")
|
391 |
+
st.session_state.container_file_uploader = container_file_uploader
|
392 |
+
|
393 |
+
# create and style a container for the dynamic metadata inputs
|
394 |
+
st.markdown('<style>.st-key-container_metadata_inputs_id { border: 1px solid lightgreen; border-radius: 5px; }</style>', unsafe_allow_html=True)
|
395 |
+
container_metadata_inputs = st.container(border=True, key="container_metadata_inputs_id")
|
396 |
+
container_metadata_inputs.write("Metadata Inputs... wait for file upload ")
|
397 |
+
st.session_state.container_metadata_inputs = container_metadata_inputs
|
398 |
+
|
399 |
+
|
400 |
+
def dbg_show_observation_hashes() -> None:
|
401 |
+
"""
|
402 |
+
Displays information about each observation including the hash
|
403 |
+
|
404 |
+
- debug usage, keeping track of the hashes and persistence of the InputObservations.
|
405 |
+
- it renders text to the current container, not intended for final app.
|
406 |
+
|
407 |
+
"""
|
408 |
+
|
409 |
+
# a debug: we seem to be losing the whale classes?
|
410 |
+
st.write(f"[D] num observations: {len(st.session_state.observations)}")
|
411 |
+
s = ""
|
412 |
+
for hash in st.session_state.observations.keys():
|
413 |
+
obs = st.session_state.observations[hash]
|
414 |
+
s += f"- [D] observation {hash} ({obs._inst_id}) has {len(obs.top_predictions)} predictions\n"
|
415 |
+
#s += f" - {repr(obs)}\n" # check the str / repr method
|
416 |
+
|
417 |
+
#print(obs)
|
418 |
+
|
419 |
+
st.markdown(s)
|
src/input/input_observation.py
CHANGED
@@ -1,13 +1,18 @@
|
|
1 |
import hashlib
|
2 |
from input.input_validator import generate_random_md5
|
3 |
|
|
|
|
|
|
|
|
|
|
|
4 |
# autogenerated class to hold the input data
|
5 |
class InputObservation:
|
6 |
"""
|
7 |
A class to hold an input observation and associated metadata
|
8 |
|
9 |
Attributes:
|
10 |
-
image (
|
11 |
The image associated with the observation.
|
12 |
latitude (float):
|
13 |
The latitude where the observation was made.
|
@@ -15,16 +20,16 @@ class InputObservation:
|
|
15 |
The longitude where the observation was made.
|
16 |
author_email (str):
|
17 |
The email of the author of the observation.
|
18 |
-
|
19 |
-
The
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
The
|
28 |
|
29 |
Methods:
|
30 |
__str__():
|
@@ -35,8 +40,8 @@ class InputObservation:
|
|
35 |
Checks if two observations are equal.
|
36 |
__ne__(other):
|
37 |
Checks if two observations are not equal.
|
38 |
-
|
39 |
-
|
40 |
to_dict():
|
41 |
Converts the observation to a dictionary.
|
42 |
from_dict(data):
|
@@ -44,66 +49,208 @@ class InputObservation:
|
|
44 |
from_input(input):
|
45 |
Creates an observation from another input observation.
|
46 |
"""
|
47 |
-
|
48 |
-
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
self.image = image
|
51 |
self.latitude = latitude
|
52 |
self.longitude = longitude
|
53 |
self.author_email = author_email
|
|
|
54 |
self.date = date
|
55 |
self.time = time
|
56 |
-
self.
|
57 |
-
self.
|
58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
def __str__(self):
|
61 |
-
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
def __repr__(self):
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
def __eq__(self, other):
|
67 |
-
|
68 |
-
|
69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
def __ne__(self, other):
|
72 |
return not self.__eq__(other)
|
73 |
|
74 |
-
def __hash__(self):
|
75 |
-
return hash((self.image, self.latitude, self.longitude, self.author_email, self.date, self.time, self.date_option, self.time_option, self.uploaded_filename))
|
76 |
-
|
77 |
def to_dict(self):
|
78 |
return {
|
79 |
#"image": self.image,
|
80 |
-
"image_filename": self.
|
81 |
-
"image_md5":
|
|
|
82 |
"latitude": self.latitude,
|
83 |
"longitude": self.longitude,
|
84 |
"author_email": self.author_email,
|
85 |
-
"
|
86 |
-
"
|
87 |
-
"
|
88 |
-
"
|
89 |
-
"
|
|
|
|
|
|
|
90 |
}
|
91 |
|
92 |
@classmethod
|
93 |
def from_dict(cls, data):
|
94 |
-
return cls(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
|
96 |
@classmethod
|
97 |
def from_input(cls, input):
|
98 |
-
return cls(
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
|
104 |
-
@staticmethod
|
105 |
-
def from_dict(data):
|
106 |
-
return InputObservation(data["image"], data["latitude"], data["longitude"], data["author_email"], data["date"], data["time"], data["date_option"], data["time_option"], data["uploaded_filename"])
|
107 |
|
108 |
|
109 |
|
|
|
1 |
import hashlib
|
2 |
from input.input_validator import generate_random_md5
|
3 |
|
4 |
+
from numpy import ndarray
|
5 |
+
from streamlit.runtime.uploaded_file_manager import UploadedFile
|
6 |
+
import datetime
|
7 |
+
|
8 |
+
|
9 |
# autogenerated class to hold the input data
|
10 |
class InputObservation:
|
11 |
"""
|
12 |
A class to hold an input observation and associated metadata
|
13 |
|
14 |
Attributes:
|
15 |
+
image (ndarray):
|
16 |
The image associated with the observation.
|
17 |
latitude (float):
|
18 |
The latitude where the observation was made.
|
|
|
20 |
The longitude where the observation was made.
|
21 |
author_email (str):
|
22 |
The email of the author of the observation.
|
23 |
+
image_datetime_raw (str):
|
24 |
+
The datetime extracted from the observation file
|
25 |
+
date (datetime.date):
|
26 |
+
Date of the observation
|
27 |
+
time (datetime.time):
|
28 |
+
Time of the observation
|
29 |
+
uploaded_file (UploadedFile):
|
30 |
+
The uploaded file associated with the observation.
|
31 |
+
image_md5 (str):
|
32 |
+
The MD5 hash of the image associated with the observation.
|
33 |
|
34 |
Methods:
|
35 |
__str__():
|
|
|
40 |
Checks if two observations are equal.
|
41 |
__ne__(other):
|
42 |
Checks if two observations are not equal.
|
43 |
+
show_diff(other):
|
44 |
+
Shows the differences between two observations.
|
45 |
to_dict():
|
46 |
Converts the observation to a dictionary.
|
47 |
from_dict(data):
|
|
|
49 |
from_input(input):
|
50 |
Creates an observation from another input observation.
|
51 |
"""
|
52 |
+
|
53 |
+
_inst_count = 0
|
54 |
+
|
55 |
+
def __init__(
|
56 |
+
self, image:ndarray=None, latitude:float=None, longitude:float=None,
|
57 |
+
author_email:str=None, image_datetime_raw:str=None,
|
58 |
+
date:datetime.date=None,
|
59 |
+
time:datetime.time=None,
|
60 |
+
uploaded_file:UploadedFile=None, image_md5:str=None):
|
61 |
+
|
62 |
self.image = image
|
63 |
self.latitude = latitude
|
64 |
self.longitude = longitude
|
65 |
self.author_email = author_email
|
66 |
+
self.image_datetime_raw = image_datetime_raw
|
67 |
self.date = date
|
68 |
self.time = time
|
69 |
+
self.uploaded_file = uploaded_file
|
70 |
+
self.image_md5 = image_md5
|
71 |
+
# attributes that get set after predictions/processing
|
72 |
+
self._top_predictions = []
|
73 |
+
self._selected_class = None
|
74 |
+
self._class_overriden = False
|
75 |
+
|
76 |
+
InputObservation._inst_count += 1
|
77 |
+
self._inst_id = InputObservation._inst_count
|
78 |
+
|
79 |
+
|
80 |
+
#dbg - temporarily give up if hash is not provided
|
81 |
+
if self.image_md5 is None:
|
82 |
+
raise ValueError(f"Image MD5 hash is required - {self._inst_id:3}.")
|
83 |
+
|
84 |
+
|
85 |
+
def set_top_predictions(self, top_predictions:list):
|
86 |
+
self._top_predictions = top_predictions
|
87 |
+
if len(top_predictions) > 0:
|
88 |
+
self.set_selected_class(top_predictions[0])
|
89 |
+
|
90 |
+
def set_selected_class(self, selected_class:str):
|
91 |
+
self._selected_class = selected_class
|
92 |
+
if selected_class != self._top_predictions[0]:
|
93 |
+
self.set_class_overriden(True)
|
94 |
+
|
95 |
+
def set_class_overriden(self, class_overriden:bool):
|
96 |
+
self._class_overriden = class_overriden
|
97 |
+
|
98 |
+
# add getters for the top_predictions, selected_class and class_overriden
|
99 |
+
@property
|
100 |
+
def top_predictions(self):
|
101 |
+
return self._top_predictions
|
102 |
+
|
103 |
+
@property
|
104 |
+
def selected_class(self):
|
105 |
+
return self._selected_class
|
106 |
+
|
107 |
+
@property
|
108 |
+
def class_overriden(self):
|
109 |
+
return self._class_overriden
|
110 |
+
|
111 |
+
|
112 |
+
# add a method to assign the image_md5 only once
|
113 |
+
def assign_image_md5(self):
|
114 |
+
raise DeprecationWarning("This method is deprecated. hash is a required constructor argument.")
|
115 |
+
if not self.image_md5:
|
116 |
+
self.image_md5 = hashlib.md5(self.uploaded_file.read()).hexdigest() if self.uploaded_file else generate_random_md5()
|
117 |
+
m_logger.debug(f"[D] Assigned image md5: {self.image_md5} for {self.uploaded_file}")
|
118 |
|
119 |
def __str__(self):
|
120 |
+
_im_str = "None" if self.image is None else f"image dims: {self.image.shape}"
|
121 |
+
return (
|
122 |
+
f"Observation: {_im_str}, {self.latitude}, {self.longitude}, "
|
123 |
+
f"{self.author_email}, {self.image_datetime_raw}, {self.date}, "
|
124 |
+
f"{self.time}, {self.uploaded_file}, {self.image_md5}"
|
125 |
+
)
|
126 |
|
127 |
def __repr__(self):
|
128 |
+
_im_str = "None" if self.image is None else f"image dims: {self.image.shape}"
|
129 |
+
return (
|
130 |
+
f"Observation: "
|
131 |
+
f"Image: {_im_str}, "
|
132 |
+
f"Latitude: {self.latitude}, "
|
133 |
+
f"Longitude: {self.longitude}, "
|
134 |
+
f"Author Email: {self.author_email}, "
|
135 |
+
f"raw timestamp: {self.image_datetime_raw}, "
|
136 |
+
f"Date: {self.date}, "
|
137 |
+
f"Time: {self.time}, "
|
138 |
+
f"Uploaded Filename: {self.uploaded_file}"
|
139 |
+
f"Image MD5 hash: {self.image_md5}"
|
140 |
+
)
|
141 |
+
|
142 |
|
143 |
def __eq__(self, other):
|
144 |
+
# TODO: ensure this covers all the attributes (some have been added?)
|
145 |
+
# - except inst_id which is unique
|
146 |
+
_image_equality = False
|
147 |
+
if self.image is None or other.image is None:
|
148 |
+
_image_equality = other.image == self.image
|
149 |
+
else: # maybe strong assumption: both are correctly ndarray.. should I test types intead?
|
150 |
+
_image_equality = (self.image == other.image).all()
|
151 |
+
equality = (
|
152 |
+
#self.image == other.image and
|
153 |
+
_image_equality and
|
154 |
+
self.latitude == other.latitude and
|
155 |
+
self.longitude == other.longitude and
|
156 |
+
self.author_email == other.author_email and
|
157 |
+
self.image_datetime_raw == other.image_datetime_raw and
|
158 |
+
self.date == other.date and
|
159 |
+
# temporarily skip time, it is followed by the clock and that is always differnt
|
160 |
+
#self.time == other.time and
|
161 |
+
self.uploaded_file == other.uploaded_file and
|
162 |
+
self.image_md5 == other.image_md5
|
163 |
+
)
|
164 |
+
return equality
|
165 |
+
|
166 |
+
# define a function show_diff(other) that shows the differences between two observations
|
167 |
+
# only highlight the differences, if element is the same don't show it
|
168 |
+
# have a summary at the top that shows if the observations are the same or not
|
169 |
+
|
170 |
+
def show_diff(self, other):
|
171 |
+
"""Show the differences between two observations"""
|
172 |
+
differences = []
|
173 |
+
if self.image is None or other.image is None:
|
174 |
+
if other.image != self.image:
|
175 |
+
differences.append(f" Image is different. (types mismatch: {type(self.image)} vs {type(other.image)})")
|
176 |
+
else:
|
177 |
+
if (self.image != other.image).any():
|
178 |
+
cnt = (self.image != other.image).sum()
|
179 |
+
differences.append(f" Image is different: {cnt} different pixels.")
|
180 |
+
if self.latitude != other.latitude:
|
181 |
+
differences.append(f" Latitude is different. (self: {self.latitude}, other: {other.latitude})")
|
182 |
+
if self.longitude != other.longitude:
|
183 |
+
differences.append(f" Longitude is different. (self: {self.longitude}, other: {other.longitude})")
|
184 |
+
if self.author_email != other.author_email:
|
185 |
+
differences.append(f" Author email is different. (self: {self.author_email}, other: {other.author_email})")
|
186 |
+
if self.image_datetime_raw != other.image_datetime_raw:
|
187 |
+
differences.append(f" Date is different. (self: {self.image_datetime_raw}, other: {other.image_datetime_raw})")
|
188 |
+
if self.date != other.date:
|
189 |
+
differences.append(f" Date is different. (self: {self.date}, other: {other.date})")
|
190 |
+
if self.time != other.time:
|
191 |
+
differences.append(f" Time is different. (self: {self.time}, other: {other.time})")
|
192 |
+
if self.uploaded_file != other.uploaded_file:
|
193 |
+
differences.append(" Uploaded filename is different.")
|
194 |
+
if self.image_md5 != other.image_md5:
|
195 |
+
differences.append(" Image MD5 hash is different.")
|
196 |
+
|
197 |
+
if differences:
|
198 |
+
print(f"Observations have {len(differences)} differences:")
|
199 |
+
for diff in differences:
|
200 |
+
print(diff)
|
201 |
+
else:
|
202 |
+
print("Observations are the same.")
|
203 |
|
204 |
def __ne__(self, other):
|
205 |
return not self.__eq__(other)
|
206 |
|
|
|
|
|
|
|
207 |
def to_dict(self):
|
208 |
return {
|
209 |
#"image": self.image,
|
210 |
+
"image_filename": self.uploaded_file.name if self.uploaded_file else None,
|
211 |
+
"image_md5": self.image_md5,
|
212 |
+
#"image_md5": hashlib.md5(self.uploaded_file.read()).hexdigest() if self.uploaded_file else generate_random_md5(),
|
213 |
"latitude": self.latitude,
|
214 |
"longitude": self.longitude,
|
215 |
"author_email": self.author_email,
|
216 |
+
"image_datetime_raw": self.image_datetime_raw,
|
217 |
+
"date": str(self.date),
|
218 |
+
"time": str(self.time),
|
219 |
+
"selected_class": self._selected_class,
|
220 |
+
"top_prediction": self._top_predictions[0] if len(self._top_predictions) else None,
|
221 |
+
"class_overriden": self._class_overriden,
|
222 |
+
|
223 |
+
#"uploaded_file": self.uploaded_file # can't serialize this in json, not sent to dataset anyway.
|
224 |
}
|
225 |
|
226 |
@classmethod
|
227 |
def from_dict(cls, data):
|
228 |
+
return cls(
|
229 |
+
image=data.get("image"),
|
230 |
+
latitude=data.get("latitude"),
|
231 |
+
longitude=data.get("longitude"),
|
232 |
+
author_email=data.get("author_email"),
|
233 |
+
image_datetime_raw=data.get("image_datetime_raw"),
|
234 |
+
date=data.get("date"),
|
235 |
+
time=data.get("time"),
|
236 |
+
uploaded_file=data.get("uploaded_file"),
|
237 |
+
image_hash=data.get("image_md5")
|
238 |
+
)
|
239 |
|
240 |
@classmethod
|
241 |
def from_input(cls, input):
|
242 |
+
return cls(
|
243 |
+
image=input.image,
|
244 |
+
latitude=input.latitude,
|
245 |
+
longitude=input.longitude,
|
246 |
+
author_email=input.author_email,
|
247 |
+
image_datetime_raw=input.image_datetime_raw,
|
248 |
+
date=input.date,
|
249 |
+
time=input.time,
|
250 |
+
uploaded_file=input.uploaded_file,
|
251 |
+
image_hash=input.image_hash
|
252 |
+
)
|
253 |
|
|
|
|
|
|
|
254 |
|
255 |
|
256 |
|
src/input/input_validator.py
CHANGED
@@ -1,22 +1,33 @@
|
|
|
|
1 |
import random
|
2 |
import string
|
3 |
import hashlib
|
4 |
import re
|
5 |
-
import streamlit as st
|
6 |
from fractions import Fraction
|
7 |
-
|
8 |
from PIL import Image
|
9 |
from PIL import ExifTags
|
10 |
|
|
|
11 |
from streamlit.runtime.uploaded_file_manager import UploadedFile
|
12 |
|
13 |
-
def generate_random_md5():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
# Generate a random string
|
15 |
-
random_string = ''.join(random.choices(string.ascii_letters + string.digits,
|
16 |
# Encode the string and compute its MD5 hash
|
17 |
md5_hash = hashlib.md5(random_string.encode()).hexdigest()
|
18 |
return md5_hash
|
19 |
|
|
|
20 |
def is_valid_number(number:str) -> bool:
|
21 |
"""
|
22 |
Check if the given string is a valid number (int or float, sign ok)
|
@@ -30,6 +41,7 @@ def is_valid_number(number:str) -> bool:
|
|
30 |
pattern = r'^[-+]?[0-9]*\.?[0-9]+$'
|
31 |
return re.match(pattern, number) is not None
|
32 |
|
|
|
33 |
# Function to validate email address
|
34 |
def is_valid_email(email:str) -> bool:
|
35 |
"""
|
@@ -41,11 +53,14 @@ def is_valid_email(email:str) -> bool:
|
|
41 |
Returns:
|
42 |
bool: True if the email address is valid, False otherwise.
|
43 |
"""
|
44 |
-
pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
|
|
|
|
|
45 |
return re.match(pattern, email) is not None
|
46 |
|
|
|
47 |
# Function to extract date and time from image metadata
|
48 |
-
def get_image_datetime(image_file):
|
49 |
"""
|
50 |
Extracts the original date and time from the EXIF metadata of an uploaded image file.
|
51 |
|
@@ -69,6 +84,7 @@ def get_image_datetime(image_file):
|
|
69 |
# TODO: add to logger
|
70 |
return None
|
71 |
|
|
|
72 |
def decimal_coords(coords:tuple, ref:str) -> Fraction:
|
73 |
"""
|
74 |
Converts coordinates from degrees, minutes, and seconds to decimal degrees.
|
@@ -96,8 +112,9 @@ def decimal_coords(coords:tuple, ref:str) -> Fraction:
|
|
96 |
return decimal_degrees
|
97 |
|
98 |
|
99 |
-
#def get_image_latlon(image_file: UploadedFile)
|
100 |
-
def get_image_latlon(image_file: UploadedFile) :
|
|
|
101 |
"""
|
102 |
Extracts the latitude and longitude from the EXIF metadata of an uploaded image file.
|
103 |
|
@@ -123,4 +140,6 @@ def get_image_latlon(image_file: UploadedFile) :
|
|
123 |
return lat, lon
|
124 |
|
125 |
except Exception as e: # FIXME: what types of exception?
|
126 |
-
st.warning(f"Could not extract latitude and longitude from image metadata. (file: {str(image_file)}")
|
|
|
|
|
|
1 |
+
from typing import Tuple, Union
|
2 |
import random
|
3 |
import string
|
4 |
import hashlib
|
5 |
import re
|
|
|
6 |
from fractions import Fraction
|
|
|
7 |
from PIL import Image
|
8 |
from PIL import ExifTags
|
9 |
|
10 |
+
import streamlit as st
|
11 |
from streamlit.runtime.uploaded_file_manager import UploadedFile
|
12 |
|
13 |
+
def generate_random_md5(length:int=16) -> str:
|
14 |
+
"""
|
15 |
+
Generate a random MD5 hash.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
length (int): The length of the random string to generate. Default is 16.
|
19 |
+
|
20 |
+
Returns:
|
21 |
+
str: The MD5 hash of the generated random string.
|
22 |
+
"""
|
23 |
+
|
24 |
# Generate a random string
|
25 |
+
random_string = ''.join(random.choices(string.ascii_letters + string.digits, length=16))
|
26 |
# Encode the string and compute its MD5 hash
|
27 |
md5_hash = hashlib.md5(random_string.encode()).hexdigest()
|
28 |
return md5_hash
|
29 |
|
30 |
+
|
31 |
def is_valid_number(number:str) -> bool:
|
32 |
"""
|
33 |
Check if the given string is a valid number (int or float, sign ok)
|
|
|
41 |
pattern = r'^[-+]?[0-9]*\.?[0-9]+$'
|
42 |
return re.match(pattern, number) is not None
|
43 |
|
44 |
+
|
45 |
# Function to validate email address
|
46 |
def is_valid_email(email:str) -> bool:
|
47 |
"""
|
|
|
53 |
Returns:
|
54 |
bool: True if the email address is valid, False otherwise.
|
55 |
"""
|
56 |
+
#pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
|
57 |
+
# do not allow starting with a +
|
58 |
+
pattern = r'^[a-zA-Z0-9_]+[a-zA-Z0-9._%+-]*@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
|
59 |
return re.match(pattern, email) is not None
|
60 |
|
61 |
+
|
62 |
# Function to extract date and time from image metadata
|
63 |
+
def get_image_datetime(image_file:UploadedFile) -> Union[str, None]:
|
64 |
"""
|
65 |
Extracts the original date and time from the EXIF metadata of an uploaded image file.
|
66 |
|
|
|
84 |
# TODO: add to logger
|
85 |
return None
|
86 |
|
87 |
+
|
88 |
def decimal_coords(coords:tuple, ref:str) -> Fraction:
|
89 |
"""
|
90 |
Converts coordinates from degrees, minutes, and seconds to decimal degrees.
|
|
|
112 |
return decimal_degrees
|
113 |
|
114 |
|
115 |
+
#def get_image_latlon(image_file: UploadedFile) : # if it is still not working
|
116 |
+
#def get_image_latlon(image_file: UploadedFile) -> Tuple[float, float] | None: # Python >=3.10
|
117 |
+
def get_image_latlon(image_file: UploadedFile) -> Union[Tuple[float, float], None]: # 3.6 <= Python < 3.10
|
118 |
"""
|
119 |
Extracts the latitude and longitude from the EXIF metadata of an uploaded image file.
|
120 |
|
|
|
140 |
return lat, lon
|
141 |
|
142 |
except Exception as e: # FIXME: what types of exception?
|
143 |
+
st.warning(f"Could not extract latitude and longitude from image metadata. (file: {str(image_file)}")
|
144 |
+
|
145 |
+
return None, None
|
src/main.py
CHANGED
@@ -9,17 +9,24 @@ from streamlit_folium import st_folium
|
|
9 |
from transformers import pipeline
|
10 |
from transformers import AutoModelForImageClassification
|
11 |
|
12 |
-
from maps.obs_map import
|
|
|
13 |
from datasets import disable_caching
|
14 |
disable_caching()
|
15 |
|
16 |
import whale_gallery as gallery
|
17 |
import whale_viewer as viewer
|
18 |
-
from input.input_handling import setup_input
|
|
|
|
|
|
|
19 |
from maps.alps_map import present_alps_map
|
20 |
from maps.obs_map import present_obs_map
|
21 |
-
from utils.st_logs import
|
22 |
-
from
|
|
|
|
|
|
|
23 |
from classifier.classifier_hotdog import hotdog_classify
|
24 |
|
25 |
|
@@ -34,6 +41,11 @@ data_files = "data/train-00000-of-00001.parquet"
|
|
34 |
USE_BASIC_MAP = False
|
35 |
DEV_SIDEBAR_LIB = True
|
36 |
|
|
|
|
|
|
|
|
|
|
|
37 |
# get a global var for logger accessor in this module
|
38 |
LOG_LEVEL = logging.DEBUG
|
39 |
g_logger = logging.getLogger(__name__)
|
@@ -42,33 +54,13 @@ g_logger.setLevel(LOG_LEVEL)
|
|
42 |
st.set_page_config(layout="wide")
|
43 |
|
44 |
# initialise various session state variables
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
if "observations" not in st.session_state:
|
52 |
-
st.session_state.observations = {}
|
53 |
-
|
54 |
-
if "images" not in st.session_state:
|
55 |
-
st.session_state.images = {}
|
56 |
-
|
57 |
-
if "files" not in st.session_state:
|
58 |
-
st.session_state.files = {}
|
59 |
|
60 |
-
if "public_observation" not in st.session_state:
|
61 |
-
st.session_state.public_observation = {}
|
62 |
-
|
63 |
-
if "classify_whale_done" not in st.session_state:
|
64 |
-
st.session_state.classify_whale_done = False
|
65 |
-
|
66 |
-
if "whale_prediction1" not in st.session_state:
|
67 |
-
st.session_state.whale_prediction1 = None
|
68 |
-
|
69 |
-
if "tab_log" not in st.session_state:
|
70 |
-
st.session_state.tab_log = None
|
71 |
-
|
72 |
|
73 |
def main() -> None:
|
74 |
"""
|
@@ -100,29 +92,22 @@ def main() -> None:
|
|
100 |
# Streamlit app
|
101 |
tab_inference, tab_hotdogs, tab_map, tab_coords, tab_log, tab_gallery = \
|
102 |
st.tabs(["Cetecean classifier", "Hotdog classifier", "Map", "*:gray[Dev:coordinates]*", "Log", "Beautiful cetaceans"])
|
103 |
-
st.session_state.tab_log = tab_log
|
104 |
|
|
|
|
|
105 |
|
106 |
# create a sidebar, and parse all the input (returned as `observations` object)
|
107 |
-
|
|
|
|
|
|
|
|
|
108 |
|
109 |
|
110 |
-
if 0:## WIP
|
111 |
-
# goal of this code is to allow the user to override the ML prediction, before transmitting an observations
|
112 |
-
predicted_class = st.sidebar.selectbox("Predicted Class", viewer.WHALE_CLASSES)
|
113 |
-
override_prediction = st.sidebar.checkbox("Override Prediction")
|
114 |
-
|
115 |
-
if override_prediction:
|
116 |
-
overridden_class = st.sidebar.selectbox("Override Class", viewer.WHALE_CLASSES)
|
117 |
-
st.session_state.observations['class_overriden'] = overridden_class
|
118 |
-
else:
|
119 |
-
st.session_state.observations['class_overriden'] = None
|
120 |
-
|
121 |
-
|
122 |
with tab_map:
|
123 |
# visual structure: a couple of toggles at the top, then the map inlcuding a
|
124 |
# dropdown for tileset selection.
|
125 |
-
|
126 |
tab_map_ui_cols = st.columns(2)
|
127 |
with tab_map_ui_cols[0]:
|
128 |
show_db_points = st.toggle("Show Points from DB", True)
|
@@ -180,43 +165,128 @@ def main() -> None:
|
|
180 |
gallery.render_whale_gallery(n_cols=4)
|
181 |
|
182 |
|
183 |
-
#
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
192 |
|
193 |
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
You can override the prediction by selecting a species from the dropdown.*""")
|
204 |
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
211 |
|
212 |
-
|
213 |
-
#
|
214 |
-
st.
|
215 |
-
|
216 |
-
|
217 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
218 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
219 |
|
|
|
|
|
|
|
220 |
|
221 |
# inside the hotdog tab, on button press we call a 2nd model (totally unrelated at present, just for demo
|
222 |
# purposes, an hotdog image classifier) which will be run locally.
|
@@ -240,6 +310,9 @@ def main() -> None:
|
|
240 |
hotdog_classify(pipeline_hot_dog, tab_hotdogs)
|
241 |
|
242 |
|
|
|
|
|
|
|
243 |
|
244 |
if __name__ == "__main__":
|
245 |
main()
|
|
|
9 |
from transformers import pipeline
|
10 |
from transformers import AutoModelForImageClassification
|
11 |
|
12 |
+
from maps.obs_map import add_obs_map_header
|
13 |
+
from classifier.classifier_image import add_classifier_header
|
14 |
from datasets import disable_caching
|
15 |
disable_caching()
|
16 |
|
17 |
import whale_gallery as gallery
|
18 |
import whale_viewer as viewer
|
19 |
+
from input.input_handling import setup_input, check_inputs_are_set
|
20 |
+
from input.input_handling import init_input_container_states, add_input_UI_elements, init_input_data_session_states
|
21 |
+
from input.input_handling import dbg_show_observation_hashes
|
22 |
+
|
23 |
from maps.alps_map import present_alps_map
|
24 |
from maps.obs_map import present_obs_map
|
25 |
+
from utils.st_logs import parse_log_buffer, init_logging_session_states
|
26 |
+
from utils.workflow_ui import refresh_progress_display, init_workflow_viz, init_workflow_session_states
|
27 |
+
from hf_push_observations import push_all_observations
|
28 |
+
|
29 |
+
from classifier.classifier_image import cetacean_just_classify, cetacean_show_results_and_review, cetacean_show_results, init_classifier_session_states
|
30 |
from classifier.classifier_hotdog import hotdog_classify
|
31 |
|
32 |
|
|
|
41 |
USE_BASIC_MAP = False
|
42 |
DEV_SIDEBAR_LIB = True
|
43 |
|
44 |
+
# one toggle for all the extra debug text
|
45 |
+
if "MODE_DEV_STATEFUL" not in st.session_state:
|
46 |
+
st.session_state.MODE_DEV_STATEFUL = False
|
47 |
+
|
48 |
+
|
49 |
# get a global var for logger accessor in this module
|
50 |
LOG_LEVEL = logging.DEBUG
|
51 |
g_logger = logging.getLogger(__name__)
|
|
|
54 |
st.set_page_config(layout="wide")
|
55 |
|
56 |
# initialise various session state variables
|
57 |
+
init_logging_session_states() # logging init should be early
|
58 |
+
init_workflow_session_states()
|
59 |
+
init_input_data_session_states()
|
60 |
+
init_input_container_states()
|
61 |
+
init_workflow_viz()
|
62 |
+
init_classifier_session_states()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
def main() -> None:
|
66 |
"""
|
|
|
92 |
# Streamlit app
|
93 |
tab_inference, tab_hotdogs, tab_map, tab_coords, tab_log, tab_gallery = \
|
94 |
st.tabs(["Cetecean classifier", "Hotdog classifier", "Map", "*:gray[Dev:coordinates]*", "Log", "Beautiful cetaceans"])
|
|
|
95 |
|
96 |
+
# put this early so the progress indicator is at the top (also refreshed at end)
|
97 |
+
refresh_progress_display()
|
98 |
|
99 |
# create a sidebar, and parse all the input (returned as `observations` object)
|
100 |
+
with st.sidebar:
|
101 |
+
# layout handling
|
102 |
+
add_input_UI_elements()
|
103 |
+
# input elements (file upload, text input, etc)
|
104 |
+
setup_input()
|
105 |
|
106 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
with tab_map:
|
108 |
# visual structure: a couple of toggles at the top, then the map inlcuding a
|
109 |
# dropdown for tileset selection.
|
110 |
+
add_obs_map_header()
|
111 |
tab_map_ui_cols = st.columns(2)
|
112 |
with tab_map_ui_cols[0]:
|
113 |
show_db_points = st.toggle("Show Points from DB", True)
|
|
|
165 |
gallery.render_whale_gallery(n_cols=4)
|
166 |
|
167 |
|
168 |
+
# state handling re data_entry phases
|
169 |
+
# 0. no data entered yet -> display the file uploader thing
|
170 |
+
# 1. we have some images, but not all the metadata fields are done -> validate button shown, disabled
|
171 |
+
# 2. all data entered -> validate button enabled
|
172 |
+
# 3. validation button pressed, validation done -> enable the inference button.
|
173 |
+
# - at this point do we also want to disable changes to the metadata selectors?
|
174 |
+
# anyway, simple first.
|
175 |
+
|
176 |
+
if st.session_state.workflow_fsm.is_in_state('doing_data_entry'):
|
177 |
+
# can we advance state? - only when all inputs are set for all uploaded files
|
178 |
+
all_inputs_set = check_inputs_are_set(debug=True, empty_ok=False)
|
179 |
+
if all_inputs_set:
|
180 |
+
st.session_state.workflow_fsm.complete_current_state()
|
181 |
+
# -> data_entry_complete
|
182 |
+
else:
|
183 |
+
# button, disabled; no state change yet.
|
184 |
+
st.sidebar.button(":gray[*Validate*]", disabled=True, help="Please fill in all fields.")
|
185 |
+
|
186 |
+
|
187 |
+
if st.session_state.workflow_fsm.is_in_state('data_entry_complete'):
|
188 |
+
# can we advance state? - only when the validate button is pressed
|
189 |
+
if st.sidebar.button(":white_check_mark:[**Validate**]"):
|
190 |
+
# create a dictionary with the submitted observation
|
191 |
+
tab_log.info(f"{st.session_state.observations}")
|
192 |
+
df = pd.DataFrame([obs.to_dict() for obs in st.session_state.observations.values()])
|
193 |
+
#df = pd.DataFrame(st.session_state.observations, index=[0])
|
194 |
+
with tab_coords:
|
195 |
+
st.table(df)
|
196 |
+
# there doesn't seem to be any actual validation here?? TODO: find validator function (each element is validated by the input box, but is there something at the whole image level?)
|
197 |
+
# hmm, maybe it should actually just be "I'm done with data entry"
|
198 |
+
st.session_state.workflow_fsm.complete_current_state()
|
199 |
+
# -> data_entry_validated
|
200 |
+
|
201 |
+
# state handling re inference phases (tab_inference)
|
202 |
+
# 3. validation button pressed, validation done -> enable the inference button.
|
203 |
+
# 4. inference button pressed -> ML started. | let's cut this one out, since it would only
|
204 |
+
# make sense if we did it as an async action
|
205 |
+
# 5. ML done -> show results, and manual validation options
|
206 |
+
# 6. manual validation done -> enable the upload buttons
|
207 |
+
#
|
208 |
+
with tab_inference:
|
209 |
+
# inside the inference tab, on button press we call the model (on huggingface hub)
|
210 |
+
# which will be run locally.
|
211 |
+
# - the model predicts the top 3 most likely species from the input image
|
212 |
+
# - these species are shown
|
213 |
+
# - the user can override the species prediction using the dropdown
|
214 |
+
# - an observation is uploaded if the user chooses.
|
215 |
|
216 |
|
217 |
+
if st.session_state.MODE_DEV_STATEFUL:
|
218 |
+
dbg_show_observation_hashes()
|
219 |
+
|
220 |
+
add_classifier_header()
|
221 |
+
# if we are before data_entry_validated, show the button, disabled.
|
222 |
+
if not st.session_state.workflow_fsm.is_in_state_or_beyond('data_entry_validated'):
|
223 |
+
tab_inference.button(":gray[*Identify with cetacean classifier*]", disabled=True,
|
224 |
+
help="Please validate inputs before proceeding",
|
225 |
+
key="button_infer_ceteans")
|
|
|
226 |
|
227 |
+
if st.session_state.workflow_fsm.is_in_state('data_entry_validated'):
|
228 |
+
# show the button, enabled. If pressed, we start the ML model (And advance state)
|
229 |
+
if tab_inference.button("Identify with cetacean classifier"):
|
230 |
+
cetacean_classifier = AutoModelForImageClassification.from_pretrained(
|
231 |
+
"Saving-Willy/cetacean-classifier",
|
232 |
+
revision=classifier_revision,
|
233 |
+
trust_remote_code=True)
|
234 |
+
|
235 |
+
cetacean_just_classify(cetacean_classifier)
|
236 |
+
st.session_state.workflow_fsm.complete_current_state()
|
237 |
+
# trigger a refresh too (refreshhing the prog indicator means the script reruns and
|
238 |
+
# we can enter the next state - visualising the results / review)
|
239 |
+
# ok it doesn't if done programmatically. maybe interacting with teh button? check docs.
|
240 |
+
refresh_progress_display()
|
241 |
+
#TODO: validate this doesn't harm performance adversely.
|
242 |
+
st.rerun()
|
243 |
|
244 |
+
elif st.session_state.workflow_fsm.is_in_state('ml_classification_completed'):
|
245 |
+
# show the results, and allow manual validation
|
246 |
+
st.markdown("""### Inference results and manual validation/adjustment """)
|
247 |
+
if st.session_state.MODE_DEV_STATEFUL:
|
248 |
+
s = ""
|
249 |
+
for k, v in st.session_state.whale_prediction1.items():
|
250 |
+
s += f"* Image {k}: {v}\n"
|
251 |
+
|
252 |
+
st.markdown(s)
|
253 |
+
|
254 |
+
# add a button to advance the state
|
255 |
+
if st.button("Confirm species predictions", help="Confirm that all species are selected correctly"):
|
256 |
+
st.session_state.workflow_fsm.complete_current_state()
|
257 |
+
# -> manual_inspection_completed
|
258 |
+
st.rerun()
|
259 |
+
|
260 |
+
cetacean_show_results_and_review()
|
261 |
+
|
262 |
+
elif st.session_state.workflow_fsm.is_in_state('manual_inspection_completed'):
|
263 |
+
# show the ML results, and allow the user to upload the observation
|
264 |
+
st.markdown("""### Inference Results (after manual validation) """)
|
265 |
+
|
266 |
+
|
267 |
+
if st.button("Upload all observations to THE INTERNET!"):
|
268 |
+
# let this go through to the push_all func, since it just reports to log for now.
|
269 |
+
push_all_observations(enable_push=False)
|
270 |
+
st.session_state.workflow_fsm.complete_current_state()
|
271 |
+
# -> data_uploaded
|
272 |
+
st.rerun()
|
273 |
+
|
274 |
+
cetacean_show_results()
|
275 |
|
276 |
+
elif st.session_state.workflow_fsm.is_in_state('data_uploaded'):
|
277 |
+
# the data has been sent. Lets show the observations again
|
278 |
+
# but no buttons to upload (or greyed out ok)
|
279 |
+
st.markdown("""### Observation(s) uploaded - thank you!""")
|
280 |
+
cetacean_show_results()
|
281 |
+
|
282 |
+
st.divider()
|
283 |
+
#df = pd.DataFrame(st.session_state.observations, index=[0])
|
284 |
+
df = pd.DataFrame([obs.to_dict() for obs in st.session_state.observations.values()])
|
285 |
+
st.table(df)
|
286 |
|
287 |
+
# didn't decide what the next state is here - I think we are in the terminal state.
|
288 |
+
#st.session_state.workflow_fsm.complete_current_state()
|
289 |
+
|
290 |
|
291 |
# inside the hotdog tab, on button press we call a 2nd model (totally unrelated at present, just for demo
|
292 |
# purposes, an hotdog image classifier) which will be run locally.
|
|
|
310 |
hotdog_classify(pipeline_hot_dog, tab_hotdogs)
|
311 |
|
312 |
|
313 |
+
# after all other processing, we can show the stage/state
|
314 |
+
refresh_progress_display()
|
315 |
+
|
316 |
|
317 |
if __name__ == "__main__":
|
318 |
main()
|
src/maps/obs_map.py
CHANGED
@@ -192,8 +192,8 @@ def present_obs_map(dataset_id:str = "Saving-Willy/Happywhale-kaggle",
|
|
192 |
return st_data
|
193 |
|
194 |
|
195 |
-
def
|
196 |
"""
|
197 |
Add brief explainer text to the tab
|
198 |
"""
|
199 |
-
st.write("A map showing the observations in the dataset, with markers colored by species.")
|
|
|
192 |
return st_data
|
193 |
|
194 |
|
195 |
+
def add_obs_map_header() -> None:
|
196 |
"""
|
197 |
Add brief explainer text to the tab
|
198 |
"""
|
199 |
+
st.write("A map showing the observations in the dataset, with markers colored by species.")
|
src/utils/metadata_handler.py
CHANGED
@@ -1,16 +1,26 @@
|
|
1 |
import streamlit as st
|
2 |
|
3 |
-
def metadata2md() -> str:
|
4 |
"""Get metadata from cache and return as markdown-formatted key-value list
|
5 |
|
|
|
|
|
|
|
|
|
6 |
Returns:
|
7 |
str: Markdown-formatted key-value list of metadata
|
8 |
|
9 |
"""
|
10 |
markdown_str = "\n"
|
11 |
-
keys_to_print = ["
|
12 |
-
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
return markdown_str
|
16 |
|
|
|
1 |
import streamlit as st
|
2 |
|
3 |
+
def metadata2md(image_hash:str, debug:bool=False) -> str:
|
4 |
"""Get metadata from cache and return as markdown-formatted key-value list
|
5 |
|
6 |
+
Args:
|
7 |
+
image_hash (str): The hash of the image to get metadata for
|
8 |
+
debug (bool, optional): Whether to print additional fields.
|
9 |
+
|
10 |
Returns:
|
11 |
str: Markdown-formatted key-value list of metadata
|
12 |
|
13 |
"""
|
14 |
markdown_str = "\n"
|
15 |
+
keys_to_print = ["author_email", "latitude", "longitude", "date", "time"]
|
16 |
+
if debug:
|
17 |
+
keys_to_print += ["iamge_md5", "selected_class", "top_prediction", "class_overriden"]
|
18 |
+
|
19 |
+
observation = st.session_state.public_observations.get(image_hash, {})
|
20 |
+
|
21 |
+
for key, value in observation.items():
|
22 |
+
if key in keys_to_print:
|
23 |
+
markdown_str += f"- **{key}**: {value}\n"
|
24 |
+
|
25 |
return markdown_str
|
26 |
|
src/utils/st_logs.py
CHANGED
@@ -100,6 +100,16 @@ class StreamlitLogHandler(logging.Handler):
|
|
100 |
self.log_area.empty() # Clear previous logs
|
101 |
self.buffer.clear()
|
102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
# Set up logging to capture all info level logs from the root logger
|
104 |
@st.cache_resource
|
105 |
def setup_logging(level:int=logging.INFO, buffer_len:int=15) -> StreamlitLogHandler:
|
@@ -126,6 +136,7 @@ def setup_logging(level:int=logging.INFO, buffer_len:int=15) -> StreamlitLogHand
|
|
126 |
# st.session_state['handler'] = handler
|
127 |
return handler
|
128 |
|
|
|
129 |
def parse_log_buffer(log_contents: deque) -> List[dict]:
|
130 |
"""
|
131 |
Convert log buffer to a list of dictionaries for use with a streamlit datatable.
|
|
|
100 |
self.log_area.empty() # Clear previous logs
|
101 |
self.buffer.clear()
|
102 |
|
103 |
+
|
104 |
+
def init_logging_session_states():
|
105 |
+
"""
|
106 |
+
Initialise the session state variables for logging.
|
107 |
+
"""
|
108 |
+
|
109 |
+
if "handler" not in st.session_state:
|
110 |
+
st.session_state['handler'] = setup_logging()
|
111 |
+
|
112 |
+
|
113 |
# Set up logging to capture all info level logs from the root logger
|
114 |
@st.cache_resource
|
115 |
def setup_logging(level:int=logging.INFO, buffer_len:int=15) -> StreamlitLogHandler:
|
|
|
136 |
# st.session_state['handler'] = handler
|
137 |
return handler
|
138 |
|
139 |
+
|
140 |
def parse_log_buffer(log_contents: deque) -> List[dict]:
|
141 |
"""
|
142 |
Convert log buffer to a list of dictionaries for use with a streamlit datatable.
|
src/utils/workflow_state.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transitions import Machine
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
OKBLUE = '\033[94m'
|
5 |
+
OKGREEN = '\033[92m'
|
6 |
+
OKCYAN = '\033[96m'
|
7 |
+
FAIL = '\033[91m'
|
8 |
+
ENDC = '\033[0m'
|
9 |
+
|
10 |
+
|
11 |
+
FSM_STATES = ['doing_data_entry', 'data_entry_complete', 'data_entry_validated',
|
12 |
+
#'ml_classification_started',
|
13 |
+
'ml_classification_completed',
|
14 |
+
'manual_inspection_completed', 'data_uploaded']
|
15 |
+
|
16 |
+
|
17 |
+
class WorkflowFSM:
|
18 |
+
def __init__(self, state_sequence: List[str]):
|
19 |
+
self.state_sequence = state_sequence
|
20 |
+
self.state_dict = {state: i for i, state in enumerate(state_sequence)}
|
21 |
+
|
22 |
+
# Create state machine
|
23 |
+
self.machine = Machine(
|
24 |
+
model=self,
|
25 |
+
states=state_sequence,
|
26 |
+
initial=state_sequence[0],
|
27 |
+
)
|
28 |
+
|
29 |
+
# For each state (except the last), add a completion transition to the next state
|
30 |
+
for i in range(len(state_sequence) - 1):
|
31 |
+
current_state = state_sequence[i]
|
32 |
+
next_state = state_sequence[i + 1]
|
33 |
+
|
34 |
+
self.machine.add_transition(
|
35 |
+
trigger=f'complete_{current_state}',
|
36 |
+
source=current_state,
|
37 |
+
dest=next_state,
|
38 |
+
conditions=[f'is_in_{current_state}']
|
39 |
+
)
|
40 |
+
|
41 |
+
# Dynamically add a condition method for each state
|
42 |
+
setattr(self, f'is_in_{current_state}',
|
43 |
+
lambda s=current_state: self.is_in_state(s))
|
44 |
+
|
45 |
+
# Add callbacks for logging
|
46 |
+
self.machine.before_state_change = self._log_transition
|
47 |
+
self.machine.after_state_change = self._post_transition
|
48 |
+
|
49 |
+
def is_in_state(self, state_name: str) -> bool:
|
50 |
+
"""Check if we're currently in the specified state"""
|
51 |
+
return self.state == state_name
|
52 |
+
|
53 |
+
def complete_current_state(self) -> bool:
|
54 |
+
"""
|
55 |
+
Signal that the current state is complete.
|
56 |
+
Returns True if state transition occurred, False otherwise.
|
57 |
+
"""
|
58 |
+
current_state = self.state
|
59 |
+
trigger_name = f'complete_{current_state}'
|
60 |
+
|
61 |
+
if hasattr(self, trigger_name):
|
62 |
+
try:
|
63 |
+
trigger_func = getattr(self, trigger_name)
|
64 |
+
trigger_func()
|
65 |
+
return True
|
66 |
+
except:
|
67 |
+
return False
|
68 |
+
return False
|
69 |
+
|
70 |
+
# add a helper method, to find out if a given state has been reached/passed
|
71 |
+
# we first need to get the index of the current state
|
72 |
+
# then the index of the argument state
|
73 |
+
# compare, and return boolean
|
74 |
+
|
75 |
+
def is_in_state_or_beyond(self, state_name: str) -> bool:
|
76 |
+
"""Check if we have reached or passed the specified state"""
|
77 |
+
if state_name not in self.state_dict:
|
78 |
+
raise ValueError(f"Invalid state: {state_name}")
|
79 |
+
|
80 |
+
return self.state_dict[state_name] <= self.state_dict[self.state]
|
81 |
+
|
82 |
+
|
83 |
+
@property
|
84 |
+
def current_state(self) -> str:
|
85 |
+
"""Get the current state name"""
|
86 |
+
return self.state
|
87 |
+
|
88 |
+
@property
|
89 |
+
def current_state_index(self) -> int:
|
90 |
+
"""Get the current state index"""
|
91 |
+
return self.state_dict[self.state]
|
92 |
+
|
93 |
+
@property
|
94 |
+
def num_states(self) -> int:
|
95 |
+
return len(self.state_sequence)
|
96 |
+
|
97 |
+
|
98 |
+
def _log_transition(self):
|
99 |
+
# TODO: use logger, not printing.
|
100 |
+
self._cprint(f"[FSM] -> Transitioning from {self.current_state}")
|
101 |
+
|
102 |
+
def _post_transition(self):
|
103 |
+
# TODO: use logger, not printing.
|
104 |
+
self._cprint(f"[FSM] -| Transitioned to {self.current_state}")
|
105 |
+
|
106 |
+
def _cprint(self, msg:str, color:str=OKCYAN):
|
107 |
+
"""Print colored message"""
|
108 |
+
print(f"{color}{msg}{ENDC}")
|
src/utils/workflow_ui.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from utils.workflow_state import WorkflowFSM, FSM_STATES
|
3 |
+
|
4 |
+
def init_workflow_session_states():
|
5 |
+
"""
|
6 |
+
Initialise the session state variables for the workflow state machine
|
7 |
+
"""
|
8 |
+
|
9 |
+
if "workflow_fsm" not in st.session_state:
|
10 |
+
# create and init the state machine
|
11 |
+
st.session_state.workflow_fsm = WorkflowFSM(FSM_STATES)
|
12 |
+
|
13 |
+
def refresh_progress_display() -> None:
|
14 |
+
"""
|
15 |
+
Updates the workflow progress display in the Streamlit sidebar.
|
16 |
+
"""
|
17 |
+
with st.sidebar:
|
18 |
+
num_states = st.session_state.workflow_fsm.num_states - 1
|
19 |
+
current_state_index = st.session_state.workflow_fsm.current_state_index
|
20 |
+
current_state_name = st.session_state.workflow_fsm.current_state
|
21 |
+
status = f"*Progress: {current_state_index}/{num_states}. Current: {current_state_name}.*"
|
22 |
+
|
23 |
+
st.session_state.disp_progress[0].markdown(status)
|
24 |
+
st.session_state.disp_progress[1].progress(current_state_index/num_states)
|
25 |
+
|
26 |
+
|
27 |
+
def init_workflow_viz(debug:bool=True) -> None:
|
28 |
+
"""
|
29 |
+
Set up the streamlit elements for visualising the workflow progress.
|
30 |
+
|
31 |
+
Adds placeholders for progress indicators, and adds a button to manually refresh
|
32 |
+
the displayed progress. Note: The button is mainly a development aid.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
debug (bool): If True, include the manual refresh button. Default is True.
|
36 |
+
|
37 |
+
"""
|
38 |
+
|
39 |
+
|
40 |
+
#Initialise the layout containers used in the input handling
|
41 |
+
# add progress indicator to session_state
|
42 |
+
if "progress" not in st.session_state:
|
43 |
+
with st.sidebar:
|
44 |
+
st.session_state.disp_progress = [st.empty(), st.empty()]
|
45 |
+
if debug:
|
46 |
+
# add button to sidebar, with the callback to refesh_progress
|
47 |
+
st.sidebar.button("Refresh Progress", on_click=refresh_progress_display)
|
48 |
+
|
src/whale_viewer.py
CHANGED
@@ -115,6 +115,9 @@ def format_whale_name(whale_class:str) -> str:
|
|
115 |
Returns:
|
116 |
str: The formatted whale name with spaces instead of underscores and each word capitalized.
|
117 |
"""
|
|
|
|
|
|
|
118 |
whale_name = whale_class.replace("_", " ").title()
|
119 |
return whale_name
|
120 |
|
|
|
115 |
Returns:
|
116 |
str: The formatted whale name with spaces instead of underscores and each word capitalized.
|
117 |
"""
|
118 |
+
if not isinstance(whale_class, str):
|
119 |
+
raise TypeError("whale_class should be a string.")
|
120 |
+
|
121 |
whale_name = whale_class.replace("_", " ").title()
|
122 |
return whale_name
|
123 |
|
tests/test_input_handling.py
CHANGED
@@ -51,9 +51,6 @@ def test_is_valid_email_invalid():
|
|
51 |
assert not is_valid_email("[email protected].")
|
52 |
assert not is_valid_email("a@[email protected]")
|
53 |
|
54 |
-
# not sure how xfails come through the CI pipeline yet.
|
55 |
-
# maybe better to just comment out this stuff until pipeline is setup, then can check /extend
|
56 |
-
@pytest.mark.xfail(reason="Bug identified, but while setting up CI having failing tests causes more headache")
|
57 |
def test_is_valid_email_invalid_plus():
|
58 |
assert not is_valid_email("[email protected]")
|
59 |
assert not is_valid_email("[email protected]")
|
@@ -143,7 +140,7 @@ def test_get_image_latlon():
|
|
143 |
|
144 |
# missing GPS loc
|
145 |
f2 = test_data_pth / 'cakes_no_exif_gps.jpg'
|
146 |
-
assert get_image_latlon(f2) == None
|
147 |
|
148 |
# missng datetime -> expect gps not affected
|
149 |
f3 = test_data_pth / 'cakes_no_exif_datetime.jpg'
|
@@ -151,7 +148,7 @@ def test_get_image_latlon():
|
|
151 |
|
152 |
# tests for get_image_latlon with empty file
|
153 |
def test_get_image_latlon_empty():
|
154 |
-
assert get_image_latlon("") == None
|
155 |
|
156 |
# tests for decimal_coords
|
157 |
# - without input, py raises TypeError
|
|
|
51 |
assert not is_valid_email("[email protected].")
|
52 |
assert not is_valid_email("a@[email protected]")
|
53 |
|
|
|
|
|
|
|
54 |
def test_is_valid_email_invalid_plus():
|
55 |
assert not is_valid_email("[email protected]")
|
56 |
assert not is_valid_email("[email protected]")
|
|
|
140 |
|
141 |
# missing GPS loc
|
142 |
f2 = test_data_pth / 'cakes_no_exif_gps.jpg'
|
143 |
+
assert get_image_latlon(f2) == (None, None)
|
144 |
|
145 |
# missng datetime -> expect gps not affected
|
146 |
f3 = test_data_pth / 'cakes_no_exif_datetime.jpg'
|
|
|
148 |
|
149 |
# tests for get_image_latlon with empty file
|
150 |
def test_get_image_latlon_empty():
|
151 |
+
assert get_image_latlon("") == (None, None)
|
152 |
|
153 |
# tests for decimal_coords
|
154 |
# - without input, py raises TypeError
|
tests/test_whale_viewer.py
CHANGED
@@ -40,11 +40,9 @@ def test_format_whale_name_empty():
|
|
40 |
assert format_whale_name("") == ""
|
41 |
|
42 |
# testing with the wrong datatype
|
43 |
-
# we should get a TypeError - currently it fails with a AttributeError
|
44 |
-
@pytest.mark.xfail
|
45 |
def test_format_whale_name_none():
|
46 |
with pytest.raises(TypeError):
|
47 |
format_whale_name(None)
|
48 |
|
49 |
|
50 |
-
# display_whale requires UI to test it.
|
|
|
40 |
assert format_whale_name("") == ""
|
41 |
|
42 |
# testing with the wrong datatype
|
|
|
|
|
43 |
def test_format_whale_name_none():
|
44 |
with pytest.raises(TypeError):
|
45 |
format_whale_name(None)
|
46 |
|
47 |
|
48 |
+
# display_whale requires UI to test it.
|