vancauwe commited on
Commit
8ccb11f
·
unverified ·
2 Parent(s): 41bbd4a 8e4ef44

Merge pull request #29 from sdsc-ordes/feat/stateful-workflow

Browse files
.github/workflows/python-pycov-onPR.yml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This workflow will install dependencies, create coverage tests and run Pytest Coverage Commentator
2
+ # For more information see: https://github.com/coroo/pytest-coverage-commentator
3
+ name: pytest-coverage-in-PR
4
+ on:
5
+ pull_request:
6
+ branches:
7
+ - '*'
8
+ jobs:
9
+ build:
10
+ runs-on: ubuntu-latest
11
+ permissions:
12
+ contents: write
13
+ pull-requests: write
14
+ steps:
15
+ - uses: actions/checkout@v4
16
+ - name: Set up Python 3.10
17
+ uses: actions/setup-python@v3
18
+ with:
19
+ python-version: "3.10"
20
+ - name: Install dependencies
21
+ run: |
22
+ python -m pip install --upgrade pip
23
+ if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
24
+ if [ -f tests/requirements.txt ]; then pip install -r tests/requirements.txt; fi
25
+
26
+ - name: Build coverage files for mishakav commenter action
27
+ run: |
28
+ pytest --junitxml=pytest.xml --cov-report=term-missing:skip-covered --cov=src tests/ | tee pytest-coverage.txt
29
+ echo "working dir:" && pwd
30
+ echo "files in cwd:" && ls -ltr
31
+
32
+ - name: Pytest coverage comment
33
+ uses: MishaKav/pytest-coverage-comment@main
34
+ with:
35
+ pytest-coverage-path: ./pytest-coverage.txt
36
+ junitxml-path: ./pytest.xml
37
+
38
+ #- name: Comment coverage
39
+ # uses: coroo/[email protected]
requirements.txt CHANGED
@@ -10,7 +10,8 @@ streamlit_folium==0.23.1
10
 
11
  # backend
12
  datasets==3.0.2
13
-
 
14
 
15
  # running ML models
16
 
 
10
 
11
  # backend
12
  datasets==3.0.2
13
+ ## FSM
14
+ transitions==0.9.2
15
 
16
  # running ML models
17
 
src/classifier/classifier_image.py CHANGED
@@ -10,13 +10,207 @@ import whale_viewer as viewer
10
  from hf_push_observations import push_observations
11
  from utils.grid_maker import gridder
12
  from utils.metadata_handler import metadata2md
 
13
 
14
- def cetacean_classify(cetacean_classifier):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  """Cetacean classifier using the saving-willy model from Saving Willy Hugging Face space.
16
  For each image in the session state, classify the image and display the top 3 predictions.
17
  Args:
18
  cetacean_classifier ([type]): saving-willy model from Saving Willy Hugging Face space
19
  """
 
20
  images = st.session_state.images
21
  observations = st.session_state.observations
22
  hashes = st.session_state.image_hashes
@@ -33,25 +227,25 @@ def cetacean_classify(cetacean_classifier):
33
  observation = observations[hash].to_dict()
34
  # run classifier model on `image`, and persistently store the output
35
  out = cetacean_classifier(image) # get top 3 matches
36
- st.session_state.whale_prediction1 = out['predictions'][0]
37
- st.session_state.classify_whale_done = True
38
- msg = f"[D]2 classify_whale_done: {st.session_state.classify_whale_done}, whale_prediction1: {st.session_state.whale_prediction1}"
39
  g_logger.info(msg)
40
 
41
  # dropdown for selecting/overriding the species prediction
42
- if not st.session_state.classify_whale_done:
43
  selected_class = st.sidebar.selectbox("Species", viewer.WHALE_CLASSES,
44
  index=None, placeholder="Species not yet identified...",
45
  disabled=True)
46
  else:
47
- pred1 = st.session_state.whale_prediction1
48
  # get index of pred1 from WHALE_CLASSES, none if not present
49
  print(f"[D] pred1: {pred1}")
50
  ix = viewer.WHALE_CLASSES.index(pred1) if pred1 in viewer.WHALE_CLASSES else None
51
  selected_class = st.selectbox(f"Species for observation {str(o)}", viewer.WHALE_CLASSES, index=ix)
52
 
53
  observation['predicted_class'] = selected_class
54
- if selected_class != st.session_state.whale_prediction1:
55
  observation['class_overriden'] = selected_class
56
 
57
  st.session_state.public_observation = observation
@@ -70,4 +264,4 @@ def cetacean_classify(cetacean_classifier):
70
  for i in range(len(whale_classes)):
71
  viewer.display_whale(whale_classes, i)
72
  o += 1
73
- col = (col + 1) % row_size
 
10
  from hf_push_observations import push_observations
11
  from utils.grid_maker import gridder
12
  from utils.metadata_handler import metadata2md
13
+ from input.input_observation import InputObservation
14
 
15
+ def init_classifier_session_states() -> None:
16
+ '''
17
+ Initialise the session state variables used in classification
18
+ '''
19
+ if "classify_whale_done" not in st.session_state:
20
+ st.session_state.classify_whale_done = {}
21
+
22
+ if "whale_prediction1" not in st.session_state:
23
+ st.session_state.whale_prediction1 = {}
24
+
25
+
26
+ def add_classifier_header() -> None:
27
+ """
28
+ Add brief explainer text about cetacean classification to the tab
29
+ """
30
+ st.markdown("""
31
+ *Run classifer to identify the species of cetean on the uploaded image.
32
+ Once inference is complete, the top three predictions are shown.
33
+ You can override the prediction by selecting a species from the dropdown.*""")
34
+
35
+
36
+ # func to just run classification, store results.
37
+ def cetacean_just_classify(cetacean_classifier):
38
+ """
39
+ Infer cetacean species for all observations in the session state.
40
+
41
+ - this function runs the classifier, and stores results in the session state.
42
+ - the top 3 predictions are stored in the observation object, which is retained
43
+ in st.session_state.observations
44
+ - to display results use cetacean_show_results() or cetacean_show_results_and_review()
45
+
46
+ Args:
47
+ cetacean_classifier ([type]): saving-willy model from Saving Willy Hugging Face space
48
+ """
49
+
50
+ images = st.session_state.images
51
+ #observations = st.session_state.observations
52
+ hashes = st.session_state.image_hashes
53
+
54
+ for hash in hashes:
55
+ image = images[hash]
56
+ # run classifier model on `image`, and persistently store the output
57
+ out = cetacean_classifier(image) # get top 3 matches
58
+ st.session_state.whale_prediction1[hash] = out['predictions'][0]
59
+ st.session_state.classify_whale_done[hash] = True
60
+ st.session_state.observations[hash].set_top_predictions(out['predictions'][:])
61
+
62
+ msg = f"[D]2 classify_whale_done for {hash}: {st.session_state.classify_whale_done[hash]}, whale_prediction1: {st.session_state.whale_prediction1[hash]}"
63
+ g_logger.info(msg)
64
+
65
+ if st.session_state.MODE_DEV_STATEFUL:
66
+ st.write(f"*[D] Observation {hash} classified as {st.session_state.whale_prediction1[hash]}*")
67
+
68
+
69
+ # func to show results and allow review
70
+ def cetacean_show_results_and_review() -> None:
71
+ """
72
+ Present classification results and allow user to review and override the prediction.
73
+
74
+ - for each observation in the session state, displays the image, summarised
75
+ metadata, and the top 3 predictions.
76
+ - allows user to override the prediction by selecting a species from the dropdown.
77
+ - the selected species is stored in the observation object, which is retained in
78
+ st.session_state.observations
79
+
80
+ """
81
+
82
+ images = st.session_state.images
83
+ observations = st.session_state.observations
84
+ hashes = st.session_state.image_hashes
85
+ batch_size, row_size, page = gridder(hashes)
86
+
87
+ grid = st.columns(row_size)
88
+ col = 0
89
+ o = 1
90
+
91
+ for hash in hashes:
92
+ image = images[hash]
93
+ #observation = observations[hash].to_dict()
94
+ _observation:InputObservation = observations[hash]
95
+
96
+ with grid[col]:
97
+ st.image(image, use_column_width=True)
98
+
99
+ # dropdown for selecting/overriding the species prediction
100
+ if not st.session_state.classify_whale_done[hash]:
101
+ selected_class = st.sidebar.selectbox("Species", viewer.WHALE_CLASSES,
102
+ index=None, placeholder="Species not yet identified...",
103
+ disabled=True)
104
+ else:
105
+ pred1 = st.session_state.whale_prediction1[hash]
106
+ # get index of pred1 from WHALE_CLASSES, none if not present
107
+ print(f"[D] {o:3} pred1: {pred1:30} | {hash}")
108
+ ix = viewer.WHALE_CLASSES.index(pred1) if pred1 in viewer.WHALE_CLASSES else None
109
+ selected_class = st.selectbox(f"Species for observation {str(o)}", viewer.WHALE_CLASSES, index=ix)
110
+
111
+ _observation.set_selected_class(selected_class)
112
+ #observation['predicted_class'] = selected_class
113
+ # this logic is now in the InputObservation class automatially
114
+ #if selected_class != st.session_state.whale_prediction1[hash]:
115
+ # observation['class_overriden'] = selected_class # TODO: this should be boolean!
116
+
117
+ # store the elements of the observation that will be transmitted (not image)
118
+ observation = _observation.to_dict()
119
+ st.session_state.public_observations[hash] = observation
120
+
121
+ #st.button(f"Upload observation {str(o)} to THE INTERNET!", on_click=push_observations)
122
+ # TODO: the metadata only fills properly if `validate` was clicked.
123
+ st.markdown(metadata2md(hash, debug=True))
124
+
125
+ msg = f"[D] full observation after inference: {observation}"
126
+ g_logger.debug(msg)
127
+ print(msg)
128
+ # TODO: add a link to more info on the model, next to the button.
129
+
130
+ whale_classes = observations[hash].top_predictions
131
+ # render images for the top 3 (that is what the model api returns)
132
+ n = len(whale_classes)
133
+ st.markdown(f"**Top {n} Predictions for observation {str(o)}**")
134
+ for i in range(n):
135
+ viewer.display_whale(whale_classes, i)
136
+ o += 1
137
+ col = (col + 1) % row_size
138
+
139
+
140
+ # func to just present results
141
+ def cetacean_show_results():
142
+ """
143
+ Present classification results that may be pushed to the online dataset.
144
+
145
+ - for each observation in the session state, displays the image, summarised
146
+ metadata, the top 3 predictions, and the selected species (which may have
147
+ been manually selected, or the top prediction accepted).
148
+
149
+ """
150
+ images = st.session_state.images
151
+ observations = st.session_state.observations
152
+ hashes = st.session_state.image_hashes
153
+ batch_size, row_size, page = gridder(hashes)
154
+
155
+
156
+ grid = st.columns(row_size)
157
+ col = 0
158
+ o = 1
159
+
160
+ for hash in hashes:
161
+ image = images[hash]
162
+ observation = observations[hash].to_dict()
163
+
164
+ with grid[col]:
165
+ st.image(image, use_column_width=True)
166
+
167
+ # # dropdown for selecting/overriding the species prediction
168
+ # if not st.session_state.classify_whale_done[hash]:
169
+ # selected_class = st.sidebar.selectbox("Species", viewer.WHALE_CLASSES,
170
+ # index=None, placeholder="Species not yet identified...",
171
+ # disabled=True)
172
+ # else:
173
+ # pred1 = st.session_state.whale_prediction1[hash]
174
+ # # get index of pred1 from WHALE_CLASSES, none if not present
175
+ # print(f"[D] pred1: {pred1}")
176
+ # ix = viewer.WHALE_CLASSES.index(pred1) if pred1 in viewer.WHALE_CLASSES else None
177
+ # selected_class = st.selectbox(f"Species for observation {str(o)}", viewer.WHALE_CLASSES, index=ix)
178
+
179
+ # observation['predicted_class'] = selected_class
180
+ # if selected_class != st.session_state.whale_prediction1[hash]:
181
+ # observation['class_overriden'] = selected_class # TODO: this should be boolean!
182
+
183
+ # st.session_state.public_observation = observation
184
+
185
+ #st.button(f"Upload observation {str(o)} to THE INTERNET!", on_click=push_observations)
186
+ #
187
+ st.markdown(metadata2md(hash, debug=True))
188
+
189
+ msg = f"[D] full observation after inference: {observation}"
190
+ g_logger.debug(msg)
191
+ print(msg)
192
+ # TODO: add a link to more info on the model, next to the button.
193
+
194
+ whale_classes = observations[hash].top_predictions
195
+ # render images for the top 3 (that is what the model api returns)
196
+ n = len(whale_classes)
197
+ st.markdown(f"**Top {n} Predictions for observation {str(o)}**")
198
+ for i in range(n):
199
+ viewer.display_whale(whale_classes, i)
200
+ o += 1
201
+ col = (col + 1) % row_size
202
+
203
+
204
+
205
+
206
+ # func to do all in one
207
+ def cetacean_classify_show_and_review(cetacean_classifier):
208
  """Cetacean classifier using the saving-willy model from Saving Willy Hugging Face space.
209
  For each image in the session state, classify the image and display the top 3 predictions.
210
  Args:
211
  cetacean_classifier ([type]): saving-willy model from Saving Willy Hugging Face space
212
  """
213
+ raise DeprecationWarning("This function is deprecated. Use individual steps instead")
214
  images = st.session_state.images
215
  observations = st.session_state.observations
216
  hashes = st.session_state.image_hashes
 
227
  observation = observations[hash].to_dict()
228
  # run classifier model on `image`, and persistently store the output
229
  out = cetacean_classifier(image) # get top 3 matches
230
+ st.session_state.whale_prediction1[hash] = out['predictions'][0]
231
+ st.session_state.classify_whale_done[hash] = True
232
+ msg = f"[D]2 classify_whale_done for {hash}: {st.session_state.classify_whale_done[hash]}, whale_prediction1: {st.session_state.whale_prediction1[hash]}"
233
  g_logger.info(msg)
234
 
235
  # dropdown for selecting/overriding the species prediction
236
+ if not st.session_state.classify_whale_done[hash]:
237
  selected_class = st.sidebar.selectbox("Species", viewer.WHALE_CLASSES,
238
  index=None, placeholder="Species not yet identified...",
239
  disabled=True)
240
  else:
241
+ pred1 = st.session_state.whale_prediction1[hash]
242
  # get index of pred1 from WHALE_CLASSES, none if not present
243
  print(f"[D] pred1: {pred1}")
244
  ix = viewer.WHALE_CLASSES.index(pred1) if pred1 in viewer.WHALE_CLASSES else None
245
  selected_class = st.selectbox(f"Species for observation {str(o)}", viewer.WHALE_CLASSES, index=ix)
246
 
247
  observation['predicted_class'] = selected_class
248
+ if selected_class != st.session_state.whale_prediction1[hash]:
249
  observation['class_overriden'] = selected_class
250
 
251
  st.session_state.public_observation = observation
 
264
  for i in range(len(whale_classes)):
265
  viewer.display_whale(whale_classes, i)
266
  o += 1
267
+ col = (col + 1) % row_size
src/classifier_image.py DELETED
@@ -1,70 +0,0 @@
1
- import streamlit as st
2
- import logging
3
- import os
4
-
5
- # get a global var for logger accessor in this module
6
- LOG_LEVEL = logging.DEBUG
7
- g_logger = logging.getLogger(__name__)
8
- g_logger.setLevel(LOG_LEVEL)
9
-
10
- from grid_maker import gridder
11
- import hf_push_observations as sw_push_obs
12
- import utils.metadata_handler as meta_handler
13
- import whale_viewer as sw_wv
14
-
15
- def cetacean_classify(cetacean_classifier, tab_inference):
16
- files = st.session_state.files
17
- images = st.session_state.images
18
- observations = st.session_state.observations
19
-
20
- batch_size, row_size, page = gridder(files)
21
-
22
- grid = st.columns(row_size)
23
- col = 0
24
-
25
- for file in files:
26
- image = images[file.name]
27
-
28
- with grid[col]:
29
- st.image(image, use_column_width=True)
30
- observation = observations[file.name].to_dict()
31
- # run classifier model on `image`, and persistently store the output
32
- out = cetacean_classifier(image) # get top 3 matches
33
- st.session_state.whale_prediction1 = out['predictions'][0]
34
- st.session_state.classify_whale_done = True
35
- msg = f"[D]2 classify_whale_done: {st.session_state.classify_whale_done}, whale_prediction1: {st.session_state.whale_prediction1}"
36
- g_logger.info(msg)
37
-
38
- # dropdown for selecting/overriding the species prediction
39
- if not st.session_state.classify_whale_done:
40
- selected_class = st.sidebar.selectbox("Species", sw_wv.WHALE_CLASSES,
41
- index=None, placeholder="Species not yet identified...",
42
- disabled=True)
43
- else:
44
- pred1 = st.session_state.whale_prediction1
45
- # get index of pred1 from WHALE_CLASSES, none if not present
46
- print(f"[D] pred1: {pred1}")
47
- ix = sw_wv.WHALE_CLASSES.index(pred1) if pred1 in sw_wv.WHALE_CLASSES else None
48
- selected_class = tab_inference.selectbox("Species", sw_wv.WHALE_CLASSES, index=ix)
49
-
50
- observation['predicted_class'] = selected_class
51
- if selected_class != st.session_state.whale_prediction1:
52
- observation['class_overriden'] = selected_class
53
-
54
- st.session_state.public_observation = observation
55
- st.button(f"Upload observation for {file.name} to THE INTERNET!", on_click=sw_push_obs.push_observations)
56
- # TODO: the metadata only fills properly if `validate` was clicked.
57
- st.markdown(meta_handler.metadata2md())
58
-
59
- msg = f"[D] full observation after inference: {observation}"
60
- g_logger.debug(msg)
61
- print(msg)
62
- # TODO: add a link to more info on the model, next to the button.
63
-
64
- whale_classes = out['predictions'][:]
65
- # render images for the top 3 (that is what the model api returns)
66
- #with tab_inference:
67
- st.title(f"Species detected for {file.name}")
68
- for i in range(len(whale_classes)):
69
- sw_wv.display_whale(whale_classes, i)
70
- col = (col + 1) % row_size
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/hf_push_observations.py CHANGED
@@ -1,15 +1,82 @@
1
- from streamlit.delta_generator import DeltaGenerator
2
- import streamlit as st
3
- from huggingface_hub import HfApi
4
  import json
5
  import tempfile
6
  import logging
7
 
 
 
 
 
 
8
  # get a global var for logger accessor in this module
9
  LOG_LEVEL = logging.DEBUG
10
  g_logger = logging.getLogger(__name__)
11
  g_logger.setLevel(LOG_LEVEL)
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  def push_observations(tab_log:DeltaGenerator=None):
14
  """
15
  Push the observations to the Hugging Face dataset
@@ -20,17 +87,16 @@ def push_observations(tab_log:DeltaGenerator=None):
20
  push any observation since generating the logger)
21
 
22
  """
 
 
23
  # we get the observation from session state: 1 is the dict 2 is the image.
24
  # first, lets do an info display (popup)
25
  metadata_str = json.dumps(st.session_state.public_observation)
26
 
27
  st.toast(f"Uploading observations: {metadata_str}", icon="🦭")
28
- tab_log = st.session_state.tab_log
29
- if tab_log is not None:
30
- tab_log.info(f"Uploading observations: {metadata_str}")
31
 
32
  # get huggingface api
33
- import os
34
  token = os.environ.get("HF_TOKEN", None)
35
  api = HfApi(token=token)
36
 
@@ -53,4 +119,4 @@ def push_observations(tab_log:DeltaGenerator=None):
53
  # msg = f"observation attempted tx to repo happy walrus: {rv}"
54
  g_logger.info(msg)
55
  st.info(msg)
56
-
 
1
+ import os
 
 
2
  import json
3
  import tempfile
4
  import logging
5
 
6
+ from streamlit.delta_generator import DeltaGenerator
7
+ import streamlit as st
8
+ from huggingface_hub import HfApi, CommitInfo
9
+
10
+
11
  # get a global var for logger accessor in this module
12
  LOG_LEVEL = logging.DEBUG
13
  g_logger = logging.getLogger(__name__)
14
  g_logger.setLevel(LOG_LEVEL)
15
 
16
+ def push_observation(image_hash:str, api:HfApi, enable_push:False) -> CommitInfo:
17
+ '''
18
+ push one observation to the Hugging Face dataset
19
+
20
+ '''
21
+ # get the observation
22
+ observation = st.session_state.public_observations.get(image_hash)
23
+ if observation is None:
24
+ msg = f"Could not find observation with hash {image_hash}"
25
+ g_logger.error(msg)
26
+ st.error(msg)
27
+ return None
28
+
29
+ # convert to json
30
+ metadata_str = json.dumps(observation) # doesn't work yet, TODO
31
+
32
+ st.toast(f"Uploading observation: {metadata_str}", icon="🦭")
33
+ g_logger.info(f"Uploading observation: {metadata_str}")
34
+
35
+ # write to temp file so we can send it (why is this not using context mgr?)
36
+ f = tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False)
37
+ f.write(metadata_str)
38
+ f.close()
39
+ #st.info(f"temp file: {f.name} with metadata written...")
40
+
41
+ path_in_repo = f"metadata/{observation['author_email']}/{observation['image_md5']}.json"
42
+
43
+ msg = f"fname: {f.name} | path: {path_in_repo}"
44
+ print(msg)
45
+ st.warning(msg)
46
+
47
+ if enable_push:
48
+ rv = api.upload_file(
49
+ path_or_fileobj=f.name,
50
+ path_in_repo=path_in_repo,
51
+ repo_id="Saving-Willy/temp_dataset",
52
+ repo_type="dataset",
53
+ )
54
+ print(rv)
55
+ msg = f"observation attempted tx to repo happy walrus: {rv}"
56
+ g_logger.info(msg)
57
+ st.info(msg)
58
+ else:
59
+ rv = None # temp don't send anything
60
+
61
+ return rv
62
+
63
+
64
+
65
+ def push_all_observations(enable_push:bool=False):
66
+ '''
67
+ open an API connection to Hugging Face, and push all observation one by one
68
+ '''
69
+
70
+ # get huggingface api
71
+ token = os.environ.get("HF_TOKEN", None)
72
+ api = HfApi(token=token)
73
+
74
+ # iterate over the list of observations
75
+ for hash in st.session_state.public_observations.keys():
76
+ rv = push_observation(hash, api, enable_push=enable_push)
77
+
78
+
79
+
80
  def push_observations(tab_log:DeltaGenerator=None):
81
  """
82
  Push the observations to the Hugging Face dataset
 
87
  push any observation since generating the logger)
88
 
89
  """
90
+ raise DeprecationWarning("This function is deprecated. Use push_all_observations instead.")
91
+
92
  # we get the observation from session state: 1 is the dict 2 is the image.
93
  # first, lets do an info display (popup)
94
  metadata_str = json.dumps(st.session_state.public_observation)
95
 
96
  st.toast(f"Uploading observations: {metadata_str}", icon="🦭")
97
+ g_logger.info(f"Uploading observations: {metadata_str}")
 
 
98
 
99
  # get huggingface api
 
100
  token = os.environ.get("HF_TOKEN", None)
101
  api = HfApi(token=token)
102
 
 
119
  # msg = f"observation attempted tx to repo happy walrus: {rv}"
120
  g_logger.info(msg)
121
  st.info(msg)
122
+
src/input/input_handling.py CHANGED
@@ -1,14 +1,17 @@
 
1
  import datetime
2
  import logging
 
3
 
4
  import streamlit as st
5
  from streamlit.delta_generator import DeltaGenerator
 
6
 
7
  import cv2
8
  import numpy as np
9
 
10
  from input.input_observation import InputObservation
11
- from input.input_validator import get_image_datetime, is_valid_email, is_valid_number
12
 
13
  m_logger = logging.getLogger(__name__)
14
  m_logger.setLevel(logging.INFO)
@@ -23,99 +26,394 @@ allowed_image_types = ['jpg', 'jpeg', 'png', 'webp']
23
  # an arbitrary set of defaults so testing is less painful...
24
  # ideally we add in some randomization to the defaults
25
  spoof_metadata = {
26
- "latitude": 23.5,
27
  "longitude": 44,
28
  "author_email": "[email protected]",
29
  "date": None,
30
  "time": None,
31
  }
32
 
33
- def setup_input(
34
- viewcontainer: DeltaGenerator=None,
35
- _allowed_image_types: list=None, ) -> InputObservation:
36
  """
37
- Sets up the input interface for uploading an image and entering metadata.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- It provides input fields for an image upload, lat/lon, author email, and date-time.
40
- In the ideal case, the image metadata will be used to populate location and datetime.
41
 
42
- Parameters:
43
- viewcontainer (DeltaGenerator, optional): The Streamlit container to use for the input interface. Defaults to st.sidebar.
44
- _allowed_image_types (list, optional): List of allowed image file types for upload. Defaults to allowed_image_types.
 
 
 
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  Returns:
47
- InputObservation: An object containing the uploaded image and entered metadata.
 
 
 
 
 
 
 
 
48
 
 
 
 
49
  """
50
-
51
- if viewcontainer is None:
52
- viewcontainer = st.sidebar
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
- if _allowed_image_types is None:
55
- _allowed_image_types = allowed_image_types
 
 
 
 
 
 
 
56
 
57
 
58
- viewcontainer.title("Input image and data")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
- # 1. Input the author email
61
- author_email = viewcontainer.text_input("Author Email", spoof_metadata.get('author_email', ""))
62
- if author_email and not is_valid_email(author_email):
63
- viewcontainer.error("Please enter a valid email address.")
 
 
 
 
 
64
 
65
- # 2. Image Selector
66
- uploaded_files = viewcontainer.file_uploader("Upload an image", type=allowed_image_types, accept_multiple_files=True)
 
 
 
 
 
 
 
 
67
  observations = {}
68
- images = {}
69
- image_hashes =[]
70
- if uploaded_files is not None:
71
- for file in uploaded_files:
72
-
73
- viewcontainer.title(f"Metadata for {file.name}")
74
-
75
- # Display the uploaded image
76
- # load image using cv2 format, so it is compatible with the ML models
77
- file_bytes = np.asarray(bytearray(file.read()), dtype=np.uint8)
78
- filename = file.name
79
- image = cv2.imdecode(file_bytes, 1)
80
- # Extract and display image date-time
81
- image_datetime = None # For storing date-time from image
82
- image_datetime = get_image_datetime(file)
83
- m_logger.debug(f"image date extracted as {image_datetime} (from {uploaded_files})")
84
 
 
85
 
86
- # 3. Latitude Entry Box
87
- latitude = viewcontainer.text_input("Latitude for "+filename, spoof_metadata.get('latitude', ""))
88
- if latitude and not is_valid_number(latitude):
89
- viewcontainer.error("Please enter a valid latitude (numerical only).")
90
- m_logger.error(f"Invalid latitude entered: {latitude}.")
91
- # 4. Longitude Entry Box
92
- longitude = viewcontainer.text_input("Longitude for "+filename, spoof_metadata.get('longitude', ""))
93
- if longitude and not is_valid_number(longitude):
94
- viewcontainer.error("Please enter a valid longitude (numerical only).")
95
- m_logger.error(f"Invalid latitude entered: {latitude}.")
96
- # 5. Date/time
97
- ## first from image metadata
98
- if image_datetime is not None:
99
- time_value = datetime.datetime.strptime(image_datetime, '%Y:%m:%d %H:%M:%S').time()
100
- date_value = datetime.datetime.strptime(image_datetime, '%Y:%m:%d %H:%M:%S').date()
101
- else:
102
- time_value = datetime.datetime.now().time() # Default to current time
103
- date_value = datetime.datetime.now().date()
104
 
105
- ## if not, give user the option to enter manually
106
- date_option = st.sidebar.date_input("Date for "+filename, value=date_value)
107
- time_option = st.sidebar.time_input("Time for "+filename, time_value)
 
 
 
 
 
 
 
108
 
109
- observation = InputObservation(image=file, latitude=latitude, longitude=longitude,
110
- author_email=author_email, date=image_datetime, time=None,
111
- date_option=date_option, time_option=time_option)
112
- image_hash = observation.to_dict()["image_md5"]
113
- observations[image_hash] = observation
114
- images[image_hash] = image
115
- image_hashes.append(image_hash)
 
 
 
 
 
 
 
 
 
116
 
117
- st.session_state.images = images
118
- st.session_state.files = uploaded_files
119
- st.session_state.observations = observations
120
- st.session_state.image_hashes = image_hashes
121
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
  import datetime
3
  import logging
4
+ import hashlib
5
 
6
  import streamlit as st
7
  from streamlit.delta_generator import DeltaGenerator
8
+ from streamlit.runtime.uploaded_file_manager import UploadedFile
9
 
10
  import cv2
11
  import numpy as np
12
 
13
  from input.input_observation import InputObservation
14
+ from input.input_validator import get_image_datetime, is_valid_email, is_valid_number, get_image_latlon
15
 
16
  m_logger = logging.getLogger(__name__)
17
  m_logger.setLevel(logging.INFO)
 
26
  # an arbitrary set of defaults so testing is less painful...
27
  # ideally we add in some randomization to the defaults
28
  spoof_metadata = {
29
+ "latitude": 0.5,
30
  "longitude": 44,
31
  "author_email": "[email protected]",
32
  "date": None,
33
  "time": None,
34
  }
35
 
36
+ def check_inputs_are_set(empty_ok:bool=False, debug:bool=False) -> bool:
 
 
37
  """
38
+ Checks if all expected inputs have been entered
39
+
40
+ Implementation: via the Streamlit session state.
41
+
42
+ Args:
43
+ empty_ok (bool): If True, returns True if no inputs are set. Default is False.
44
+ debug (bool): If True, prints and logs the status of each expected input key. Default is False.
45
+ Returns:
46
+ bool: True if all expected input keys are set, False otherwise.
47
+ """
48
+ image_hashes = st.session_state.image_hashes
49
+ if len(image_hashes) == 0:
50
+ return empty_ok
51
+
52
+ exp_input_key_stubs = ["input_latitude", "input_longitude", "input_date", "input_time"]
53
+ #exp_input_key_stubs = ["input_latitude", "input_longitude", "input_author_email", "input_date", "input_time",
54
+
55
+ vals = []
56
+ # the author_email is global/one-off - no hash extension.
57
+ if "input_author_email" in st.session_state:
58
+ val = st.session_state["input_author_email"]
59
+ vals.append(val)
60
+ if debug:
61
+ msg = f"{'input_author_email':15}, {(val is not None):8}, {val}"
62
+ m_logger.debug(msg)
63
+ print(msg)
64
+
65
+
66
+ for image_hash in image_hashes:
67
+ for stub in exp_input_key_stubs:
68
+ key = f"{stub}_{image_hash}"
69
+ val = None
70
+ if key in st.session_state:
71
+ val = st.session_state[key]
72
+
73
+ # handle cases where it is defined but empty
74
+ # if val is a string and empty, set to None
75
+ if isinstance(val, str) and not val:
76
+ val = None
77
+ # if val is a list and empty, set to None (not sure what UI elements would return a list?)
78
+ if isinstance(val, list) and not val:
79
+ val = None
80
+ # number 0 is ok - possibly. could be on the equator, e.g.
81
+
82
+ vals.append(val)
83
+ if debug:
84
+ msg = f"{key:15}, {(val is not None):8}, {val}"
85
+ m_logger.debug(msg)
86
+ print(msg)
87
+
88
+
89
+
90
+ return all([v is not None for v in vals])
91
+
92
+
93
+ def buffer_uploaded_files():
94
+ """
95
+ Buffers uploaded files to session_state (images, image_hashes, filenames).
96
 
97
+ Buffers uploaded files by extracting and storing filenames, images, and
98
+ image hashes in the session state.
99
 
100
+ Adds the following keys to `st.session_state`:
101
+ - `images`: dict mapping image hashes to image data (numpy arrays)
102
+ - `files`: list of uploaded files
103
+ - `image_hashes`: list of image hashes
104
+ - `image_filenames`: list of filenames
105
+ """
106
 
107
+
108
+ # buffer info from the file_uploader that doesn't require further user input
109
+ # - the image, the hash, the filename
110
+ # a separate function takes care of per-file user inputs for metadata
111
+ # - this is necessary because dynamically producing more widgets should be
112
+ # avoided inside callbacks (tl;dr: they dissapear)
113
+
114
+ # - note that the UploadedFile objects have file_ids, which are unique to each file
115
+ # - these file_ids are not persistent between sessions, seem to just be random identifiers.
116
+
117
+
118
+ # get files from state
119
+ uploaded_files = st.session_state.file_uploader_data
120
+
121
+ filenames = []
122
+ images = {}
123
+ image_hashes = []
124
+
125
+ for ix, file in enumerate(uploaded_files):
126
+ filename:str = file.name
127
+ print(f"[D] processing {ix}th file {filename}. {file.file_id} {file.type} {file.size}")
128
+ # image to np and hash both require reading the file so do together
129
+ image, image_hash = load_file_and_hash(file)
130
+
131
+ filenames.append(filename)
132
+ image_hashes.append(image_hash)
133
+
134
+ images[image_hash] = image
135
+
136
+ st.session_state.images = images
137
+ st.session_state.files = uploaded_files
138
+ st.session_state.image_hashes = image_hashes
139
+ st.session_state.image_filenames = filenames
140
+
141
+
142
+ def load_file_and_hash(file:UploadedFile) -> Tuple[np.ndarray, str]:
143
+ """
144
+ Loads an image file and computes its MD5 hash.
145
+
146
+ Since both operations require reading the full file contentsV, they are done
147
+ together for efficiency.
148
+
149
+ Args:
150
+ file (UploadedFile): The uploaded file to be processed.
151
  Returns:
152
+ Tuple[np.ndarray, str]: A tuple containing the decoded image as a NumPy array and the MD5 hash of the file's contents.
153
+ """
154
+ # two operations that require reading the file done together for efficiency
155
+ # load the file, compute the hash, return the image and hash
156
+ _bytes = file.read()
157
+ image_hash = hashlib.md5(_bytes).hexdigest()
158
+ image: np.ndarray = cv2.imdecode(np.asarray(bytearray(_bytes), dtype=np.uint8), 1)
159
+
160
+ return (image, image_hash)
161
 
162
+
163
+
164
+ def metadata_inputs_one_file(file:UploadedFile, image_hash:str, dbg_ix:int=0) -> InputObservation:
165
  """
166
+ Creates and parses metadata inputs for a single file
167
+
168
+ Args:
169
+ file (UploadedFile): The uploaded file for which metadata is being handled.
170
+ image_hash (str): The hash of the image.
171
+ dbg_ix (int, optional): Debug index to differentiate data in each input group. Defaults to 0.
172
+ Returns:
173
+ InputObservation: An object containing the metadata and other information for the input file.
174
+ """
175
+ # dbg_ix is a hack to have different data in each input group, checking persistence
176
+
177
+ if st.session_state.container_metadata_inputs is not None:
178
+ _viewcontainer = st.session_state.container_metadata_inputs
179
+ else:
180
+ _viewcontainer = st.sidebar
181
+ m_logger.warning(f"[W] `container_metadata_inputs` is None, using sidebar")
182
+
183
+
184
+
185
+ author_email = st.session_state["input_author_email"]
186
+ filename = file.name
187
+ image_datetime_raw = get_image_datetime(file)
188
+ latitude0, longitude0 = get_image_latlon(file)
189
+ msg = f"[D] {filename}: lat, lon from image metadata: {latitude0}, {longitude0}"
190
+ m_logger.debug(msg)
191
+
192
+ if latitude0 is None: # get some default values if not found in exifdata
193
+ latitude0:float = spoof_metadata.get('latitude', 0) + dbg_ix
194
+ if longitude0 is None:
195
+ longitude0:float = spoof_metadata.get('longitude', 0) - dbg_ix
196
 
197
+ image = st.session_state.images.get(image_hash, None)
198
+ # add the UI elements
199
+ #viewcontainer.title(f"Metadata for {filename}")
200
+ viewcontainer = _viewcontainer.expander(f"Metadata for {file.name}", expanded=True)
201
+
202
+ # TODO: use session state so any changes are persisted within session -- currently I think
203
+ # we are going to take the defaults over and over again -- if the user adjusts coords, or date, it will get lost
204
+ # - it is a bit complicated, if no values change, they persist (the widget definition: params, name, key, etc)
205
+ # even if the code is re-run. but if the value changes, it is lost.
206
 
207
 
208
+ # 3. Latitude Entry Box
209
+ latitude = viewcontainer.text_input(
210
+ "Latitude for " + filename,
211
+ latitude0,
212
+ key=f"input_latitude_{image_hash}")
213
+ if latitude and not is_valid_number(latitude):
214
+ viewcontainer.error("Please enter a valid latitude (numerical only).")
215
+ m_logger.error(f"Invalid latitude entered: {latitude}.")
216
+ # 4. Longitude Entry Box
217
+ longitude = viewcontainer.text_input(
218
+ "Longitude for " + filename,
219
+ longitude0,
220
+ key=f"input_longitude_{image_hash}")
221
+ if longitude and not is_valid_number(longitude):
222
+ viewcontainer.error("Please enter a valid longitude (numerical only).")
223
+ m_logger.error(f"Invalid latitude entered: {latitude}.")
224
+
225
+ # 5. Date/time
226
+ ## first from image metadata
227
+ if image_datetime_raw is not None:
228
+ time_value = datetime.datetime.strptime(image_datetime_raw, '%Y:%m:%d %H:%M:%S').time()
229
+ date_value = datetime.datetime.strptime(image_datetime_raw, '%Y:%m:%d %H:%M:%S').date()
230
+ else:
231
+ time_value = datetime.datetime.now().time() # Default to current time
232
+ date_value = datetime.datetime.now().date()
233
+
234
+ ## either way, give user the option to enter manually (or correct, e.g. if camera has no rtc clock)
235
+ date = viewcontainer.date_input("Date for "+filename, value=date_value, key=f"input_date_{image_hash}")
236
+ time = viewcontainer.time_input("Time for "+filename, time_value, key=f"input_time_{image_hash}")
237
+
238
+ observation = InputObservation(image=image, latitude=latitude, longitude=longitude,
239
+ author_email=author_email, image_datetime_raw=image_datetime_raw,
240
+ date=date, time=time,
241
+ uploaded_file=file, image_md5=image_hash
242
+ )
243
+
244
+ return observation
245
+
246
 
247
+
248
+ def _setup_dynamic_inputs() -> None:
249
+ """
250
+ Setup metadata inputs dynamically for each uploaded file, and process.
251
+
252
+ This operates on the data buffered in the session state, and writes
253
+ the observation objects back to the session state.
254
+
255
+ """
256
 
257
+ # for each file uploaded,
258
+ # - add the UI elements for the metadata
259
+ # - validate the data
260
+ # end of cycle should have observation objects set for each file.
261
+ # - and these go into session state
262
+
263
+ # load the files from the session state
264
+ uploaded_files = st.session_state.files
265
+ hashes = st.session_state.image_hashes
266
+ #images = st.session_state.images
267
  observations = {}
268
+
269
+ for ix, file in enumerate(uploaded_files):
270
+ hash = hashes[ix]
271
+ observation = metadata_inputs_one_file(file, hash, ix)
272
+ old_obs = st.session_state.observations.get(hash, None)
273
+ if old_obs is not None:
274
+ if old_obs == observation:
275
+ m_logger.debug(f"[D] {ix}th observation is the same as before. retaining")
276
+ observations[hash] = old_obs
277
+ else:
278
+ m_logger.debug(f"[D] {ix}th observation is different from before. updating")
279
+ observations[hash] = observation
280
+ observation.show_diff(old_obs)
281
+ else:
282
+ m_logger.debug(f"[D] {ix}th observation is new (image_hash not seen before). Storing")
283
+ observations[hash] = observation
284
 
285
+ st.session_state.observations = observations
286
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
 
288
+ def _setup_oneoff_inputs() -> None:
289
+ '''
290
+ Add the UI input elements for which we have one covering all files
291
+
292
+ - author email
293
+ - file uploader (accepts multiple files)
294
+ '''
295
+
296
+ # fetch the container for the file uploader input elements
297
+ container_file_uploader = st.session_state.container_file_uploader
298
 
299
+ with container_file_uploader:
300
+ # 1. Input the author email
301
+ author_email = st.text_input("Author Email", spoof_metadata.get('author_email', ""),
302
+ key="input_author_email")
303
+ if author_email and not is_valid_email(author_email):
304
+ st.error("Please enter a valid email address.")
305
+
306
+ # 2. Image Selector
307
+ st.file_uploader(
308
+ "Upload one or more images", type=["png", 'jpg', 'jpeg', 'webp'],
309
+ accept_multiple_files=True,
310
+ key="file_uploader_data", on_change=buffer_uploaded_files)
311
+
312
+
313
+
314
+
315
 
 
 
 
 
316
 
317
+
318
+ def setup_input() -> None:
319
+ '''
320
+ Set up the user input handling (files and metadata)
321
+
322
+ It provides input fields for an image upload, and author email.
323
+ Then for each uploaded image,
324
+ - it provides input fields for lat/lon, date-time.
325
+ - In the ideal case, the image metadata will be used to populate location and datetime.
326
+
327
+ Data is stored in the Streamlit session state for downstream processing,
328
+ nothing is returned
329
+
330
+ '''
331
+ # configure the author email and file_uploader (with callback to buffer files)
332
+ _setup_oneoff_inputs()
333
+
334
+ # setup dynamic UI input elements, based on the data that is buffered in session_state
335
+ _setup_dynamic_inputs()
336
+
337
+
338
+ def init_input_container_states() -> None:
339
+ '''
340
+ Initialise the layout containers used in the input handling
341
+ '''
342
+ #if "container_per_file_input_elems" not in st.session_state:
343
+ # st.session_state.container_per_file_input_elems = None
344
+
345
+ if "container_file_uploader" not in st.session_state:
346
+ st.session_state.container_file_uploader = None
347
+
348
+ if "container_metadata_inputs" not in st.session_state:
349
+ st.session_state.container_metadata_inputs = None
350
+
351
+ def init_input_data_session_states() -> None:
352
+ '''
353
+ Initialise the session state variables used in the input handling
354
+ '''
355
+
356
+ if "image_hashes" not in st.session_state:
357
+ st.session_state.image_hashes = []
358
+
359
+ # TODO: ideally just use image_hashes, but need a unique key for the ui elements
360
+ # to track the user input phase; and these are created before the hash is generated.
361
+ if "image_filenames" not in st.session_state:
362
+ st.session_state.image_filenames = []
363
+
364
+ if "observations" not in st.session_state:
365
+ st.session_state.observations = {}
366
+
367
+ if "images" not in st.session_state:
368
+ st.session_state.images = {}
369
+
370
+ if "files" not in st.session_state:
371
+ st.session_state.files = {}
372
+
373
+ if "public_observations" not in st.session_state:
374
+ st.session_state.public_observations = {}
375
+
376
+
377
+
378
+ def add_input_UI_elements() -> None:
379
+ '''
380
+ Create the containers within which user input elements will be placed
381
+ '''
382
+ # we make containers ahead of time, allowing consistent order of elements
383
+ # which are not created in the same order.
384
+
385
+ st.divider()
386
+ st.title("Input image and data")
387
+
388
+ # create and style a container for the file uploader/other one-off inputs
389
+ st.markdown('<style>.st-key-container_file_uploader_id { border: 1px solid skyblue; border-radius: 5px; }</style>', unsafe_allow_html=True)
390
+ container_file_uploader = st.container(border=True, key="container_file_uploader_id")
391
+ st.session_state.container_file_uploader = container_file_uploader
392
+
393
+ # create and style a container for the dynamic metadata inputs
394
+ st.markdown('<style>.st-key-container_metadata_inputs_id { border: 1px solid lightgreen; border-radius: 5px; }</style>', unsafe_allow_html=True)
395
+ container_metadata_inputs = st.container(border=True, key="container_metadata_inputs_id")
396
+ container_metadata_inputs.write("Metadata Inputs... wait for file upload ")
397
+ st.session_state.container_metadata_inputs = container_metadata_inputs
398
+
399
+
400
+ def dbg_show_observation_hashes() -> None:
401
+ """
402
+ Displays information about each observation including the hash
403
+
404
+ - debug usage, keeping track of the hashes and persistence of the InputObservations.
405
+ - it renders text to the current container, not intended for final app.
406
+
407
+ """
408
+
409
+ # a debug: we seem to be losing the whale classes?
410
+ st.write(f"[D] num observations: {len(st.session_state.observations)}")
411
+ s = ""
412
+ for hash in st.session_state.observations.keys():
413
+ obs = st.session_state.observations[hash]
414
+ s += f"- [D] observation {hash} ({obs._inst_id}) has {len(obs.top_predictions)} predictions\n"
415
+ #s += f" - {repr(obs)}\n" # check the str / repr method
416
+
417
+ #print(obs)
418
+
419
+ st.markdown(s)
src/input/input_observation.py CHANGED
@@ -1,13 +1,18 @@
1
  import hashlib
2
  from input.input_validator import generate_random_md5
3
 
 
 
 
 
 
4
  # autogenerated class to hold the input data
5
  class InputObservation:
6
  """
7
  A class to hold an input observation and associated metadata
8
 
9
  Attributes:
10
- image (Any):
11
  The image associated with the observation.
12
  latitude (float):
13
  The latitude where the observation was made.
@@ -15,16 +20,16 @@ class InputObservation:
15
  The longitude where the observation was made.
16
  author_email (str):
17
  The email of the author of the observation.
18
- date (str):
19
- The date when the observation was made.
20
- time (str):
21
- The time when the observation was made.
22
- date_option (str):
23
- Additional date option for the observation.
24
- time_option (str):
25
- Additional time option for the observation.
26
- uploaded_filename (Any):
27
- The uploaded filename associated with the observation.
28
 
29
  Methods:
30
  __str__():
@@ -35,8 +40,8 @@ class InputObservation:
35
  Checks if two observations are equal.
36
  __ne__(other):
37
  Checks if two observations are not equal.
38
- __hash__():
39
- Returns the hash of the observation.
40
  to_dict():
41
  Converts the observation to a dictionary.
42
  from_dict(data):
@@ -44,66 +49,208 @@ class InputObservation:
44
  from_input(input):
45
  Creates an observation from another input observation.
46
  """
47
- def __init__(self, image=None, latitude=None, longitude=None,
48
- author_email=None, date=None, time=None, date_option=None, time_option=None,
49
- uploaded_filename=None):
 
 
 
 
 
 
 
50
  self.image = image
51
  self.latitude = latitude
52
  self.longitude = longitude
53
  self.author_email = author_email
 
54
  self.date = date
55
  self.time = time
56
- self.date_option = date_option
57
- self.time_option = time_option
58
- self.uploaded_filename = uploaded_filename
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  def __str__(self):
61
- return f"Observation: {self.image}, {self.latitude}, {self.longitude}, {self.author_email}, {self.date}, {self.time}, {self.date_option}, {self.time_option}, {self.uploaded_filename}"
 
 
 
 
 
62
 
63
  def __repr__(self):
64
- return f"Observation: {self.image}, {self.latitude}, {self.longitude}, {self.author_email}, {self.date}, {self.time}, {self.date_option}, {self.time_option}, {self.uploaded_filename}"
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  def __eq__(self, other):
67
- return (self.image == other.image and self.latitude == other.latitude and self.longitude == other.longitude and
68
- self.author_email == other.author_email and self.date == other.date and self.time == other.time and
69
- self.date_option == other.date_option and self.time_option == other.time_option and self.uploaded_filename == other.uploaded_filename)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
  def __ne__(self, other):
72
  return not self.__eq__(other)
73
 
74
- def __hash__(self):
75
- return hash((self.image, self.latitude, self.longitude, self.author_email, self.date, self.time, self.date_option, self.time_option, self.uploaded_filename))
76
-
77
  def to_dict(self):
78
  return {
79
  #"image": self.image,
80
- "image_filename": self.uploaded_filename.name if self.uploaded_filename else None,
81
- "image_md5": hashlib.md5(self.uploaded_filename.read()).hexdigest() if self.uploaded_filename else generate_random_md5(),
 
82
  "latitude": self.latitude,
83
  "longitude": self.longitude,
84
  "author_email": self.author_email,
85
- "date": self.date,
86
- "time": self.time,
87
- "date_option": str(self.date_option),
88
- "time_option": str(self.time_option),
89
- "uploaded_filename": self.uploaded_filename
 
 
 
90
  }
91
 
92
  @classmethod
93
  def from_dict(cls, data):
94
- return cls(data["image"], data["latitude"], data["longitude"], data["author_email"], data["date"], data["time"], data["date_option"], data["time_option"], data["uploaded_filename"])
 
 
 
 
 
 
 
 
 
 
95
 
96
  @classmethod
97
  def from_input(cls, input):
98
- return cls(input.image, input.latitude, input.longitude, input.author_email, input.date, input.time, input.date_option, input.time_option, input.uploaded_filename)
99
-
100
- @staticmethod
101
- def from_input(input):
102
- return InputObservation(input.image, input.latitude, input.longitude, input.author_email, input.date, input.time, input.date_option, input.time_option, input.uploaded_filename)
 
 
 
 
 
 
103
 
104
- @staticmethod
105
- def from_dict(data):
106
- return InputObservation(data["image"], data["latitude"], data["longitude"], data["author_email"], data["date"], data["time"], data["date_option"], data["time_option"], data["uploaded_filename"])
107
 
108
 
109
 
 
1
  import hashlib
2
  from input.input_validator import generate_random_md5
3
 
4
+ from numpy import ndarray
5
+ from streamlit.runtime.uploaded_file_manager import UploadedFile
6
+ import datetime
7
+
8
+
9
  # autogenerated class to hold the input data
10
  class InputObservation:
11
  """
12
  A class to hold an input observation and associated metadata
13
 
14
  Attributes:
15
+ image (ndarray):
16
  The image associated with the observation.
17
  latitude (float):
18
  The latitude where the observation was made.
 
20
  The longitude where the observation was made.
21
  author_email (str):
22
  The email of the author of the observation.
23
+ image_datetime_raw (str):
24
+ The datetime extracted from the observation file
25
+ date (datetime.date):
26
+ Date of the observation
27
+ time (datetime.time):
28
+ Time of the observation
29
+ uploaded_file (UploadedFile):
30
+ The uploaded file associated with the observation.
31
+ image_md5 (str):
32
+ The MD5 hash of the image associated with the observation.
33
 
34
  Methods:
35
  __str__():
 
40
  Checks if two observations are equal.
41
  __ne__(other):
42
  Checks if two observations are not equal.
43
+ show_diff(other):
44
+ Shows the differences between two observations.
45
  to_dict():
46
  Converts the observation to a dictionary.
47
  from_dict(data):
 
49
  from_input(input):
50
  Creates an observation from another input observation.
51
  """
52
+
53
+ _inst_count = 0
54
+
55
+ def __init__(
56
+ self, image:ndarray=None, latitude:float=None, longitude:float=None,
57
+ author_email:str=None, image_datetime_raw:str=None,
58
+ date:datetime.date=None,
59
+ time:datetime.time=None,
60
+ uploaded_file:UploadedFile=None, image_md5:str=None):
61
+
62
  self.image = image
63
  self.latitude = latitude
64
  self.longitude = longitude
65
  self.author_email = author_email
66
+ self.image_datetime_raw = image_datetime_raw
67
  self.date = date
68
  self.time = time
69
+ self.uploaded_file = uploaded_file
70
+ self.image_md5 = image_md5
71
+ # attributes that get set after predictions/processing
72
+ self._top_predictions = []
73
+ self._selected_class = None
74
+ self._class_overriden = False
75
+
76
+ InputObservation._inst_count += 1
77
+ self._inst_id = InputObservation._inst_count
78
+
79
+
80
+ #dbg - temporarily give up if hash is not provided
81
+ if self.image_md5 is None:
82
+ raise ValueError(f"Image MD5 hash is required - {self._inst_id:3}.")
83
+
84
+
85
+ def set_top_predictions(self, top_predictions:list):
86
+ self._top_predictions = top_predictions
87
+ if len(top_predictions) > 0:
88
+ self.set_selected_class(top_predictions[0])
89
+
90
+ def set_selected_class(self, selected_class:str):
91
+ self._selected_class = selected_class
92
+ if selected_class != self._top_predictions[0]:
93
+ self.set_class_overriden(True)
94
+
95
+ def set_class_overriden(self, class_overriden:bool):
96
+ self._class_overriden = class_overriden
97
+
98
+ # add getters for the top_predictions, selected_class and class_overriden
99
+ @property
100
+ def top_predictions(self):
101
+ return self._top_predictions
102
+
103
+ @property
104
+ def selected_class(self):
105
+ return self._selected_class
106
+
107
+ @property
108
+ def class_overriden(self):
109
+ return self._class_overriden
110
+
111
+
112
+ # add a method to assign the image_md5 only once
113
+ def assign_image_md5(self):
114
+ raise DeprecationWarning("This method is deprecated. hash is a required constructor argument.")
115
+ if not self.image_md5:
116
+ self.image_md5 = hashlib.md5(self.uploaded_file.read()).hexdigest() if self.uploaded_file else generate_random_md5()
117
+ m_logger.debug(f"[D] Assigned image md5: {self.image_md5} for {self.uploaded_file}")
118
 
119
  def __str__(self):
120
+ _im_str = "None" if self.image is None else f"image dims: {self.image.shape}"
121
+ return (
122
+ f"Observation: {_im_str}, {self.latitude}, {self.longitude}, "
123
+ f"{self.author_email}, {self.image_datetime_raw}, {self.date}, "
124
+ f"{self.time}, {self.uploaded_file}, {self.image_md5}"
125
+ )
126
 
127
  def __repr__(self):
128
+ _im_str = "None" if self.image is None else f"image dims: {self.image.shape}"
129
+ return (
130
+ f"Observation: "
131
+ f"Image: {_im_str}, "
132
+ f"Latitude: {self.latitude}, "
133
+ f"Longitude: {self.longitude}, "
134
+ f"Author Email: {self.author_email}, "
135
+ f"raw timestamp: {self.image_datetime_raw}, "
136
+ f"Date: {self.date}, "
137
+ f"Time: {self.time}, "
138
+ f"Uploaded Filename: {self.uploaded_file}"
139
+ f"Image MD5 hash: {self.image_md5}"
140
+ )
141
+
142
 
143
  def __eq__(self, other):
144
+ # TODO: ensure this covers all the attributes (some have been added?)
145
+ # - except inst_id which is unique
146
+ _image_equality = False
147
+ if self.image is None or other.image is None:
148
+ _image_equality = other.image == self.image
149
+ else: # maybe strong assumption: both are correctly ndarray.. should I test types intead?
150
+ _image_equality = (self.image == other.image).all()
151
+ equality = (
152
+ #self.image == other.image and
153
+ _image_equality and
154
+ self.latitude == other.latitude and
155
+ self.longitude == other.longitude and
156
+ self.author_email == other.author_email and
157
+ self.image_datetime_raw == other.image_datetime_raw and
158
+ self.date == other.date and
159
+ # temporarily skip time, it is followed by the clock and that is always differnt
160
+ #self.time == other.time and
161
+ self.uploaded_file == other.uploaded_file and
162
+ self.image_md5 == other.image_md5
163
+ )
164
+ return equality
165
+
166
+ # define a function show_diff(other) that shows the differences between two observations
167
+ # only highlight the differences, if element is the same don't show it
168
+ # have a summary at the top that shows if the observations are the same or not
169
+
170
+ def show_diff(self, other):
171
+ """Show the differences between two observations"""
172
+ differences = []
173
+ if self.image is None or other.image is None:
174
+ if other.image != self.image:
175
+ differences.append(f" Image is different. (types mismatch: {type(self.image)} vs {type(other.image)})")
176
+ else:
177
+ if (self.image != other.image).any():
178
+ cnt = (self.image != other.image).sum()
179
+ differences.append(f" Image is different: {cnt} different pixels.")
180
+ if self.latitude != other.latitude:
181
+ differences.append(f" Latitude is different. (self: {self.latitude}, other: {other.latitude})")
182
+ if self.longitude != other.longitude:
183
+ differences.append(f" Longitude is different. (self: {self.longitude}, other: {other.longitude})")
184
+ if self.author_email != other.author_email:
185
+ differences.append(f" Author email is different. (self: {self.author_email}, other: {other.author_email})")
186
+ if self.image_datetime_raw != other.image_datetime_raw:
187
+ differences.append(f" Date is different. (self: {self.image_datetime_raw}, other: {other.image_datetime_raw})")
188
+ if self.date != other.date:
189
+ differences.append(f" Date is different. (self: {self.date}, other: {other.date})")
190
+ if self.time != other.time:
191
+ differences.append(f" Time is different. (self: {self.time}, other: {other.time})")
192
+ if self.uploaded_file != other.uploaded_file:
193
+ differences.append(" Uploaded filename is different.")
194
+ if self.image_md5 != other.image_md5:
195
+ differences.append(" Image MD5 hash is different.")
196
+
197
+ if differences:
198
+ print(f"Observations have {len(differences)} differences:")
199
+ for diff in differences:
200
+ print(diff)
201
+ else:
202
+ print("Observations are the same.")
203
 
204
  def __ne__(self, other):
205
  return not self.__eq__(other)
206
 
 
 
 
207
  def to_dict(self):
208
  return {
209
  #"image": self.image,
210
+ "image_filename": self.uploaded_file.name if self.uploaded_file else None,
211
+ "image_md5": self.image_md5,
212
+ #"image_md5": hashlib.md5(self.uploaded_file.read()).hexdigest() if self.uploaded_file else generate_random_md5(),
213
  "latitude": self.latitude,
214
  "longitude": self.longitude,
215
  "author_email": self.author_email,
216
+ "image_datetime_raw": self.image_datetime_raw,
217
+ "date": str(self.date),
218
+ "time": str(self.time),
219
+ "selected_class": self._selected_class,
220
+ "top_prediction": self._top_predictions[0] if len(self._top_predictions) else None,
221
+ "class_overriden": self._class_overriden,
222
+
223
+ #"uploaded_file": self.uploaded_file # can't serialize this in json, not sent to dataset anyway.
224
  }
225
 
226
  @classmethod
227
  def from_dict(cls, data):
228
+ return cls(
229
+ image=data.get("image"),
230
+ latitude=data.get("latitude"),
231
+ longitude=data.get("longitude"),
232
+ author_email=data.get("author_email"),
233
+ image_datetime_raw=data.get("image_datetime_raw"),
234
+ date=data.get("date"),
235
+ time=data.get("time"),
236
+ uploaded_file=data.get("uploaded_file"),
237
+ image_hash=data.get("image_md5")
238
+ )
239
 
240
  @classmethod
241
  def from_input(cls, input):
242
+ return cls(
243
+ image=input.image,
244
+ latitude=input.latitude,
245
+ longitude=input.longitude,
246
+ author_email=input.author_email,
247
+ image_datetime_raw=input.image_datetime_raw,
248
+ date=input.date,
249
+ time=input.time,
250
+ uploaded_file=input.uploaded_file,
251
+ image_hash=input.image_hash
252
+ )
253
 
 
 
 
254
 
255
 
256
 
src/input/input_validator.py CHANGED
@@ -1,22 +1,33 @@
 
1
  import random
2
  import string
3
  import hashlib
4
  import re
5
- import streamlit as st
6
  from fractions import Fraction
7
-
8
  from PIL import Image
9
  from PIL import ExifTags
10
 
 
11
  from streamlit.runtime.uploaded_file_manager import UploadedFile
12
 
13
- def generate_random_md5():
 
 
 
 
 
 
 
 
 
 
14
  # Generate a random string
15
- random_string = ''.join(random.choices(string.ascii_letters + string.digits, k=16))
16
  # Encode the string and compute its MD5 hash
17
  md5_hash = hashlib.md5(random_string.encode()).hexdigest()
18
  return md5_hash
19
 
 
20
  def is_valid_number(number:str) -> bool:
21
  """
22
  Check if the given string is a valid number (int or float, sign ok)
@@ -30,6 +41,7 @@ def is_valid_number(number:str) -> bool:
30
  pattern = r'^[-+]?[0-9]*\.?[0-9]+$'
31
  return re.match(pattern, number) is not None
32
 
 
33
  # Function to validate email address
34
  def is_valid_email(email:str) -> bool:
35
  """
@@ -41,11 +53,14 @@ def is_valid_email(email:str) -> bool:
41
  Returns:
42
  bool: True if the email address is valid, False otherwise.
43
  """
44
- pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
 
 
45
  return re.match(pattern, email) is not None
46
 
 
47
  # Function to extract date and time from image metadata
48
- def get_image_datetime(image_file):
49
  """
50
  Extracts the original date and time from the EXIF metadata of an uploaded image file.
51
 
@@ -69,6 +84,7 @@ def get_image_datetime(image_file):
69
  # TODO: add to logger
70
  return None
71
 
 
72
  def decimal_coords(coords:tuple, ref:str) -> Fraction:
73
  """
74
  Converts coordinates from degrees, minutes, and seconds to decimal degrees.
@@ -96,8 +112,9 @@ def decimal_coords(coords:tuple, ref:str) -> Fraction:
96
  return decimal_degrees
97
 
98
 
99
- #def get_image_latlon(image_file: UploadedFile) -> tuple[float, float] | None:
100
- def get_image_latlon(image_file: UploadedFile) :
 
101
  """
102
  Extracts the latitude and longitude from the EXIF metadata of an uploaded image file.
103
 
@@ -123,4 +140,6 @@ def get_image_latlon(image_file: UploadedFile) :
123
  return lat, lon
124
 
125
  except Exception as e: # FIXME: what types of exception?
126
- st.warning(f"Could not extract latitude and longitude from image metadata. (file: {str(image_file)}")
 
 
 
1
+ from typing import Tuple, Union
2
  import random
3
  import string
4
  import hashlib
5
  import re
 
6
  from fractions import Fraction
 
7
  from PIL import Image
8
  from PIL import ExifTags
9
 
10
+ import streamlit as st
11
  from streamlit.runtime.uploaded_file_manager import UploadedFile
12
 
13
+ def generate_random_md5(length:int=16) -> str:
14
+ """
15
+ Generate a random MD5 hash.
16
+
17
+ Args:
18
+ length (int): The length of the random string to generate. Default is 16.
19
+
20
+ Returns:
21
+ str: The MD5 hash of the generated random string.
22
+ """
23
+
24
  # Generate a random string
25
+ random_string = ''.join(random.choices(string.ascii_letters + string.digits, length=16))
26
  # Encode the string and compute its MD5 hash
27
  md5_hash = hashlib.md5(random_string.encode()).hexdigest()
28
  return md5_hash
29
 
30
+
31
  def is_valid_number(number:str) -> bool:
32
  """
33
  Check if the given string is a valid number (int or float, sign ok)
 
41
  pattern = r'^[-+]?[0-9]*\.?[0-9]+$'
42
  return re.match(pattern, number) is not None
43
 
44
+
45
  # Function to validate email address
46
  def is_valid_email(email:str) -> bool:
47
  """
 
53
  Returns:
54
  bool: True if the email address is valid, False otherwise.
55
  """
56
+ #pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
57
+ # do not allow starting with a +
58
+ pattern = r'^[a-zA-Z0-9_]+[a-zA-Z0-9._%+-]*@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
59
  return re.match(pattern, email) is not None
60
 
61
+
62
  # Function to extract date and time from image metadata
63
+ def get_image_datetime(image_file:UploadedFile) -> Union[str, None]:
64
  """
65
  Extracts the original date and time from the EXIF metadata of an uploaded image file.
66
 
 
84
  # TODO: add to logger
85
  return None
86
 
87
+
88
  def decimal_coords(coords:tuple, ref:str) -> Fraction:
89
  """
90
  Converts coordinates from degrees, minutes, and seconds to decimal degrees.
 
112
  return decimal_degrees
113
 
114
 
115
+ #def get_image_latlon(image_file: UploadedFile) : # if it is still not working
116
+ #def get_image_latlon(image_file: UploadedFile) -> Tuple[float, float] | None: # Python >=3.10
117
+ def get_image_latlon(image_file: UploadedFile) -> Union[Tuple[float, float], None]: # 3.6 <= Python < 3.10
118
  """
119
  Extracts the latitude and longitude from the EXIF metadata of an uploaded image file.
120
 
 
140
  return lat, lon
141
 
142
  except Exception as e: # FIXME: what types of exception?
143
+ st.warning(f"Could not extract latitude and longitude from image metadata. (file: {str(image_file)}")
144
+
145
+ return None, None
src/main.py CHANGED
@@ -9,17 +9,24 @@ from streamlit_folium import st_folium
9
  from transformers import pipeline
10
  from transformers import AutoModelForImageClassification
11
 
12
- from maps.obs_map import add_header_text
 
13
  from datasets import disable_caching
14
  disable_caching()
15
 
16
  import whale_gallery as gallery
17
  import whale_viewer as viewer
18
- from input.input_handling import setup_input
 
 
 
19
  from maps.alps_map import present_alps_map
20
  from maps.obs_map import present_obs_map
21
- from utils.st_logs import setup_logging, parse_log_buffer
22
- from classifier.classifier_image import cetacean_classify
 
 
 
23
  from classifier.classifier_hotdog import hotdog_classify
24
 
25
 
@@ -34,6 +41,11 @@ data_files = "data/train-00000-of-00001.parquet"
34
  USE_BASIC_MAP = False
35
  DEV_SIDEBAR_LIB = True
36
 
 
 
 
 
 
37
  # get a global var for logger accessor in this module
38
  LOG_LEVEL = logging.DEBUG
39
  g_logger = logging.getLogger(__name__)
@@ -42,33 +54,13 @@ g_logger.setLevel(LOG_LEVEL)
42
  st.set_page_config(layout="wide")
43
 
44
  # initialise various session state variables
45
- if "handler" not in st.session_state:
46
- st.session_state['handler'] = setup_logging()
47
-
48
- if "image_hashes" not in st.session_state:
49
- st.session_state.image_hashes = []
50
-
51
- if "observations" not in st.session_state:
52
- st.session_state.observations = {}
53
-
54
- if "images" not in st.session_state:
55
- st.session_state.images = {}
56
-
57
- if "files" not in st.session_state:
58
- st.session_state.files = {}
59
 
60
- if "public_observation" not in st.session_state:
61
- st.session_state.public_observation = {}
62
-
63
- if "classify_whale_done" not in st.session_state:
64
- st.session_state.classify_whale_done = False
65
-
66
- if "whale_prediction1" not in st.session_state:
67
- st.session_state.whale_prediction1 = None
68
-
69
- if "tab_log" not in st.session_state:
70
- st.session_state.tab_log = None
71
-
72
 
73
  def main() -> None:
74
  """
@@ -100,29 +92,22 @@ def main() -> None:
100
  # Streamlit app
101
  tab_inference, tab_hotdogs, tab_map, tab_coords, tab_log, tab_gallery = \
102
  st.tabs(["Cetecean classifier", "Hotdog classifier", "Map", "*:gray[Dev:coordinates]*", "Log", "Beautiful cetaceans"])
103
- st.session_state.tab_log = tab_log
104
 
 
 
105
 
106
  # create a sidebar, and parse all the input (returned as `observations` object)
107
- setup_input(viewcontainer=st.sidebar)
 
 
 
 
108
 
109
 
110
- if 0:## WIP
111
- # goal of this code is to allow the user to override the ML prediction, before transmitting an observations
112
- predicted_class = st.sidebar.selectbox("Predicted Class", viewer.WHALE_CLASSES)
113
- override_prediction = st.sidebar.checkbox("Override Prediction")
114
-
115
- if override_prediction:
116
- overridden_class = st.sidebar.selectbox("Override Class", viewer.WHALE_CLASSES)
117
- st.session_state.observations['class_overriden'] = overridden_class
118
- else:
119
- st.session_state.observations['class_overriden'] = None
120
-
121
-
122
  with tab_map:
123
  # visual structure: a couple of toggles at the top, then the map inlcuding a
124
  # dropdown for tileset selection.
125
- add_header_text()
126
  tab_map_ui_cols = st.columns(2)
127
  with tab_map_ui_cols[0]:
128
  show_db_points = st.toggle("Show Points from DB", True)
@@ -180,43 +165,128 @@ def main() -> None:
180
  gallery.render_whale_gallery(n_cols=4)
181
 
182
 
183
- # Display submitted observation
184
- if st.sidebar.button("Validate"):
185
- # create a dictionary with the submitted observation
186
- tab_log.info(f"{st.session_state.observations}")
187
- df = pd.DataFrame(st.session_state.observations, index=[0])
188
- with tab_coords:
189
- st.table(df)
190
-
191
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
 
193
 
194
- # inside the inference tab, on button press we call the model (on huggingface hub)
195
- # which will be run locally.
196
- # - the model predicts the top 3 most likely species from the input image
197
- # - these species are shown
198
- # - the user can override the species prediction using the dropdown
199
- # - an observation is uploaded if the user chooses.
200
- tab_inference.markdown("""
201
- *Run classifer to identify the species of cetean on the uploaded image.
202
- Once inference is complete, the top three predictions are shown.
203
- You can override the prediction by selecting a species from the dropdown.*""")
204
 
205
- if tab_inference.button("Identify with cetacean classifier"):
206
- #pipe = pipeline("image-classification", model="Saving-Willy/cetacean-classifier", trust_remote_code=True)
207
- cetacean_classifier = AutoModelForImageClassification.from_pretrained("Saving-Willy/cetacean-classifier",
208
- revision=classifier_revision,
209
- trust_remote_code=True)
210
-
 
 
 
 
 
 
 
 
 
 
211
 
212
- if st.session_state.images is None:
213
- # TODO: cleaner design to disable the button until data input done?
214
- st.info("Please upload an image first.")
215
- else:
216
- cetacean_classify(cetacean_classifier)
217
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
 
 
 
 
 
 
 
 
 
 
 
219
 
 
 
 
220
 
221
  # inside the hotdog tab, on button press we call a 2nd model (totally unrelated at present, just for demo
222
  # purposes, an hotdog image classifier) which will be run locally.
@@ -240,6 +310,9 @@ def main() -> None:
240
  hotdog_classify(pipeline_hot_dog, tab_hotdogs)
241
 
242
 
 
 
 
243
 
244
  if __name__ == "__main__":
245
  main()
 
9
  from transformers import pipeline
10
  from transformers import AutoModelForImageClassification
11
 
12
+ from maps.obs_map import add_obs_map_header
13
+ from classifier.classifier_image import add_classifier_header
14
  from datasets import disable_caching
15
  disable_caching()
16
 
17
  import whale_gallery as gallery
18
  import whale_viewer as viewer
19
+ from input.input_handling import setup_input, check_inputs_are_set
20
+ from input.input_handling import init_input_container_states, add_input_UI_elements, init_input_data_session_states
21
+ from input.input_handling import dbg_show_observation_hashes
22
+
23
  from maps.alps_map import present_alps_map
24
  from maps.obs_map import present_obs_map
25
+ from utils.st_logs import parse_log_buffer, init_logging_session_states
26
+ from utils.workflow_ui import refresh_progress_display, init_workflow_viz, init_workflow_session_states
27
+ from hf_push_observations import push_all_observations
28
+
29
+ from classifier.classifier_image import cetacean_just_classify, cetacean_show_results_and_review, cetacean_show_results, init_classifier_session_states
30
  from classifier.classifier_hotdog import hotdog_classify
31
 
32
 
 
41
  USE_BASIC_MAP = False
42
  DEV_SIDEBAR_LIB = True
43
 
44
+ # one toggle for all the extra debug text
45
+ if "MODE_DEV_STATEFUL" not in st.session_state:
46
+ st.session_state.MODE_DEV_STATEFUL = False
47
+
48
+
49
  # get a global var for logger accessor in this module
50
  LOG_LEVEL = logging.DEBUG
51
  g_logger = logging.getLogger(__name__)
 
54
  st.set_page_config(layout="wide")
55
 
56
  # initialise various session state variables
57
+ init_logging_session_states() # logging init should be early
58
+ init_workflow_session_states()
59
+ init_input_data_session_states()
60
+ init_input_container_states()
61
+ init_workflow_viz()
62
+ init_classifier_session_states()
 
 
 
 
 
 
 
 
63
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  def main() -> None:
66
  """
 
92
  # Streamlit app
93
  tab_inference, tab_hotdogs, tab_map, tab_coords, tab_log, tab_gallery = \
94
  st.tabs(["Cetecean classifier", "Hotdog classifier", "Map", "*:gray[Dev:coordinates]*", "Log", "Beautiful cetaceans"])
 
95
 
96
+ # put this early so the progress indicator is at the top (also refreshed at end)
97
+ refresh_progress_display()
98
 
99
  # create a sidebar, and parse all the input (returned as `observations` object)
100
+ with st.sidebar:
101
+ # layout handling
102
+ add_input_UI_elements()
103
+ # input elements (file upload, text input, etc)
104
+ setup_input()
105
 
106
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  with tab_map:
108
  # visual structure: a couple of toggles at the top, then the map inlcuding a
109
  # dropdown for tileset selection.
110
+ add_obs_map_header()
111
  tab_map_ui_cols = st.columns(2)
112
  with tab_map_ui_cols[0]:
113
  show_db_points = st.toggle("Show Points from DB", True)
 
165
  gallery.render_whale_gallery(n_cols=4)
166
 
167
 
168
+ # state handling re data_entry phases
169
+ # 0. no data entered yet -> display the file uploader thing
170
+ # 1. we have some images, but not all the metadata fields are done -> validate button shown, disabled
171
+ # 2. all data entered -> validate button enabled
172
+ # 3. validation button pressed, validation done -> enable the inference button.
173
+ # - at this point do we also want to disable changes to the metadata selectors?
174
+ # anyway, simple first.
175
+
176
+ if st.session_state.workflow_fsm.is_in_state('doing_data_entry'):
177
+ # can we advance state? - only when all inputs are set for all uploaded files
178
+ all_inputs_set = check_inputs_are_set(debug=True, empty_ok=False)
179
+ if all_inputs_set:
180
+ st.session_state.workflow_fsm.complete_current_state()
181
+ # -> data_entry_complete
182
+ else:
183
+ # button, disabled; no state change yet.
184
+ st.sidebar.button(":gray[*Validate*]", disabled=True, help="Please fill in all fields.")
185
+
186
+
187
+ if st.session_state.workflow_fsm.is_in_state('data_entry_complete'):
188
+ # can we advance state? - only when the validate button is pressed
189
+ if st.sidebar.button(":white_check_mark:[**Validate**]"):
190
+ # create a dictionary with the submitted observation
191
+ tab_log.info(f"{st.session_state.observations}")
192
+ df = pd.DataFrame([obs.to_dict() for obs in st.session_state.observations.values()])
193
+ #df = pd.DataFrame(st.session_state.observations, index=[0])
194
+ with tab_coords:
195
+ st.table(df)
196
+ # there doesn't seem to be any actual validation here?? TODO: find validator function (each element is validated by the input box, but is there something at the whole image level?)
197
+ # hmm, maybe it should actually just be "I'm done with data entry"
198
+ st.session_state.workflow_fsm.complete_current_state()
199
+ # -> data_entry_validated
200
+
201
+ # state handling re inference phases (tab_inference)
202
+ # 3. validation button pressed, validation done -> enable the inference button.
203
+ # 4. inference button pressed -> ML started. | let's cut this one out, since it would only
204
+ # make sense if we did it as an async action
205
+ # 5. ML done -> show results, and manual validation options
206
+ # 6. manual validation done -> enable the upload buttons
207
+ #
208
+ with tab_inference:
209
+ # inside the inference tab, on button press we call the model (on huggingface hub)
210
+ # which will be run locally.
211
+ # - the model predicts the top 3 most likely species from the input image
212
+ # - these species are shown
213
+ # - the user can override the species prediction using the dropdown
214
+ # - an observation is uploaded if the user chooses.
215
 
216
 
217
+ if st.session_state.MODE_DEV_STATEFUL:
218
+ dbg_show_observation_hashes()
219
+
220
+ add_classifier_header()
221
+ # if we are before data_entry_validated, show the button, disabled.
222
+ if not st.session_state.workflow_fsm.is_in_state_or_beyond('data_entry_validated'):
223
+ tab_inference.button(":gray[*Identify with cetacean classifier*]", disabled=True,
224
+ help="Please validate inputs before proceeding",
225
+ key="button_infer_ceteans")
 
226
 
227
+ if st.session_state.workflow_fsm.is_in_state('data_entry_validated'):
228
+ # show the button, enabled. If pressed, we start the ML model (And advance state)
229
+ if tab_inference.button("Identify with cetacean classifier"):
230
+ cetacean_classifier = AutoModelForImageClassification.from_pretrained(
231
+ "Saving-Willy/cetacean-classifier",
232
+ revision=classifier_revision,
233
+ trust_remote_code=True)
234
+
235
+ cetacean_just_classify(cetacean_classifier)
236
+ st.session_state.workflow_fsm.complete_current_state()
237
+ # trigger a refresh too (refreshhing the prog indicator means the script reruns and
238
+ # we can enter the next state - visualising the results / review)
239
+ # ok it doesn't if done programmatically. maybe interacting with teh button? check docs.
240
+ refresh_progress_display()
241
+ #TODO: validate this doesn't harm performance adversely.
242
+ st.rerun()
243
 
244
+ elif st.session_state.workflow_fsm.is_in_state('ml_classification_completed'):
245
+ # show the results, and allow manual validation
246
+ st.markdown("""### Inference results and manual validation/adjustment """)
247
+ if st.session_state.MODE_DEV_STATEFUL:
248
+ s = ""
249
+ for k, v in st.session_state.whale_prediction1.items():
250
+ s += f"* Image {k}: {v}\n"
251
+
252
+ st.markdown(s)
253
+
254
+ # add a button to advance the state
255
+ if st.button("Confirm species predictions", help="Confirm that all species are selected correctly"):
256
+ st.session_state.workflow_fsm.complete_current_state()
257
+ # -> manual_inspection_completed
258
+ st.rerun()
259
+
260
+ cetacean_show_results_and_review()
261
+
262
+ elif st.session_state.workflow_fsm.is_in_state('manual_inspection_completed'):
263
+ # show the ML results, and allow the user to upload the observation
264
+ st.markdown("""### Inference Results (after manual validation) """)
265
+
266
+
267
+ if st.button("Upload all observations to THE INTERNET!"):
268
+ # let this go through to the push_all func, since it just reports to log for now.
269
+ push_all_observations(enable_push=False)
270
+ st.session_state.workflow_fsm.complete_current_state()
271
+ # -> data_uploaded
272
+ st.rerun()
273
+
274
+ cetacean_show_results()
275
 
276
+ elif st.session_state.workflow_fsm.is_in_state('data_uploaded'):
277
+ # the data has been sent. Lets show the observations again
278
+ # but no buttons to upload (or greyed out ok)
279
+ st.markdown("""### Observation(s) uploaded - thank you!""")
280
+ cetacean_show_results()
281
+
282
+ st.divider()
283
+ #df = pd.DataFrame(st.session_state.observations, index=[0])
284
+ df = pd.DataFrame([obs.to_dict() for obs in st.session_state.observations.values()])
285
+ st.table(df)
286
 
287
+ # didn't decide what the next state is here - I think we are in the terminal state.
288
+ #st.session_state.workflow_fsm.complete_current_state()
289
+
290
 
291
  # inside the hotdog tab, on button press we call a 2nd model (totally unrelated at present, just for demo
292
  # purposes, an hotdog image classifier) which will be run locally.
 
310
  hotdog_classify(pipeline_hot_dog, tab_hotdogs)
311
 
312
 
313
+ # after all other processing, we can show the stage/state
314
+ refresh_progress_display()
315
+
316
 
317
  if __name__ == "__main__":
318
  main()
src/maps/obs_map.py CHANGED
@@ -192,8 +192,8 @@ def present_obs_map(dataset_id:str = "Saving-Willy/Happywhale-kaggle",
192
  return st_data
193
 
194
 
195
- def add_header_text() -> None:
196
  """
197
  Add brief explainer text to the tab
198
  """
199
- st.write("A map showing the observations in the dataset, with markers colored by species.")
 
192
  return st_data
193
 
194
 
195
+ def add_obs_map_header() -> None:
196
  """
197
  Add brief explainer text to the tab
198
  """
199
+ st.write("A map showing the observations in the dataset, with markers colored by species.")
src/utils/metadata_handler.py CHANGED
@@ -1,16 +1,26 @@
1
  import streamlit as st
2
 
3
- def metadata2md() -> str:
4
  """Get metadata from cache and return as markdown-formatted key-value list
5
 
 
 
 
 
6
  Returns:
7
  str: Markdown-formatted key-value list of metadata
8
 
9
  """
10
  markdown_str = "\n"
11
- keys_to_print = ["latitude","longitude","author_email","date","time"]
12
- for key, value in st.session_state.public_observation.items():
13
- if key in keys_to_print:
14
- markdown_str += f"- **{key}**: {value}\n"
 
 
 
 
 
 
15
  return markdown_str
16
 
 
1
  import streamlit as st
2
 
3
+ def metadata2md(image_hash:str, debug:bool=False) -> str:
4
  """Get metadata from cache and return as markdown-formatted key-value list
5
 
6
+ Args:
7
+ image_hash (str): The hash of the image to get metadata for
8
+ debug (bool, optional): Whether to print additional fields.
9
+
10
  Returns:
11
  str: Markdown-formatted key-value list of metadata
12
 
13
  """
14
  markdown_str = "\n"
15
+ keys_to_print = ["author_email", "latitude", "longitude", "date", "time"]
16
+ if debug:
17
+ keys_to_print += ["iamge_md5", "selected_class", "top_prediction", "class_overriden"]
18
+
19
+ observation = st.session_state.public_observations.get(image_hash, {})
20
+
21
+ for key, value in observation.items():
22
+ if key in keys_to_print:
23
+ markdown_str += f"- **{key}**: {value}\n"
24
+
25
  return markdown_str
26
 
src/utils/st_logs.py CHANGED
@@ -100,6 +100,16 @@ class StreamlitLogHandler(logging.Handler):
100
  self.log_area.empty() # Clear previous logs
101
  self.buffer.clear()
102
 
 
 
 
 
 
 
 
 
 
 
103
  # Set up logging to capture all info level logs from the root logger
104
  @st.cache_resource
105
  def setup_logging(level:int=logging.INFO, buffer_len:int=15) -> StreamlitLogHandler:
@@ -126,6 +136,7 @@ def setup_logging(level:int=logging.INFO, buffer_len:int=15) -> StreamlitLogHand
126
  # st.session_state['handler'] = handler
127
  return handler
128
 
 
129
  def parse_log_buffer(log_contents: deque) -> List[dict]:
130
  """
131
  Convert log buffer to a list of dictionaries for use with a streamlit datatable.
 
100
  self.log_area.empty() # Clear previous logs
101
  self.buffer.clear()
102
 
103
+
104
+ def init_logging_session_states():
105
+ """
106
+ Initialise the session state variables for logging.
107
+ """
108
+
109
+ if "handler" not in st.session_state:
110
+ st.session_state['handler'] = setup_logging()
111
+
112
+
113
  # Set up logging to capture all info level logs from the root logger
114
  @st.cache_resource
115
  def setup_logging(level:int=logging.INFO, buffer_len:int=15) -> StreamlitLogHandler:
 
136
  # st.session_state['handler'] = handler
137
  return handler
138
 
139
+
140
  def parse_log_buffer(log_contents: deque) -> List[dict]:
141
  """
142
  Convert log buffer to a list of dictionaries for use with a streamlit datatable.
src/utils/workflow_state.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transitions import Machine
2
+ from typing import List
3
+
4
+ OKBLUE = '\033[94m'
5
+ OKGREEN = '\033[92m'
6
+ OKCYAN = '\033[96m'
7
+ FAIL = '\033[91m'
8
+ ENDC = '\033[0m'
9
+
10
+
11
+ FSM_STATES = ['doing_data_entry', 'data_entry_complete', 'data_entry_validated',
12
+ #'ml_classification_started',
13
+ 'ml_classification_completed',
14
+ 'manual_inspection_completed', 'data_uploaded']
15
+
16
+
17
+ class WorkflowFSM:
18
+ def __init__(self, state_sequence: List[str]):
19
+ self.state_sequence = state_sequence
20
+ self.state_dict = {state: i for i, state in enumerate(state_sequence)}
21
+
22
+ # Create state machine
23
+ self.machine = Machine(
24
+ model=self,
25
+ states=state_sequence,
26
+ initial=state_sequence[0],
27
+ )
28
+
29
+ # For each state (except the last), add a completion transition to the next state
30
+ for i in range(len(state_sequence) - 1):
31
+ current_state = state_sequence[i]
32
+ next_state = state_sequence[i + 1]
33
+
34
+ self.machine.add_transition(
35
+ trigger=f'complete_{current_state}',
36
+ source=current_state,
37
+ dest=next_state,
38
+ conditions=[f'is_in_{current_state}']
39
+ )
40
+
41
+ # Dynamically add a condition method for each state
42
+ setattr(self, f'is_in_{current_state}',
43
+ lambda s=current_state: self.is_in_state(s))
44
+
45
+ # Add callbacks for logging
46
+ self.machine.before_state_change = self._log_transition
47
+ self.machine.after_state_change = self._post_transition
48
+
49
+ def is_in_state(self, state_name: str) -> bool:
50
+ """Check if we're currently in the specified state"""
51
+ return self.state == state_name
52
+
53
+ def complete_current_state(self) -> bool:
54
+ """
55
+ Signal that the current state is complete.
56
+ Returns True if state transition occurred, False otherwise.
57
+ """
58
+ current_state = self.state
59
+ trigger_name = f'complete_{current_state}'
60
+
61
+ if hasattr(self, trigger_name):
62
+ try:
63
+ trigger_func = getattr(self, trigger_name)
64
+ trigger_func()
65
+ return True
66
+ except:
67
+ return False
68
+ return False
69
+
70
+ # add a helper method, to find out if a given state has been reached/passed
71
+ # we first need to get the index of the current state
72
+ # then the index of the argument state
73
+ # compare, and return boolean
74
+
75
+ def is_in_state_or_beyond(self, state_name: str) -> bool:
76
+ """Check if we have reached or passed the specified state"""
77
+ if state_name not in self.state_dict:
78
+ raise ValueError(f"Invalid state: {state_name}")
79
+
80
+ return self.state_dict[state_name] <= self.state_dict[self.state]
81
+
82
+
83
+ @property
84
+ def current_state(self) -> str:
85
+ """Get the current state name"""
86
+ return self.state
87
+
88
+ @property
89
+ def current_state_index(self) -> int:
90
+ """Get the current state index"""
91
+ return self.state_dict[self.state]
92
+
93
+ @property
94
+ def num_states(self) -> int:
95
+ return len(self.state_sequence)
96
+
97
+
98
+ def _log_transition(self):
99
+ # TODO: use logger, not printing.
100
+ self._cprint(f"[FSM] -> Transitioning from {self.current_state}")
101
+
102
+ def _post_transition(self):
103
+ # TODO: use logger, not printing.
104
+ self._cprint(f"[FSM] -| Transitioned to {self.current_state}")
105
+
106
+ def _cprint(self, msg:str, color:str=OKCYAN):
107
+ """Print colored message"""
108
+ print(f"{color}{msg}{ENDC}")
src/utils/workflow_ui.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from utils.workflow_state import WorkflowFSM, FSM_STATES
3
+
4
+ def init_workflow_session_states():
5
+ """
6
+ Initialise the session state variables for the workflow state machine
7
+ """
8
+
9
+ if "workflow_fsm" not in st.session_state:
10
+ # create and init the state machine
11
+ st.session_state.workflow_fsm = WorkflowFSM(FSM_STATES)
12
+
13
+ def refresh_progress_display() -> None:
14
+ """
15
+ Updates the workflow progress display in the Streamlit sidebar.
16
+ """
17
+ with st.sidebar:
18
+ num_states = st.session_state.workflow_fsm.num_states - 1
19
+ current_state_index = st.session_state.workflow_fsm.current_state_index
20
+ current_state_name = st.session_state.workflow_fsm.current_state
21
+ status = f"*Progress: {current_state_index}/{num_states}. Current: {current_state_name}.*"
22
+
23
+ st.session_state.disp_progress[0].markdown(status)
24
+ st.session_state.disp_progress[1].progress(current_state_index/num_states)
25
+
26
+
27
+ def init_workflow_viz(debug:bool=True) -> None:
28
+ """
29
+ Set up the streamlit elements for visualising the workflow progress.
30
+
31
+ Adds placeholders for progress indicators, and adds a button to manually refresh
32
+ the displayed progress. Note: The button is mainly a development aid.
33
+
34
+ Args:
35
+ debug (bool): If True, include the manual refresh button. Default is True.
36
+
37
+ """
38
+
39
+
40
+ #Initialise the layout containers used in the input handling
41
+ # add progress indicator to session_state
42
+ if "progress" not in st.session_state:
43
+ with st.sidebar:
44
+ st.session_state.disp_progress = [st.empty(), st.empty()]
45
+ if debug:
46
+ # add button to sidebar, with the callback to refesh_progress
47
+ st.sidebar.button("Refresh Progress", on_click=refresh_progress_display)
48
+
src/whale_viewer.py CHANGED
@@ -115,6 +115,9 @@ def format_whale_name(whale_class:str) -> str:
115
  Returns:
116
  str: The formatted whale name with spaces instead of underscores and each word capitalized.
117
  """
 
 
 
118
  whale_name = whale_class.replace("_", " ").title()
119
  return whale_name
120
 
 
115
  Returns:
116
  str: The formatted whale name with spaces instead of underscores and each word capitalized.
117
  """
118
+ if not isinstance(whale_class, str):
119
+ raise TypeError("whale_class should be a string.")
120
+
121
  whale_name = whale_class.replace("_", " ").title()
122
  return whale_name
123
 
tests/test_input_handling.py CHANGED
@@ -51,9 +51,6 @@ def test_is_valid_email_invalid():
51
  assert not is_valid_email("[email protected].")
52
  assert not is_valid_email("a@[email protected]")
53
 
54
- # not sure how xfails come through the CI pipeline yet.
55
- # maybe better to just comment out this stuff until pipeline is setup, then can check /extend
56
- @pytest.mark.xfail(reason="Bug identified, but while setting up CI having failing tests causes more headache")
57
  def test_is_valid_email_invalid_plus():
58
  assert not is_valid_email("[email protected]")
59
  assert not is_valid_email("[email protected]")
@@ -143,7 +140,7 @@ def test_get_image_latlon():
143
 
144
  # missing GPS loc
145
  f2 = test_data_pth / 'cakes_no_exif_gps.jpg'
146
- assert get_image_latlon(f2) == None
147
 
148
  # missng datetime -> expect gps not affected
149
  f3 = test_data_pth / 'cakes_no_exif_datetime.jpg'
@@ -151,7 +148,7 @@ def test_get_image_latlon():
151
 
152
  # tests for get_image_latlon with empty file
153
  def test_get_image_latlon_empty():
154
- assert get_image_latlon("") == None
155
 
156
  # tests for decimal_coords
157
  # - without input, py raises TypeError
 
51
  assert not is_valid_email("[email protected].")
52
  assert not is_valid_email("a@[email protected]")
53
 
 
 
 
54
  def test_is_valid_email_invalid_plus():
55
  assert not is_valid_email("[email protected]")
56
  assert not is_valid_email("[email protected]")
 
140
 
141
  # missing GPS loc
142
  f2 = test_data_pth / 'cakes_no_exif_gps.jpg'
143
+ assert get_image_latlon(f2) == (None, None)
144
 
145
  # missng datetime -> expect gps not affected
146
  f3 = test_data_pth / 'cakes_no_exif_datetime.jpg'
 
148
 
149
  # tests for get_image_latlon with empty file
150
  def test_get_image_latlon_empty():
151
+ assert get_image_latlon("") == (None, None)
152
 
153
  # tests for decimal_coords
154
  # - without input, py raises TypeError
tests/test_whale_viewer.py CHANGED
@@ -40,11 +40,9 @@ def test_format_whale_name_empty():
40
  assert format_whale_name("") == ""
41
 
42
  # testing with the wrong datatype
43
- # we should get a TypeError - currently it fails with a AttributeError
44
- @pytest.mark.xfail
45
  def test_format_whale_name_none():
46
  with pytest.raises(TypeError):
47
  format_whale_name(None)
48
 
49
 
50
- # display_whale requires UI to test it.
 
40
  assert format_whale_name("") == ""
41
 
42
  # testing with the wrong datatype
 
 
43
  def test_format_whale_name_none():
44
  with pytest.raises(TypeError):
45
  format_whale_name(None)
46
 
47
 
48
+ # display_whale requires UI to test it.