vancauwe commited on
Commit
0e8c927
Β·
1 Parent(s): 54319e9

feat: refactor and multi image classification

Browse files
src/classifier/classifier_hotdog.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import json
3
+ from PIL import Image
4
+
5
+
6
+ def hotdog_classify(pipeline_hot_dog, tab_hotdogs):
7
+ col1, col2 = tab_hotdogs.columns(2)
8
+ for file in st.session_state.files:
9
+ image = st.session_state.images[file.name]
10
+ observation = st.session_state.observations[file.name].to_dict()
11
+ # display the image (use cached version, no need to reread)
12
+ col1.image(image, use_column_width=True)
13
+ # and then run inference on the image
14
+ hotdog_image = Image.fromarray(image)
15
+ predictions = pipeline_hot_dog(hotdog_image)
16
+
17
+ col2.header("Probabilities")
18
+ first = True
19
+ for p in predictions:
20
+ col2.subheader(f"{ p['label'] }: { round(p['score'] * 100, 1)}%")
21
+ if first:
22
+ observation['predicted_class'] = p['label']
23
+ observation['predicted_score'] = round(p['score'] * 100, 1)
24
+ first = False
25
+
26
+ tab_hotdogs.write(f"Session observation: {json.dumps(observation)}")
src/classifier/classifier_image.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import logging
3
+
4
+ # get a global var for logger accessor in this module
5
+ LOG_LEVEL = logging.DEBUG
6
+ g_logger = logging.getLogger(__name__)
7
+ g_logger.setLevel(LOG_LEVEL)
8
+
9
+ import whale_viewer as viewer
10
+ from hf_push_observations import push_observations
11
+ from utils.grid_maker import gridder
12
+ from utils.metadata_handler import metadata2md
13
+
14
+ def cetacean_classify(cetacean_classifier, tab_inference):
15
+ files = st.session_state.files
16
+ images = st.session_state.images
17
+ observations = st.session_state.observations
18
+
19
+ batch_size, row_size, page = gridder(files)
20
+
21
+ grid = st.columns(row_size)
22
+ col = 0
23
+
24
+ for file in files:
25
+ image = images[file.name]
26
+
27
+ with grid[col]:
28
+ st.image(image, use_column_width=True)
29
+ observation = observations[file.name].to_dict()
30
+ # run classifier model on `image`, and persistently store the output
31
+ out = cetacean_classifier(image) # get top 3 matches
32
+ st.session_state.whale_prediction1 = out['predictions'][0]
33
+ st.session_state.classify_whale_done = True
34
+ msg = f"[D]2 classify_whale_done: {st.session_state.classify_whale_done}, whale_prediction1: {st.session_state.whale_prediction1}"
35
+ g_logger.info(msg)
36
+
37
+ # dropdown for selecting/overriding the species prediction
38
+ if not st.session_state.classify_whale_done:
39
+ selected_class = st.sidebar.selectbox("Species", viewer.WHALE_CLASSES,
40
+ index=None, placeholder="Species not yet identified...",
41
+ disabled=True)
42
+ else:
43
+ pred1 = st.session_state.whale_prediction1
44
+ # get index of pred1 from WHALE_CLASSES, none if not present
45
+ print(f"[D] pred1: {pred1}")
46
+ ix = viewer.WHALE_CLASSES.index(pred1) if pred1 in viewer.WHALE_CLASSES else None
47
+ selected_class = st.selectbox(f"Species for {file.name}", viewer.WHALE_CLASSES, index=ix)
48
+
49
+ observation['predicted_class'] = selected_class
50
+ if selected_class != st.session_state.whale_prediction1:
51
+ observation['class_overriden'] = selected_class
52
+
53
+ st.session_state.public_observation = observation
54
+ st.button(f"Upload observation for {file.name} to THE INTERNET!", on_click=push_observations)
55
+ # TODO: the metadata only fills properly if `validate` was clicked.
56
+ st.markdown(metadata2md())
57
+
58
+ msg = f"[D] full observation after inference: {observation}"
59
+ g_logger.debug(msg)
60
+ print(msg)
61
+ # TODO: add a link to more info on the model, next to the button.
62
+
63
+ whale_classes = out['predictions'][:]
64
+ # render images for the top 3 (that is what the model api returns)
65
+ #with tab_inference:
66
+ st.markdown(f"Top 3 Predictions for {file.name}")
67
+ for i in range(len(whale_classes)):
68
+ viewer.display_whale(whale_classes, i)
69
+ col = (col + 1) % row_size
src/classifier_image.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import logging
3
+ import os
4
+
5
+ # get a global var for logger accessor in this module
6
+ LOG_LEVEL = logging.DEBUG
7
+ g_logger = logging.getLogger(__name__)
8
+ g_logger.setLevel(LOG_LEVEL)
9
+
10
+ from grid_maker import gridder
11
+ import hf_push_observations as sw_push_obs
12
+ import utils.metadata_handler as meta_handler
13
+ import whale_viewer as sw_wv
14
+
15
+ def cetacean_classify(cetacean_classifier, tab_inference):
16
+ files = st.session_state.files
17
+ images = st.session_state.images
18
+ observations = st.session_state.observations
19
+
20
+ batch_size, row_size, page = gridder(files)
21
+
22
+ grid = st.columns(row_size)
23
+ col = 0
24
+
25
+ for file in files:
26
+ image = images[file.name]
27
+
28
+ with grid[col]:
29
+ st.image(image, use_column_width=True)
30
+ observation = observations[file.name].to_dict()
31
+ # run classifier model on `image`, and persistently store the output
32
+ out = cetacean_classifier(image) # get top 3 matches
33
+ st.session_state.whale_prediction1 = out['predictions'][0]
34
+ st.session_state.classify_whale_done = True
35
+ msg = f"[D]2 classify_whale_done: {st.session_state.classify_whale_done}, whale_prediction1: {st.session_state.whale_prediction1}"
36
+ g_logger.info(msg)
37
+
38
+ # dropdown for selecting/overriding the species prediction
39
+ if not st.session_state.classify_whale_done:
40
+ selected_class = st.sidebar.selectbox("Species", sw_wv.WHALE_CLASSES,
41
+ index=None, placeholder="Species not yet identified...",
42
+ disabled=True)
43
+ else:
44
+ pred1 = st.session_state.whale_prediction1
45
+ # get index of pred1 from WHALE_CLASSES, none if not present
46
+ print(f"[D] pred1: {pred1}")
47
+ ix = sw_wv.WHALE_CLASSES.index(pred1) if pred1 in sw_wv.WHALE_CLASSES else None
48
+ selected_class = tab_inference.selectbox("Species", sw_wv.WHALE_CLASSES, index=ix)
49
+
50
+ observation['predicted_class'] = selected_class
51
+ if selected_class != st.session_state.whale_prediction1:
52
+ observation['class_overriden'] = selected_class
53
+
54
+ st.session_state.public_observation = observation
55
+ st.button(f"Upload observation for {file.name} to THE INTERNET!", on_click=sw_push_obs.push_observations)
56
+ # TODO: the metadata only fills properly if `validate` was clicked.
57
+ st.markdown(meta_handler.metadata2md())
58
+
59
+ msg = f"[D] full observation after inference: {observation}"
60
+ g_logger.debug(msg)
61
+ print(msg)
62
+ # TODO: add a link to more info on the model, next to the button.
63
+
64
+ whale_classes = out['predictions'][:]
65
+ # render images for the top 3 (that is what the model api returns)
66
+ #with tab_inference:
67
+ st.title(f"Species detected for {file.name}")
68
+ for i in range(len(whale_classes)):
69
+ sw_wv.display_whale(whale_classes, i)
70
+ col = (col + 1) % row_size
src/hf_push_observations.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from streamlit.delta_generator import DeltaGenerator
2
+ import streamlit as st
3
+ from huggingface_hub import HfApi
4
+ import json
5
+ import tempfile
6
+ import logging
7
+
8
+ # get a global var for logger accessor in this module
9
+ LOG_LEVEL = logging.DEBUG
10
+ g_logger = logging.getLogger(__name__)
11
+ g_logger.setLevel(LOG_LEVEL)
12
+
13
+ def push_observations(tab_log:DeltaGenerator=None):
14
+ """
15
+ Push the observations to the Hugging Face dataset
16
+
17
+ Args:
18
+ tab_log (streamlit.container): The container to log messages to. If not provided,
19
+ log messages are in any case written to the global logger (TODO: test - didn't
20
+ push any observation since generating the logger)
21
+
22
+ """
23
+ # we get the observation from session state: 1 is the dict 2 is the image.
24
+ # first, lets do an info display (popup)
25
+ metadata_str = json.dumps(st.session_state.public_observation)
26
+
27
+ st.toast(f"Uploading observations: {metadata_str}", icon="🦭")
28
+ tab_log = st.session_state.tab_log
29
+ if tab_log is not None:
30
+ tab_log.info(f"Uploading observations: {metadata_str}")
31
+
32
+ # get huggingface api
33
+ import os
34
+ token = os.environ.get("HF_TOKEN", None)
35
+ api = HfApi(token=token)
36
+
37
+ f = tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False)
38
+ f.write(metadata_str)
39
+ f.close()
40
+ st.info(f"temp file: {f.name} with metadata written...")
41
+
42
+ path_in_repo= f"metadata/{st.session_state.public_observation['author_email']}/{st.session_state.public_observation['image_md5']}.json"
43
+ msg = f"fname: {f.name} | path: {path_in_repo}"
44
+ print(msg)
45
+ st.warning(msg)
46
+ # rv = api.upload_file(
47
+ # path_or_fileobj=f.name,
48
+ # path_in_repo=path_in_repo,
49
+ # repo_id="Saving-Willy/temp_dataset",
50
+ # repo_type="dataset",
51
+ # )
52
+ # print(rv)
53
+ # msg = f"observation attempted tx to repo happy walrus: {rv}"
54
+ g_logger.info(msg)
55
+ st.info(msg)
56
+
src/{input_handling.py β†’ input/input_handling.py} RENAMED
@@ -1,19 +1,14 @@
1
- from PIL import Image
2
- from PIL import ExifTags
3
- import re
4
  import datetime
5
- import hashlib
6
  import logging
7
 
8
  import streamlit as st
9
- from streamlit.runtime.uploaded_file_manager import UploadedFile # for type hinting
10
  from streamlit.delta_generator import DeltaGenerator
11
 
12
  import cv2
13
  import numpy as np
14
 
15
- import random
16
- import string
17
 
18
  m_logger = logging.getLogger(__name__)
19
  m_logger.setLevel(logging.INFO)
@@ -25,172 +20,6 @@ both the UI elements (setup_input_UI) and the validation functions.
25
  '''
26
  allowed_image_types = ['jpg', 'jpeg', 'png', 'webp']
27
 
28
- def generate_random_md5():
29
- # Generate a random string
30
- random_string = ''.join(random.choices(string.ascii_letters + string.digits, k=16))
31
- # Encode the string and compute its MD5 hash
32
- md5_hash = hashlib.md5(random_string.encode()).hexdigest()
33
- return md5_hash
34
-
35
- # autogenerated class to hold the input data
36
- class InputObservation:
37
- """
38
- A class to hold an input observation and associated metadata
39
-
40
- Attributes:
41
- image (Any):
42
- The image associated with the observation.
43
- latitude (float):
44
- The latitude where the observation was made.
45
- longitude (float):
46
- The longitude where the observation was made.
47
- author_email (str):
48
- The email of the author of the observation.
49
- date (str):
50
- The date when the observation was made.
51
- time (str):
52
- The time when the observation was made.
53
- date_option (str):
54
- Additional date option for the observation.
55
- time_option (str):
56
- Additional time option for the observation.
57
- uploaded_filename (Any):
58
- The uploaded filename associated with the observation.
59
-
60
- Methods:
61
- __str__():
62
- Returns a string representation of the observation.
63
- __repr__():
64
- Returns a string representation of the observation.
65
- __eq__(other):
66
- Checks if two observations are equal.
67
- __ne__(other):
68
- Checks if two observations are not equal.
69
- __hash__():
70
- Returns the hash of the observation.
71
- to_dict():
72
- Converts the observation to a dictionary.
73
- from_dict(data):
74
- Creates an observation from a dictionary.
75
- from_input(input):
76
- Creates an observation from another input observation.
77
- """
78
- def __init__(self, image=None, latitude=None, longitude=None, author_email=None, date=None, time=None, date_option=None, time_option=None, uploaded_filename=None):
79
- self.image = image
80
- self.latitude = latitude
81
- self.longitude = longitude
82
- self.author_email = author_email
83
- self.date = date
84
- self.time = time
85
- self.date_option = date_option
86
- self.time_option = time_option
87
- self.uploaded_filename = uploaded_filename
88
-
89
- def __str__(self):
90
- return f"Observation: {self.image}, {self.latitude}, {self.longitude}, {self.author_email}, {self.date}, {self.time}, {self.date_option}, {self.time_option}, {self.uploaded_filename}"
91
-
92
- def __repr__(self):
93
- return f"Observation: {self.image}, {self.latitude}, {self.longitude}, {self.author_email}, {self.date}, {self.time}, {self.date_option}, {self.time_option}, {self.uploaded_filename}"
94
-
95
- def __eq__(self, other):
96
- return (self.image == other.image and self.latitude == other.latitude and self.longitude == other.longitude and
97
- self.author_email == other.author_email and self.date == other.date and self.time == other.time and
98
- self.date_option == other.date_option and self.time_option == other.time_option and self.uploaded_filename == other.uploaded_filename)
99
-
100
- def __ne__(self, other):
101
- return not self.__eq__(other)
102
-
103
- def __hash__(self):
104
- return hash((self.image, self.latitude, self.longitude, self.author_email, self.date, self.time, self.date_option, self.time_option, self.uploaded_filename))
105
-
106
- def to_dict(self):
107
- return {
108
- #"image": self.image,
109
- "image_filename": self.uploaded_filename.name if self.uploaded_filename else None,
110
- "image_md5": hashlib.md5(self.uploaded_filename.read()).hexdigest() if self.uploaded_filename else generate_random_md5(),
111
- "latitude": self.latitude,
112
- "longitude": self.longitude,
113
- "author_email": self.author_email,
114
- "date": self.date,
115
- "time": self.time,
116
- "date_option": str(self.date_option),
117
- "time_option": str(self.time_option),
118
- "uploaded_filename": self.uploaded_filename
119
- }
120
-
121
- @classmethod
122
- def from_dict(cls, data):
123
- return cls(data["image"], data["latitude"], data["longitude"], data["author_email"], data["date"], data["time"], data["date_option"], data["time_option"], data["uploaded_filename"])
124
-
125
- @classmethod
126
- def from_input(cls, input):
127
- return cls(input.image, input.latitude, input.longitude, input.author_email, input.date, input.time, input.date_option, input.time_option, input.uploaded_filename)
128
-
129
- @staticmethod
130
- def from_input(input):
131
- return InputObservation(input.image, input.latitude, input.longitude, input.author_email, input.date, input.time, input.date_option, input.time_option, input.uploaded_filename)
132
-
133
- @staticmethod
134
- def from_dict(data):
135
- return InputObservation(data["image"], data["latitude"], data["longitude"], data["author_email"], data["date"], data["time"], data["date_option"], data["time_option"], data["uploaded_filename"])
136
-
137
-
138
- def is_valid_number(number:str) -> bool:
139
- """
140
- Check if the given string is a valid number (int or float, sign ok)
141
-
142
- Args:
143
- number (str): The string to be checked.
144
-
145
- Returns:
146
- bool: True if the string is a valid number, False otherwise.
147
- """
148
- pattern = r'^[-+]?[0-9]*\.?[0-9]+$'
149
- return re.match(pattern, number) is not None
150
-
151
-
152
- # Function to validate email address
153
- def is_valid_email(email:str) -> bool:
154
- """
155
- Validates if the provided email address is in a correct format.
156
-
157
- Args:
158
- email (str): The email address to validate.
159
-
160
- Returns:
161
- bool: True if the email address is valid, False otherwise.
162
- """
163
- pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
164
- return re.match(pattern, email) is not None
165
-
166
- # Function to extract date and time from image metadata
167
- # def get_image_datetime(image_file: UploadedFile) -> str | None:
168
- def get_image_datetime(image_file):
169
- """
170
- Extracts the original date and time from the EXIF metadata of an uploaded image file.
171
-
172
- Args:
173
- image_file (UploadedFile): The uploaded image file from which to extract the date and time.
174
-
175
- Returns:
176
- str: The original date and time as a string if available, otherwise None.
177
-
178
- Raises:
179
- Warning: If the date and time could not be extracted from the image metadata.
180
- """
181
- try:
182
- image = Image.open(image_file)
183
- exif_data = image._getexif()
184
- if exif_data is not None:
185
- for tag, value in exif_data.items():
186
- if ExifTags.TAGS.get(tag) == 'DateTimeOriginal':
187
- return value
188
- except Exception as e: # FIXME: what types of exception?
189
- st.warning(f"Could not extract date from image metadata. (file: {image_file.name})")
190
- # TODO: add to logger
191
- return None
192
-
193
-
194
  # an arbitrary set of defaults so testing is less painful...
195
  # ideally we add in some randomization to the defaults
196
  spoof_metadata = {
@@ -282,7 +111,7 @@ def setup_input(
282
  observations[file.name] = observation
283
  images[file.name] = image
284
 
285
- st.session_state.image = images
286
  st.session_state.files = uploaded_files
287
 
288
  return observations
 
 
 
 
1
  import datetime
 
2
  import logging
3
 
4
  import streamlit as st
 
5
  from streamlit.delta_generator import DeltaGenerator
6
 
7
  import cv2
8
  import numpy as np
9
 
10
+ from input.input_observation import InputObservation
11
+ from input.input_validator import get_image_datetime, is_valid_email, is_valid_number
12
 
13
  m_logger = logging.getLogger(__name__)
14
  m_logger.setLevel(logging.INFO)
 
20
  '''
21
  allowed_image_types = ['jpg', 'jpeg', 'png', 'webp']
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  # an arbitrary set of defaults so testing is less painful...
24
  # ideally we add in some randomization to the defaults
25
  spoof_metadata = {
 
111
  observations[file.name] = observation
112
  images[file.name] = image
113
 
114
+ st.session_state.images = images
115
  st.session_state.files = uploaded_files
116
 
117
  return observations
src/input/input_observation.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ from input.input_validator import generate_random_md5
3
+
4
+ # autogenerated class to hold the input data
5
+ class InputObservation:
6
+ """
7
+ A class to hold an input observation and associated metadata
8
+
9
+ Attributes:
10
+ image (Any):
11
+ The image associated with the observation.
12
+ latitude (float):
13
+ The latitude where the observation was made.
14
+ longitude (float):
15
+ The longitude where the observation was made.
16
+ author_email (str):
17
+ The email of the author of the observation.
18
+ date (str):
19
+ The date when the observation was made.
20
+ time (str):
21
+ The time when the observation was made.
22
+ date_option (str):
23
+ Additional date option for the observation.
24
+ time_option (str):
25
+ Additional time option for the observation.
26
+ uploaded_filename (Any):
27
+ The uploaded filename associated with the observation.
28
+
29
+ Methods:
30
+ __str__():
31
+ Returns a string representation of the observation.
32
+ __repr__():
33
+ Returns a string representation of the observation.
34
+ __eq__(other):
35
+ Checks if two observations are equal.
36
+ __ne__(other):
37
+ Checks if two observations are not equal.
38
+ __hash__():
39
+ Returns the hash of the observation.
40
+ to_dict():
41
+ Converts the observation to a dictionary.
42
+ from_dict(data):
43
+ Creates an observation from a dictionary.
44
+ from_input(input):
45
+ Creates an observation from another input observation.
46
+ """
47
+ def __init__(self, image=None, latitude=None, longitude=None,
48
+ author_email=None, date=None, time=None, date_option=None, time_option=None,
49
+ uploaded_filename=None):
50
+ self.image = image
51
+ self.latitude = latitude
52
+ self.longitude = longitude
53
+ self.author_email = author_email
54
+ self.date = date
55
+ self.time = time
56
+ self.date_option = date_option
57
+ self.time_option = time_option
58
+ self.uploaded_filename = uploaded_filename
59
+
60
+ def __str__(self):
61
+ return f"Observation: {self.image}, {self.latitude}, {self.longitude}, {self.author_email}, {self.date}, {self.time}, {self.date_option}, {self.time_option}, {self.uploaded_filename}"
62
+
63
+ def __repr__(self):
64
+ return f"Observation: {self.image}, {self.latitude}, {self.longitude}, {self.author_email}, {self.date}, {self.time}, {self.date_option}, {self.time_option}, {self.uploaded_filename}"
65
+
66
+ def __eq__(self, other):
67
+ return (self.image == other.image and self.latitude == other.latitude and self.longitude == other.longitude and
68
+ self.author_email == other.author_email and self.date == other.date and self.time == other.time and
69
+ self.date_option == other.date_option and self.time_option == other.time_option and self.uploaded_filename == other.uploaded_filename)
70
+
71
+ def __ne__(self, other):
72
+ return not self.__eq__(other)
73
+
74
+ def __hash__(self):
75
+ return hash((self.image, self.latitude, self.longitude, self.author_email, self.date, self.time, self.date_option, self.time_option, self.uploaded_filename))
76
+
77
+ def to_dict(self):
78
+ return {
79
+ #"image": self.image,
80
+ "image_filename": self.uploaded_filename.name if self.uploaded_filename else None,
81
+ "image_md5": hashlib.md5(self.uploaded_filename.read()).hexdigest() if self.uploaded_filename else generate_random_md5(),
82
+ "latitude": self.latitude,
83
+ "longitude": self.longitude,
84
+ "author_email": self.author_email,
85
+ "date": self.date,
86
+ "time": self.time,
87
+ "date_option": str(self.date_option),
88
+ "time_option": str(self.time_option),
89
+ "uploaded_filename": self.uploaded_filename
90
+ }
91
+
92
+ @classmethod
93
+ def from_dict(cls, data):
94
+ return cls(data["image"], data["latitude"], data["longitude"], data["author_email"], data["date"], data["time"], data["date_option"], data["time_option"], data["uploaded_filename"])
95
+
96
+ @classmethod
97
+ def from_input(cls, input):
98
+ return cls(input.image, input.latitude, input.longitude, input.author_email, input.date, input.time, input.date_option, input.time_option, input.uploaded_filename)
99
+
100
+ @staticmethod
101
+ def from_input(input):
102
+ return InputObservation(input.image, input.latitude, input.longitude, input.author_email, input.date, input.time, input.date_option, input.time_option, input.uploaded_filename)
103
+
104
+ @staticmethod
105
+ def from_dict(data):
106
+ return InputObservation(data["image"], data["latitude"], data["longitude"], data["author_email"], data["date"], data["time"], data["date_option"], data["time_option"], data["uploaded_filename"])
107
+
108
+
109
+
110
+
src/input/input_validator.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import string
3
+ import hashlib
4
+ import re
5
+ import streamlit as st
6
+
7
+ from PIL import Image
8
+ from PIL import ExifTags
9
+
10
+ def generate_random_md5():
11
+ # Generate a random string
12
+ random_string = ''.join(random.choices(string.ascii_letters + string.digits, k=16))
13
+ # Encode the string and compute its MD5 hash
14
+ md5_hash = hashlib.md5(random_string.encode()).hexdigest()
15
+ return md5_hash
16
+
17
+ def is_valid_number(number:str) -> bool:
18
+ """
19
+ Check if the given string is a valid number (int or float, sign ok)
20
+
21
+ Args:
22
+ number (str): The string to be checked.
23
+
24
+ Returns:
25
+ bool: True if the string is a valid number, False otherwise.
26
+ """
27
+ pattern = r'^[-+]?[0-9]*\.?[0-9]+$'
28
+ return re.match(pattern, number) is not None
29
+
30
+ # Function to validate email address
31
+ def is_valid_email(email:str) -> bool:
32
+ """
33
+ Validates if the provided email address is in a correct format.
34
+
35
+ Args:
36
+ email (str): The email address to validate.
37
+
38
+ Returns:
39
+ bool: True if the email address is valid, False otherwise.
40
+ """
41
+ pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
42
+ return re.match(pattern, email) is not None
43
+
44
+ # Function to extract date and time from image metadata
45
+ def get_image_datetime(image_file):
46
+ """
47
+ Extracts the original date and time from the EXIF metadata of an uploaded image file.
48
+
49
+ Args:
50
+ image_file (UploadedFile): The uploaded image file from which to extract the date and time.
51
+
52
+ Returns:
53
+ str: The original date and time as a string if available, otherwise None.
54
+
55
+ Raises:
56
+ Warning: If the date and time could not be extracted from the image metadata.
57
+ """
58
+ try:
59
+ image = Image.open(image_file)
60
+ exif_data = image._getexif()
61
+ if exif_data is not None:
62
+ for tag, value in exif_data.items():
63
+ if ExifTags.TAGS.get(tag) == 'DateTimeOriginal':
64
+ return value
65
+ except Exception as e: # FIXME: what types of exception?
66
+ st.warning(f"Could not extract date from image metadata. (file: {image_file.name})")
67
+ # TODO: add to logger
68
+ return None
src/main.py CHANGED
@@ -1,31 +1,25 @@
1
- #import datetime
2
- from PIL import Image
3
-
4
- import json
5
  import logging
6
  import os
7
- import tempfile
8
 
9
  import pandas as pd
10
  import streamlit as st
11
- from streamlit.delta_generator import DeltaGenerator # for type hinting
12
  import folium
13
  from streamlit_folium import st_folium
14
- from huggingface_hub import HfApi
15
  from transformers import pipeline
16
  from transformers import AutoModelForImageClassification
17
 
18
  from datasets import disable_caching
19
  disable_caching()
20
 
21
- import alps_map as sw_am
22
- import input_handling as sw_inp
23
- import obs_map as sw_map
24
- import st_logs as sw_logs
25
- import whale_gallery as sw_wg
26
- import whale_viewer as sw_wv
27
-
28
-
29
 
30
 
31
  # setup for the ML model on huggingface (our wrapper)
@@ -45,96 +39,40 @@ g_logger = logging.getLogger(__name__)
45
  g_logger.setLevel(LOG_LEVEL)
46
 
47
  st.set_page_config(layout="wide")
48
- #sw_logs.setup_logging(level=LOG_LEVEL, buffer_len=40)
49
-
50
-
51
 
52
  # initialise various session state variables
53
  if "handler" not in st.session_state:
54
- st.session_state['handler'] = sw_logs.setup_logging()
 
 
 
 
 
 
55
 
56
- if "full_data" not in st.session_state:
57
- st.session_state.full_data = {}
 
 
 
58
 
59
  if "classify_whale_done" not in st.session_state:
60
  st.session_state.classify_whale_done = False
61
 
62
  if "whale_prediction1" not in st.session_state:
63
  st.session_state.whale_prediction1 = None
64
-
65
- if "image" not in st.session_state:
66
- st.session_state.image = None
67
 
68
  if "tab_log" not in st.session_state:
69
  st.session_state.tab_log = None
70
 
71
 
72
- def metadata2md() -> str:
73
- """Get metadata from cache and return as markdown-formatted key-value list
74
-
75
- Returns:
76
- str: Markdown-formatted key-value list of metadata
77
-
78
- """
79
- markdown_str = "\n"
80
- for key, value in st.session_state.public_observation.items():
81
- markdown_str += f"- **{key}**: {value}\n"
82
- return markdown_str
83
-
84
-
85
- def push_observations(tab_log:DeltaGenerator=None):
86
- """
87
- Push the observations to the Hugging Face dataset
88
-
89
- Args:
90
- tab_log (streamlit.container): The container to log messages to. If not provided,
91
- log messages are in any case written to the global logger (TODO: test - didn't
92
- push any data since generating the logger)
93
-
94
- """
95
- # we get the data from session state: 1 is the dict 2 is the image.
96
- # first, lets do an info display (popup)
97
- metadata_str = json.dumps(st.session_state.public_observation)
98
-
99
- st.toast(f"Uploading observations: {metadata_str}", icon="🦭")
100
- tab_log = st.session_state.tab_log
101
- if tab_log is not None:
102
- tab_log.info(f"Uploading observations: {metadata_str}")
103
-
104
- # get huggingface api
105
- import os
106
- token = os.environ.get("HF_TOKEN", None)
107
- api = HfApi(token=token)
108
-
109
- f = tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False)
110
- f.write(metadata_str)
111
- f.close()
112
- st.info(f"temp file: {f.name} with metadata written...")
113
-
114
- path_in_repo= f"metadata/{st.session_state.public_observation['author_email']}/{st.session_state.public_observation['image_md5']}.json"
115
- msg = f"fname: {f.name} | path: {path_in_repo}"
116
- print(msg)
117
- st.warning(msg)
118
- rv = api.upload_file(
119
- path_or_fileobj=f.name,
120
- path_in_repo=path_in_repo,
121
- repo_id="Saving-Willy/temp_dataset",
122
- repo_type="dataset",
123
- )
124
- print(rv)
125
- msg = f"data attempted tx to repo happy walrus: {rv}"
126
- g_logger.info(msg)
127
- st.info(msg)
128
-
129
-
130
-
131
  def main() -> None:
132
  """
133
  Main entry point to set up the streamlit UI and run the application.
134
 
135
  The organisation is as follows:
136
 
137
- 1. data input (a new observations) is handled in the sidebar
138
  2. the rest of the interface is organised in tabs:
139
 
140
  - cetean classifier
@@ -156,25 +94,25 @@ def main() -> None:
156
  #g_logger.warning("warning message")
157
 
158
  # Streamlit app
159
- #tab_gallery, tab_inference, tab_hotdogs, tab_map, tab_data, tab_log = st.tabs(["Cetecean classifier", "Hotdog classifier", "Map", "Data", "Log", "Beautiful cetaceans"])
160
- tab_inference, tab_hotdogs, tab_map, tab_data, tab_log, tab_gallery = st.tabs(["Cetecean classifier", "Hotdog classifier", "Map", "Data", "Log", "Beautiful cetaceans"])
161
  st.session_state.tab_log = tab_log
162
 
163
 
164
  # create a sidebar, and parse all the input (returned as `observations` object)
165
- observations = sw_inp.setup_input(viewcontainer=st.sidebar)
166
 
167
 
168
  if 0:## WIP
169
  # goal of this code is to allow the user to override the ML prediction, before transmitting an observations
170
- predicted_class = st.sidebar.selectbox("Predicted Class", sw_wv.WHALE_CLASSES)
171
  override_prediction = st.sidebar.checkbox("Override Prediction")
172
 
173
  if override_prediction:
174
- overridden_class = st.sidebar.selectbox("Override Class", sw_wv.WHALE_CLASSES)
175
- st.session_state.full_data['class_overriden'] = overridden_class
176
  else:
177
- st.session_state.full_data['class_overriden'] = None
178
 
179
 
180
  with tab_map:
@@ -188,19 +126,19 @@ def main() -> None:
188
 
189
  if show_db_points:
190
  # show a nicer map, observations marked, tileset selectable.
191
- st_data = sw_map.present_obs_map(
192
  dataset_id=dataset_id, data_files=data_files,
193
  dbg_show_extra=dbg_show_extra)
194
 
195
  else:
196
  # development map.
197
- st_data = sw_am.present_alps_map()
198
 
199
 
200
  with tab_log:
201
  handler = st.session_state['handler']
202
  if handler is not None:
203
- records = sw_logs.parse_log_buffer(handler.buffer)
204
  st.dataframe(records[::-1], use_container_width=True,)
205
  st.info(f"Length of records: {len(records)}")
206
  else:
@@ -230,19 +168,18 @@ def main() -> None:
230
  # specific to the gallery (otherwise we get side effects)
231
  tg_cont = st.container(key="swgallery")
232
  with tg_cont:
233
- sw_wg.render_whale_gallery(n_cols=4)
234
 
235
 
236
- # Display submitted data
237
  if st.sidebar.button("Validate"):
238
- # create a dictionary with the submitted data
239
  submitted_data = observations
240
- st.session_state.full_data = observations
241
 
242
- tab_log.info(f"{st.session_state.full_data}")
243
 
244
- df = pd.DataFrame(submitted_data)
245
- print("Dataframe Shape: ", df.shape)
246
  with tab_data:
247
  st.table(df)
248
 
@@ -254,7 +191,7 @@ def main() -> None:
254
  # - the model predicts the top 3 most likely species from the input image
255
  # - these species are shown
256
  # - the user can override the species prediction using the dropdown
257
- # - an observations is uploaded if the user chooses.
258
 
259
  if tab_inference.button("Identify with cetacean classifier"):
260
  #pipe = pipeline("image-classification", model="Saving-Willy/cetacean-classifier", trust_remote_code=True)
@@ -262,58 +199,12 @@ def main() -> None:
262
  revision=classifier_revision,
263
  trust_remote_code=True)
264
 
265
- if st.session_state.image is None:
266
  # TODO: cleaner design to disable the button until data input done?
267
  st.info("Please upload an image first.")
268
  else:
269
- files = st.session_state.files
270
- images = st.session_state.images
271
- full_data = st.session_state.full_data
272
- for file in files:
273
- image = images[file]
274
- data = full_data[file]
275
- # run classifier model on `image`, and persistently store the output
276
- out = cetacean_classifier(image) # get top 3 matches
277
- st.session_state.whale_prediction1 = out['predictions'][0]
278
- st.session_state.classify_whale_done = True
279
- msg = f"[D]2 classify_whale_done: {st.session_state.classify_whale_done}, whale_prediction1: {st.session_state.whale_prediction1}"
280
- # st.info(msg)
281
- g_logger.info(msg)
282
-
283
- # dropdown for selecting/overriding the species prediction
284
- #st.info(f"[D] classify_whale_done: {st.session_state.classify_whale_done}, whale_prediction1: {st.session_state.whale_prediction1}")
285
- if not st.session_state.classify_whale_done:
286
- selected_class = tab_inference.sidebar.selectbox("Species", sw_wv.WHALE_CLASSES,
287
- index=None, placeholder="Species not yet identified...",
288
- disabled=True)
289
- else:
290
- pred1 = st.session_state.whale_prediction1
291
- # get index of pred1 from WHALE_CLASSES, none if not present
292
- print(f"[D] pred1: {pred1}")
293
- ix = sw_wv.WHALE_CLASSES.index(pred1) if pred1 in sw_wv.WHALE_CLASSES else None
294
- selected_class = tab_inference.selectbox("Species", sw_wv.WHALE_CLASSES, index=ix)
295
 
296
- data['predicted_class'] = selected_class
297
- if selected_class != st.session_state.whale_prediction1:
298
- data['class_overriden'] = selected_class
299
-
300
- st.session_state.public_observation = data
301
- st.button("Upload observations to THE INTERNET!", on_click=push_observations)
302
- # TODO: the metadata only fills properly if `validate` was clicked.
303
- tab_inference.markdown(metadata2md())
304
-
305
- msg = f"[D] full data after inference: {data}"
306
- g_logger.debug(msg)
307
- print(msg)
308
- # TODO: add a link to more info on the model, next to the button.
309
-
310
- whale_classes = out['predictions'][:]
311
- # render images for the top 3 (that is what the model api returns)
312
- with tab_inference:
313
- st.markdown("## Species detected")
314
- for i in range(len(whale_classes)):
315
- sw_wv.display_whale(whale_classes, i)
316
-
317
 
318
 
319
 
@@ -329,29 +220,10 @@ def main() -> None:
329
 
330
  if st.session_state.image is None:
331
  st.info("Please upload an image first.")
332
- st.info(str(observations.to_dict()))
333
 
334
  else:
335
- col1, col2 = tab_hotdogs.columns(2)
336
- for file in st.session_state.files:
337
- image = st.session_state.images[file]
338
- data = st.session_state.full_data[file]
339
- # display the image (use cached version, no need to reread)
340
- col1.image(image, use_column_width=True)
341
- # and then run inference on the image
342
- hotdog_image = Image.fromarray(image)
343
- predictions = pipeline_hot_dog(hotdog_image)
344
-
345
- col2.header("Probabilities")
346
- first = True
347
- for p in predictions:
348
- col2.subheader(f"{ p['label'] }: { round(p['score'] * 100, 1)}%")
349
- if first:
350
- data['predicted_class'] = p['label']
351
- data['predicted_score'] = round(p['score'] * 100, 1)
352
- first = False
353
-
354
- tab_hotdogs.write(f"Session Data: {json.dumps(data)}")
355
 
356
 
357
 
 
 
 
 
 
1
  import logging
2
  import os
 
3
 
4
  import pandas as pd
5
  import streamlit as st
 
6
  import folium
7
  from streamlit_folium import st_folium
8
+
9
  from transformers import pipeline
10
  from transformers import AutoModelForImageClassification
11
 
12
  from datasets import disable_caching
13
  disable_caching()
14
 
15
+ import whale_gallery as gallery
16
+ import whale_viewer as viewer
17
+ from input.input_handling import setup_input
18
+ from maps.alps_map import present_alps_map
19
+ from maps.obs_map import present_obs_map
20
+ from utils.st_logs import setup_logging, parse_log_buffer
21
+ from classifier.classifier_image import cetacean_classify
22
+ from classifier.classifier_hotdog import hotdog_classify
23
 
24
 
25
  # setup for the ML model on huggingface (our wrapper)
 
39
  g_logger.setLevel(LOG_LEVEL)
40
 
41
  st.set_page_config(layout="wide")
 
 
 
42
 
43
  # initialise various session state variables
44
  if "handler" not in st.session_state:
45
+ st.session_state['handler'] = setup_logging()
46
+
47
+ if "observations" not in st.session_state:
48
+ st.session_state.observations = {}
49
+
50
+ if "images" not in st.session_state:
51
+ st.session_state.images = {}
52
 
53
+ if "files" not in st.session_state:
54
+ st.session_state.files = {}
55
+
56
+ if "public_observation" not in st.session_state:
57
+ st.session_state.public_observation = {}
58
 
59
  if "classify_whale_done" not in st.session_state:
60
  st.session_state.classify_whale_done = False
61
 
62
  if "whale_prediction1" not in st.session_state:
63
  st.session_state.whale_prediction1 = None
 
 
 
64
 
65
  if "tab_log" not in st.session_state:
66
  st.session_state.tab_log = None
67
 
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  def main() -> None:
70
  """
71
  Main entry point to set up the streamlit UI and run the application.
72
 
73
  The organisation is as follows:
74
 
75
+ 1. observation input (a new observations) is handled in the sidebar
76
  2. the rest of the interface is organised in tabs:
77
 
78
  - cetean classifier
 
94
  #g_logger.warning("warning message")
95
 
96
  # Streamlit app
97
+ #tab_gallery, tab_inference, tab_hotdogs, tab_map, tab_data, tab_log = st.tabs(["Cetecean classifier", "Hotdog classifier", "Map", "observation", "Log", "Beautiful cetaceans"])
98
+ tab_inference, tab_hotdogs, tab_map, tab_data, tab_log, tab_gallery = st.tabs(["Cetecean classifier", "Hotdog classifier", "Map", "observation", "Log", "Beautiful cetaceans"])
99
  st.session_state.tab_log = tab_log
100
 
101
 
102
  # create a sidebar, and parse all the input (returned as `observations` object)
103
+ observations = setup_input(viewcontainer=st.sidebar)
104
 
105
 
106
  if 0:## WIP
107
  # goal of this code is to allow the user to override the ML prediction, before transmitting an observations
108
+ predicted_class = st.sidebar.selectbox("Predicted Class", viewer.WHALE_CLASSES)
109
  override_prediction = st.sidebar.checkbox("Override Prediction")
110
 
111
  if override_prediction:
112
+ overridden_class = st.sidebar.selectbox("Override Class", viewer.WHALE_CLASSES)
113
+ st.session_state.observations['class_overriden'] = overridden_class
114
  else:
115
+ st.session_state.observations['class_overriden'] = None
116
 
117
 
118
  with tab_map:
 
126
 
127
  if show_db_points:
128
  # show a nicer map, observations marked, tileset selectable.
129
+ st_observation = present_obs_map(
130
  dataset_id=dataset_id, data_files=data_files,
131
  dbg_show_extra=dbg_show_extra)
132
 
133
  else:
134
  # development map.
135
+ st_observation = present_alps_map()
136
 
137
 
138
  with tab_log:
139
  handler = st.session_state['handler']
140
  if handler is not None:
141
+ records = parse_log_buffer(handler.buffer)
142
  st.dataframe(records[::-1], use_container_width=True,)
143
  st.info(f"Length of records: {len(records)}")
144
  else:
 
168
  # specific to the gallery (otherwise we get side effects)
169
  tg_cont = st.container(key="swgallery")
170
  with tg_cont:
171
+ gallery.render_whale_gallery(n_cols=4)
172
 
173
 
174
+ # Display submitted observation
175
  if st.sidebar.button("Validate"):
176
+ # create a dictionary with the submitted observation
177
  submitted_data = observations
178
+ st.session_state.observations = observations
179
 
180
+ tab_log.info(f"{st.session_state.observations}")
181
 
182
+ df = pd.DataFrame(submitted_data, index=[0])
 
183
  with tab_data:
184
  st.table(df)
185
 
 
191
  # - the model predicts the top 3 most likely species from the input image
192
  # - these species are shown
193
  # - the user can override the species prediction using the dropdown
194
+ # - an observation is uploaded if the user chooses.
195
 
196
  if tab_inference.button("Identify with cetacean classifier"):
197
  #pipe = pipeline("image-classification", model="Saving-Willy/cetacean-classifier", trust_remote_code=True)
 
199
  revision=classifier_revision,
200
  trust_remote_code=True)
201
 
202
+ if st.session_state.images is None:
203
  # TODO: cleaner design to disable the button until data input done?
204
  st.info("Please upload an image first.")
205
  else:
206
+ cetacean_classify(cetacean_classifier, tab_inference)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
 
209
 
210
 
 
220
 
221
  if st.session_state.image is None:
222
  st.info("Please upload an image first.")
223
+ #st.info(str(observations.to_dict()))
224
 
225
  else:
226
+ hotdog_classify(pipeline_hot_dog, tab_hotdogs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
 
228
 
229
 
src/{alps_map.py β†’ maps/alps_map.py} RENAMED
File without changes
src/{obs_map.py β†’ maps/obs_map.py} RENAMED
@@ -7,8 +7,8 @@ import streamlit as st
7
  import folium
8
  from streamlit_folium import st_folium
9
 
10
- import whale_viewer as sw_wv
11
- from fix_tabrender import js_show_zeroheight_iframe
12
 
13
  m_logger = logging.getLogger(__name__)
14
  # we can set the log level locally for funcs in this module
@@ -60,7 +60,7 @@ _colors = [
60
  "#778899" # Light Slate Gray
61
  ]
62
 
63
- whale2color = {k: v for k, v in zip(sw_wv.WHALE_CLASSES, _colors)}
64
 
65
  def create_map(tile_name:str, location:Tuple[float], zoom_start: int = 7) -> folium.Map:
66
  """
 
7
  import folium
8
  from streamlit_folium import st_folium
9
 
10
+ import whale_viewer as viewer
11
+ from utils.fix_tabrender import js_show_zeroheight_iframe
12
 
13
  m_logger = logging.getLogger(__name__)
14
  # we can set the log level locally for funcs in this module
 
60
  "#778899" # Light Slate Gray
61
  ]
62
 
63
+ whale2color = {k: v for k, v in zip(viewer.WHALE_CLASSES, _colors)}
64
 
65
  def create_map(tile_name:str, location:Tuple[float], zoom_start: int = 7) -> folium.Map:
66
  """
src/{fix_tabrender.py β†’ utils/fix_tabrender.py} RENAMED
File without changes
src/utils/grid_maker.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import math
3
+
4
+ def gridder(files):
5
+ cols = st.columns(3)
6
+ with cols[0]:
7
+ batch_size = st.select_slider("Batch size:",range(10,110,10), value=10)
8
+ with cols[1]:
9
+ row_size = st.select_slider("Row size:", range(1,6), value = 5)
10
+ num_batches = math.ceil(len(files)/batch_size)
11
+ with cols[2]:
12
+ page = st.selectbox("Page", range(1,num_batches+1))
13
+ return batch_size, row_size, page
src/utils/metadata_handler.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ def metadata2md() -> str:
4
+ """Get metadata from cache and return as markdown-formatted key-value list
5
+
6
+ Returns:
7
+ str: Markdown-formatted key-value list of metadata
8
+
9
+ """
10
+ markdown_str = "\n"
11
+ keys_to_print = ["latitude","longitude","author_email","date","time"]
12
+ for key, value in st.session_state.public_observation.items():
13
+ if key in keys_to_print:
14
+ markdown_str += f"- **{key}**: {value}\n"
15
+ return markdown_str
16
+
src/{st_logs.py β†’ utils/st_logs.py} RENAMED
File without changes
src/whale_viewer.py CHANGED
@@ -1,5 +1,5 @@
1
  from typing import List
2
-
3
  from PIL import Image
4
  import pandas as pd
5
  import os
@@ -134,7 +134,7 @@ def display_whale(whale_classes:List[str], i:int, viewcontainer=None):
134
  TODO: how to find the object type of viewcontainer.? they are just "deltagenerators" but
135
  we want the result of the generator.. In any case, it works ok with either call signature.
136
  """
137
- import streamlit as st
138
  if viewcontainer is None:
139
  viewcontainer = st
140
 
@@ -148,11 +148,10 @@ def display_whale(whale_classes:List[str], i:int, viewcontainer=None):
148
 
149
 
150
  viewcontainer.markdown(
151
- "### :whale: #" + str(i + 1) + ": " + format_whale_name(whale_classes[i])
152
  )
153
  current_dir = os.getcwd()
154
  image_path = os.path.join(current_dir, "src/images/references/")
155
  image = Image.open(image_path + df_whale_img_ref.loc[whale_classes[i], "WHALE_IMAGES"])
156
 
157
- viewcontainer.image(image, caption=df_whale_img_ref.loc[whale_classes[i], "WHALE_REFERENCES"])
158
- # link st.markdown(f"[{df.loc[whale_classes[i], 'WHALE_REFERENCES']}]({df.loc[whale_classes[i], 'WHALE_REFERENCES']})")
 
1
  from typing import List
2
+ import streamlit as st
3
  from PIL import Image
4
  import pandas as pd
5
  import os
 
134
  TODO: how to find the object type of viewcontainer.? they are just "deltagenerators" but
135
  we want the result of the generator.. In any case, it works ok with either call signature.
136
  """
137
+
138
  if viewcontainer is None:
139
  viewcontainer = st
140
 
 
148
 
149
 
150
  viewcontainer.markdown(
151
+ ":whale: #" + str(i + 1) + ": " + format_whale_name(whale_classes[i])
152
  )
153
  current_dir = os.getcwd()
154
  image_path = os.path.join(current_dir, "src/images/references/")
155
  image = Image.open(image_path + df_whale_img_ref.loc[whale_classes[i], "WHALE_IMAGES"])
156
 
157
+ viewcontainer.image(image, caption=df_whale_img_ref.loc[whale_classes[i], "WHALE_REFERENCES"], use_column_width=True)