vancauwe commited on
Commit
1c0e2a5
·
1 Parent(s): 0e9c1b3

feat: hash used as identifier

Browse files
src/classifier/classifier_image.py CHANGED
@@ -12,21 +12,20 @@ from utils.grid_maker import gridder
12
  from utils.metadata_handler import metadata2md
13
 
14
  def cetacean_classify(cetacean_classifier):
15
- files = st.session_state.files
16
  images = st.session_state.images
17
  observations = st.session_state.observations
18
-
19
- batch_size, row_size, page = gridder(files)
20
 
21
  grid = st.columns(row_size)
22
  col = 0
23
-
24
- for file in files:
25
- image = images[file.name]
26
 
27
  with grid[col]:
28
  st.image(image, use_column_width=True)
29
- observation = observations[file.name].to_dict()
30
  # run classifier model on `image`, and persistently store the output
31
  out = cetacean_classifier(image) # get top 3 matches
32
  st.session_state.whale_prediction1 = out['predictions'][0]
@@ -44,14 +43,14 @@ def cetacean_classify(cetacean_classifier):
44
  # get index of pred1 from WHALE_CLASSES, none if not present
45
  print(f"[D] pred1: {pred1}")
46
  ix = viewer.WHALE_CLASSES.index(pred1) if pred1 in viewer.WHALE_CLASSES else None
47
- selected_class = st.selectbox(f"Species for {file.name}", viewer.WHALE_CLASSES, index=ix)
48
 
49
  observation['predicted_class'] = selected_class
50
  if selected_class != st.session_state.whale_prediction1:
51
  observation['class_overriden'] = selected_class
52
 
53
  st.session_state.public_observation = observation
54
- st.button(f"Upload observation for {file.name} to THE INTERNET!", on_click=push_observations)
55
  # TODO: the metadata only fills properly if `validate` was clicked.
56
  st.markdown(metadata2md())
57
 
@@ -62,7 +61,8 @@ def cetacean_classify(cetacean_classifier):
62
 
63
  whale_classes = out['predictions'][:]
64
  # render images for the top 3 (that is what the model api returns)
65
- st.markdown(f"Top 3 Predictions for {file.name}")
66
  for i in range(len(whale_classes)):
67
  viewer.display_whale(whale_classes, i)
 
68
  col = (col + 1) % row_size
 
12
  from utils.metadata_handler import metadata2md
13
 
14
  def cetacean_classify(cetacean_classifier):
 
15
  images = st.session_state.images
16
  observations = st.session_state.observations
17
+ hashes = st.session_state.image_hashes
18
+ batch_size, row_size, page = gridder(hashes)
19
 
20
  grid = st.columns(row_size)
21
  col = 0
22
+ o=1
23
+ for hash in hashes:
24
+ image = images[hash]
25
 
26
  with grid[col]:
27
  st.image(image, use_column_width=True)
28
+ observation = observations[hash].to_dict()
29
  # run classifier model on `image`, and persistently store the output
30
  out = cetacean_classifier(image) # get top 3 matches
31
  st.session_state.whale_prediction1 = out['predictions'][0]
 
43
  # get index of pred1 from WHALE_CLASSES, none if not present
44
  print(f"[D] pred1: {pred1}")
45
  ix = viewer.WHALE_CLASSES.index(pred1) if pred1 in viewer.WHALE_CLASSES else None
46
+ selected_class = st.selectbox(f"Species for observation {str(o)}", viewer.WHALE_CLASSES, index=ix)
47
 
48
  observation['predicted_class'] = selected_class
49
  if selected_class != st.session_state.whale_prediction1:
50
  observation['class_overriden'] = selected_class
51
 
52
  st.session_state.public_observation = observation
53
+ st.button(f"Upload observation {str(o)} to THE INTERNET!", on_click=push_observations)
54
  # TODO: the metadata only fills properly if `validate` was clicked.
55
  st.markdown(metadata2md())
56
 
 
61
 
62
  whale_classes = out['predictions'][:]
63
  # render images for the top 3 (that is what the model api returns)
64
+ st.markdown(f"Top 3 Predictions for observation {str(o)}")
65
  for i in range(len(whale_classes)):
66
  viewer.display_whale(whale_classes, i)
67
+ o += 1
68
  col = (col + 1) % row_size
src/input/input_handling.py CHANGED
@@ -66,6 +66,7 @@ def setup_input(
66
  uploaded_files = viewcontainer.file_uploader("Upload an image", type=allowed_image_types, accept_multiple_files=True)
67
  observations = {}
68
  images = {}
 
69
  if uploaded_files is not None:
70
  for file in uploaded_files:
71
 
@@ -108,11 +109,13 @@ def setup_input(
108
  observation = InputObservation(image=file, latitude=latitude, longitude=longitude,
109
  author_email=author_email, date=image_datetime, time=None,
110
  date_option=date_option, time_option=time_option)
111
- observations[file.name] = observation
112
- images[file.name] = image
 
 
113
 
114
  st.session_state.images = images
115
  st.session_state.files = uploaded_files
116
-
117
- return observations
118
 
 
66
  uploaded_files = viewcontainer.file_uploader("Upload an image", type=allowed_image_types, accept_multiple_files=True)
67
  observations = {}
68
  images = {}
69
+ image_hashes =[]
70
  if uploaded_files is not None:
71
  for file in uploaded_files:
72
 
 
109
  observation = InputObservation(image=file, latitude=latitude, longitude=longitude,
110
  author_email=author_email, date=image_datetime, time=None,
111
  date_option=date_option, time_option=time_option)
112
+ image_hash = observation.to_dict()["image_md5"]
113
+ observations[image_hash] = observation
114
+ images[image_hash] = image
115
+ image_hashes.append(image_hash)
116
 
117
  st.session_state.images = images
118
  st.session_state.files = uploaded_files
119
+ st.session_state.observations = observations
120
+ st.session_state.image_hashes = image_hashes
121
 
src/input/input_validator.py CHANGED
@@ -96,7 +96,8 @@ def decimal_coords(coords:tuple, ref:str) -> Fraction:
96
  return decimal_degrees
97
 
98
 
99
- def get_image_latlon(image_file: UploadedFile) -> tuple[float, float] | None:
 
100
  """
101
  Extracts the latitude and longitude from the EXIF metadata of an uploaded image file.
102
 
 
96
  return decimal_degrees
97
 
98
 
99
+ #def get_image_latlon(image_file: UploadedFile) -> tuple[float, float] | None:
100
+ def get_image_latlon(image_file: UploadedFile) :
101
  """
102
  Extracts the latitude and longitude from the EXIF metadata of an uploaded image file.
103
 
src/main.py CHANGED
@@ -9,6 +9,7 @@ from streamlit_folium import st_folium
9
  from transformers import pipeline
10
  from transformers import AutoModelForImageClassification
11
 
 
12
  from datasets import disable_caching
13
  disable_caching()
14
 
@@ -44,6 +45,9 @@ st.set_page_config(layout="wide")
44
  if "handler" not in st.session_state:
45
  st.session_state['handler'] = setup_logging()
46
 
 
 
 
47
  if "observations" not in st.session_state:
48
  st.session_state.observations = {}
49
 
@@ -100,7 +104,7 @@ def main() -> None:
100
 
101
 
102
  # create a sidebar, and parse all the input (returned as `observations` object)
103
- observations = setup_input(viewcontainer=st.sidebar)
104
 
105
 
106
  if 0:## WIP
@@ -118,7 +122,7 @@ def main() -> None:
118
  with tab_map:
119
  # visual structure: a couple of toggles at the top, then the map inlcuding a
120
  # dropdown for tileset selection.
121
- sw_map.add_header_text()
122
  tab_map_ui_cols = st.columns(2)
123
  with tab_map_ui_cols[0]:
124
  show_db_points = st.toggle("Show Points from DB", True)
@@ -179,12 +183,8 @@ def main() -> None:
179
  # Display submitted observation
180
  if st.sidebar.button("Validate"):
181
  # create a dictionary with the submitted observation
182
- submitted_data = observations
183
- st.session_state.observations = observations
184
-
185
  tab_log.info(f"{st.session_state.observations}")
186
-
187
- df = pd.DataFrame(submitted_data, index=[0])
188
  with tab_coords:
189
  st.table(df)
190
 
 
9
  from transformers import pipeline
10
  from transformers import AutoModelForImageClassification
11
 
12
+ from maps.obs_map import add_header_text
13
  from datasets import disable_caching
14
  disable_caching()
15
 
 
45
  if "handler" not in st.session_state:
46
  st.session_state['handler'] = setup_logging()
47
 
48
+ if "image_hashes" not in st.session_state:
49
+ st.session_state.image_hashes = []
50
+
51
  if "observations" not in st.session_state:
52
  st.session_state.observations = {}
53
 
 
104
 
105
 
106
  # create a sidebar, and parse all the input (returned as `observations` object)
107
+ setup_input(viewcontainer=st.sidebar)
108
 
109
 
110
  if 0:## WIP
 
122
  with tab_map:
123
  # visual structure: a couple of toggles at the top, then the map inlcuding a
124
  # dropdown for tileset selection.
125
+ add_header_text()
126
  tab_map_ui_cols = st.columns(2)
127
  with tab_map_ui_cols[0]:
128
  show_db_points = st.toggle("Show Points from DB", True)
 
183
  # Display submitted observation
184
  if st.sidebar.button("Validate"):
185
  # create a dictionary with the submitted observation
 
 
 
186
  tab_log.info(f"{st.session_state.observations}")
187
+ df = pd.DataFrame(st.session_state.observations, index=[0])
 
188
  with tab_coords:
189
  st.table(df)
190
 
src/utils/grid_maker.py CHANGED
@@ -1,13 +1,13 @@
1
  import streamlit as st
2
  import math
3
 
4
- def gridder(files):
5
  cols = st.columns(3)
6
  with cols[0]:
7
  batch_size = st.select_slider("Batch size:",range(10,110,10), value=10)
8
  with cols[1]:
9
  row_size = st.select_slider("Row size:", range(1,6), value = 5)
10
- num_batches = math.ceil(len(files)/batch_size)
11
  with cols[2]:
12
  page = st.selectbox("Page", range(1,num_batches+1))
13
  return batch_size, row_size, page
 
1
  import streamlit as st
2
  import math
3
 
4
+ def gridder(items):
5
  cols = st.columns(3)
6
  with cols[0]:
7
  batch_size = st.select_slider("Batch size:",range(10,110,10), value=10)
8
  with cols[1]:
9
  row_size = st.select_slider("Row size:", range(1,6), value = 5)
10
+ num_batches = math.ceil(len(items)/batch_size)
11
  with cols[2]:
12
  page = st.selectbox("Page", range(1,num_batches+1))
13
  return batch_size, row_size, page