Spaces:
Sleeping
Sleeping
feat: refactor and multi image classification
Browse files- src/classifier/classifier_hotdog.py +26 -0
- src/classifier/classifier_image.py +69 -0
- src/classifier_image.py +70 -0
- src/hf_push_observations.py +56 -0
- src/{input_handling.py β input/input_handling.py} +3 -174
- src/input/input_observation.py +110 -0
- src/input/input_validator.py +68 -0
- src/main.py +43 -171
- src/{alps_map.py β maps/alps_map.py} +0 -0
- src/{obs_map.py β maps/obs_map.py} +3 -3
- src/{fix_tabrender.py β utils/fix_tabrender.py} +0 -0
- src/utils/grid_maker.py +13 -0
- src/utils/metadata_handler.py +16 -0
- src/{st_logs.py β utils/st_logs.py} +0 -0
- src/whale_viewer.py +4 -5
src/classifier/classifier_hotdog.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import json
|
3 |
+
from PIL import Image
|
4 |
+
|
5 |
+
|
6 |
+
def hotdog_classify(pipeline_hot_dog, tab_hotdogs):
|
7 |
+
col1, col2 = tab_hotdogs.columns(2)
|
8 |
+
for file in st.session_state.files:
|
9 |
+
image = st.session_state.images[file.name]
|
10 |
+
observation = st.session_state.observations[file.name].to_dict()
|
11 |
+
# display the image (use cached version, no need to reread)
|
12 |
+
col1.image(image, use_column_width=True)
|
13 |
+
# and then run inference on the image
|
14 |
+
hotdog_image = Image.fromarray(image)
|
15 |
+
predictions = pipeline_hot_dog(hotdog_image)
|
16 |
+
|
17 |
+
col2.header("Probabilities")
|
18 |
+
first = True
|
19 |
+
for p in predictions:
|
20 |
+
col2.subheader(f"{ p['label'] }: { round(p['score'] * 100, 1)}%")
|
21 |
+
if first:
|
22 |
+
observation['predicted_class'] = p['label']
|
23 |
+
observation['predicted_score'] = round(p['score'] * 100, 1)
|
24 |
+
first = False
|
25 |
+
|
26 |
+
tab_hotdogs.write(f"Session observation: {json.dumps(observation)}")
|
src/classifier/classifier_image.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import logging
|
3 |
+
|
4 |
+
# get a global var for logger accessor in this module
|
5 |
+
LOG_LEVEL = logging.DEBUG
|
6 |
+
g_logger = logging.getLogger(__name__)
|
7 |
+
g_logger.setLevel(LOG_LEVEL)
|
8 |
+
|
9 |
+
import whale_viewer as viewer
|
10 |
+
from hf_push_observations import push_observations
|
11 |
+
from utils.grid_maker import gridder
|
12 |
+
from utils.metadata_handler import metadata2md
|
13 |
+
|
14 |
+
def cetacean_classify(cetacean_classifier, tab_inference):
|
15 |
+
files = st.session_state.files
|
16 |
+
images = st.session_state.images
|
17 |
+
observations = st.session_state.observations
|
18 |
+
|
19 |
+
batch_size, row_size, page = gridder(files)
|
20 |
+
|
21 |
+
grid = st.columns(row_size)
|
22 |
+
col = 0
|
23 |
+
|
24 |
+
for file in files:
|
25 |
+
image = images[file.name]
|
26 |
+
|
27 |
+
with grid[col]:
|
28 |
+
st.image(image, use_column_width=True)
|
29 |
+
observation = observations[file.name].to_dict()
|
30 |
+
# run classifier model on `image`, and persistently store the output
|
31 |
+
out = cetacean_classifier(image) # get top 3 matches
|
32 |
+
st.session_state.whale_prediction1 = out['predictions'][0]
|
33 |
+
st.session_state.classify_whale_done = True
|
34 |
+
msg = f"[D]2 classify_whale_done: {st.session_state.classify_whale_done}, whale_prediction1: {st.session_state.whale_prediction1}"
|
35 |
+
g_logger.info(msg)
|
36 |
+
|
37 |
+
# dropdown for selecting/overriding the species prediction
|
38 |
+
if not st.session_state.classify_whale_done:
|
39 |
+
selected_class = st.sidebar.selectbox("Species", viewer.WHALE_CLASSES,
|
40 |
+
index=None, placeholder="Species not yet identified...",
|
41 |
+
disabled=True)
|
42 |
+
else:
|
43 |
+
pred1 = st.session_state.whale_prediction1
|
44 |
+
# get index of pred1 from WHALE_CLASSES, none if not present
|
45 |
+
print(f"[D] pred1: {pred1}")
|
46 |
+
ix = viewer.WHALE_CLASSES.index(pred1) if pred1 in viewer.WHALE_CLASSES else None
|
47 |
+
selected_class = st.selectbox(f"Species for {file.name}", viewer.WHALE_CLASSES, index=ix)
|
48 |
+
|
49 |
+
observation['predicted_class'] = selected_class
|
50 |
+
if selected_class != st.session_state.whale_prediction1:
|
51 |
+
observation['class_overriden'] = selected_class
|
52 |
+
|
53 |
+
st.session_state.public_observation = observation
|
54 |
+
st.button(f"Upload observation for {file.name} to THE INTERNET!", on_click=push_observations)
|
55 |
+
# TODO: the metadata only fills properly if `validate` was clicked.
|
56 |
+
st.markdown(metadata2md())
|
57 |
+
|
58 |
+
msg = f"[D] full observation after inference: {observation}"
|
59 |
+
g_logger.debug(msg)
|
60 |
+
print(msg)
|
61 |
+
# TODO: add a link to more info on the model, next to the button.
|
62 |
+
|
63 |
+
whale_classes = out['predictions'][:]
|
64 |
+
# render images for the top 3 (that is what the model api returns)
|
65 |
+
#with tab_inference:
|
66 |
+
st.markdown(f"Top 3 Predictions for {file.name}")
|
67 |
+
for i in range(len(whale_classes)):
|
68 |
+
viewer.display_whale(whale_classes, i)
|
69 |
+
col = (col + 1) % row_size
|
src/classifier_image.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
|
5 |
+
# get a global var for logger accessor in this module
|
6 |
+
LOG_LEVEL = logging.DEBUG
|
7 |
+
g_logger = logging.getLogger(__name__)
|
8 |
+
g_logger.setLevel(LOG_LEVEL)
|
9 |
+
|
10 |
+
from grid_maker import gridder
|
11 |
+
import hf_push_observations as sw_push_obs
|
12 |
+
import utils.metadata_handler as meta_handler
|
13 |
+
import whale_viewer as sw_wv
|
14 |
+
|
15 |
+
def cetacean_classify(cetacean_classifier, tab_inference):
|
16 |
+
files = st.session_state.files
|
17 |
+
images = st.session_state.images
|
18 |
+
observations = st.session_state.observations
|
19 |
+
|
20 |
+
batch_size, row_size, page = gridder(files)
|
21 |
+
|
22 |
+
grid = st.columns(row_size)
|
23 |
+
col = 0
|
24 |
+
|
25 |
+
for file in files:
|
26 |
+
image = images[file.name]
|
27 |
+
|
28 |
+
with grid[col]:
|
29 |
+
st.image(image, use_column_width=True)
|
30 |
+
observation = observations[file.name].to_dict()
|
31 |
+
# run classifier model on `image`, and persistently store the output
|
32 |
+
out = cetacean_classifier(image) # get top 3 matches
|
33 |
+
st.session_state.whale_prediction1 = out['predictions'][0]
|
34 |
+
st.session_state.classify_whale_done = True
|
35 |
+
msg = f"[D]2 classify_whale_done: {st.session_state.classify_whale_done}, whale_prediction1: {st.session_state.whale_prediction1}"
|
36 |
+
g_logger.info(msg)
|
37 |
+
|
38 |
+
# dropdown for selecting/overriding the species prediction
|
39 |
+
if not st.session_state.classify_whale_done:
|
40 |
+
selected_class = st.sidebar.selectbox("Species", sw_wv.WHALE_CLASSES,
|
41 |
+
index=None, placeholder="Species not yet identified...",
|
42 |
+
disabled=True)
|
43 |
+
else:
|
44 |
+
pred1 = st.session_state.whale_prediction1
|
45 |
+
# get index of pred1 from WHALE_CLASSES, none if not present
|
46 |
+
print(f"[D] pred1: {pred1}")
|
47 |
+
ix = sw_wv.WHALE_CLASSES.index(pred1) if pred1 in sw_wv.WHALE_CLASSES else None
|
48 |
+
selected_class = tab_inference.selectbox("Species", sw_wv.WHALE_CLASSES, index=ix)
|
49 |
+
|
50 |
+
observation['predicted_class'] = selected_class
|
51 |
+
if selected_class != st.session_state.whale_prediction1:
|
52 |
+
observation['class_overriden'] = selected_class
|
53 |
+
|
54 |
+
st.session_state.public_observation = observation
|
55 |
+
st.button(f"Upload observation for {file.name} to THE INTERNET!", on_click=sw_push_obs.push_observations)
|
56 |
+
# TODO: the metadata only fills properly if `validate` was clicked.
|
57 |
+
st.markdown(meta_handler.metadata2md())
|
58 |
+
|
59 |
+
msg = f"[D] full observation after inference: {observation}"
|
60 |
+
g_logger.debug(msg)
|
61 |
+
print(msg)
|
62 |
+
# TODO: add a link to more info on the model, next to the button.
|
63 |
+
|
64 |
+
whale_classes = out['predictions'][:]
|
65 |
+
# render images for the top 3 (that is what the model api returns)
|
66 |
+
#with tab_inference:
|
67 |
+
st.title(f"Species detected for {file.name}")
|
68 |
+
for i in range(len(whale_classes)):
|
69 |
+
sw_wv.display_whale(whale_classes, i)
|
70 |
+
col = (col + 1) % row_size
|
src/hf_push_observations.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from streamlit.delta_generator import DeltaGenerator
|
2 |
+
import streamlit as st
|
3 |
+
from huggingface_hub import HfApi
|
4 |
+
import json
|
5 |
+
import tempfile
|
6 |
+
import logging
|
7 |
+
|
8 |
+
# get a global var for logger accessor in this module
|
9 |
+
LOG_LEVEL = logging.DEBUG
|
10 |
+
g_logger = logging.getLogger(__name__)
|
11 |
+
g_logger.setLevel(LOG_LEVEL)
|
12 |
+
|
13 |
+
def push_observations(tab_log:DeltaGenerator=None):
|
14 |
+
"""
|
15 |
+
Push the observations to the Hugging Face dataset
|
16 |
+
|
17 |
+
Args:
|
18 |
+
tab_log (streamlit.container): The container to log messages to. If not provided,
|
19 |
+
log messages are in any case written to the global logger (TODO: test - didn't
|
20 |
+
push any observation since generating the logger)
|
21 |
+
|
22 |
+
"""
|
23 |
+
# we get the observation from session state: 1 is the dict 2 is the image.
|
24 |
+
# first, lets do an info display (popup)
|
25 |
+
metadata_str = json.dumps(st.session_state.public_observation)
|
26 |
+
|
27 |
+
st.toast(f"Uploading observations: {metadata_str}", icon="π¦")
|
28 |
+
tab_log = st.session_state.tab_log
|
29 |
+
if tab_log is not None:
|
30 |
+
tab_log.info(f"Uploading observations: {metadata_str}")
|
31 |
+
|
32 |
+
# get huggingface api
|
33 |
+
import os
|
34 |
+
token = os.environ.get("HF_TOKEN", None)
|
35 |
+
api = HfApi(token=token)
|
36 |
+
|
37 |
+
f = tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False)
|
38 |
+
f.write(metadata_str)
|
39 |
+
f.close()
|
40 |
+
st.info(f"temp file: {f.name} with metadata written...")
|
41 |
+
|
42 |
+
path_in_repo= f"metadata/{st.session_state.public_observation['author_email']}/{st.session_state.public_observation['image_md5']}.json"
|
43 |
+
msg = f"fname: {f.name} | path: {path_in_repo}"
|
44 |
+
print(msg)
|
45 |
+
st.warning(msg)
|
46 |
+
# rv = api.upload_file(
|
47 |
+
# path_or_fileobj=f.name,
|
48 |
+
# path_in_repo=path_in_repo,
|
49 |
+
# repo_id="Saving-Willy/temp_dataset",
|
50 |
+
# repo_type="dataset",
|
51 |
+
# )
|
52 |
+
# print(rv)
|
53 |
+
# msg = f"observation attempted tx to repo happy walrus: {rv}"
|
54 |
+
g_logger.info(msg)
|
55 |
+
st.info(msg)
|
56 |
+
|
src/{input_handling.py β input/input_handling.py}
RENAMED
@@ -1,19 +1,14 @@
|
|
1 |
-
from PIL import Image
|
2 |
-
from PIL import ExifTags
|
3 |
-
import re
|
4 |
import datetime
|
5 |
-
import hashlib
|
6 |
import logging
|
7 |
|
8 |
import streamlit as st
|
9 |
-
from streamlit.runtime.uploaded_file_manager import UploadedFile # for type hinting
|
10 |
from streamlit.delta_generator import DeltaGenerator
|
11 |
|
12 |
import cv2
|
13 |
import numpy as np
|
14 |
|
15 |
-
import
|
16 |
-
import
|
17 |
|
18 |
m_logger = logging.getLogger(__name__)
|
19 |
m_logger.setLevel(logging.INFO)
|
@@ -25,172 +20,6 @@ both the UI elements (setup_input_UI) and the validation functions.
|
|
25 |
'''
|
26 |
allowed_image_types = ['jpg', 'jpeg', 'png', 'webp']
|
27 |
|
28 |
-
def generate_random_md5():
|
29 |
-
# Generate a random string
|
30 |
-
random_string = ''.join(random.choices(string.ascii_letters + string.digits, k=16))
|
31 |
-
# Encode the string and compute its MD5 hash
|
32 |
-
md5_hash = hashlib.md5(random_string.encode()).hexdigest()
|
33 |
-
return md5_hash
|
34 |
-
|
35 |
-
# autogenerated class to hold the input data
|
36 |
-
class InputObservation:
|
37 |
-
"""
|
38 |
-
A class to hold an input observation and associated metadata
|
39 |
-
|
40 |
-
Attributes:
|
41 |
-
image (Any):
|
42 |
-
The image associated with the observation.
|
43 |
-
latitude (float):
|
44 |
-
The latitude where the observation was made.
|
45 |
-
longitude (float):
|
46 |
-
The longitude where the observation was made.
|
47 |
-
author_email (str):
|
48 |
-
The email of the author of the observation.
|
49 |
-
date (str):
|
50 |
-
The date when the observation was made.
|
51 |
-
time (str):
|
52 |
-
The time when the observation was made.
|
53 |
-
date_option (str):
|
54 |
-
Additional date option for the observation.
|
55 |
-
time_option (str):
|
56 |
-
Additional time option for the observation.
|
57 |
-
uploaded_filename (Any):
|
58 |
-
The uploaded filename associated with the observation.
|
59 |
-
|
60 |
-
Methods:
|
61 |
-
__str__():
|
62 |
-
Returns a string representation of the observation.
|
63 |
-
__repr__():
|
64 |
-
Returns a string representation of the observation.
|
65 |
-
__eq__(other):
|
66 |
-
Checks if two observations are equal.
|
67 |
-
__ne__(other):
|
68 |
-
Checks if two observations are not equal.
|
69 |
-
__hash__():
|
70 |
-
Returns the hash of the observation.
|
71 |
-
to_dict():
|
72 |
-
Converts the observation to a dictionary.
|
73 |
-
from_dict(data):
|
74 |
-
Creates an observation from a dictionary.
|
75 |
-
from_input(input):
|
76 |
-
Creates an observation from another input observation.
|
77 |
-
"""
|
78 |
-
def __init__(self, image=None, latitude=None, longitude=None, author_email=None, date=None, time=None, date_option=None, time_option=None, uploaded_filename=None):
|
79 |
-
self.image = image
|
80 |
-
self.latitude = latitude
|
81 |
-
self.longitude = longitude
|
82 |
-
self.author_email = author_email
|
83 |
-
self.date = date
|
84 |
-
self.time = time
|
85 |
-
self.date_option = date_option
|
86 |
-
self.time_option = time_option
|
87 |
-
self.uploaded_filename = uploaded_filename
|
88 |
-
|
89 |
-
def __str__(self):
|
90 |
-
return f"Observation: {self.image}, {self.latitude}, {self.longitude}, {self.author_email}, {self.date}, {self.time}, {self.date_option}, {self.time_option}, {self.uploaded_filename}"
|
91 |
-
|
92 |
-
def __repr__(self):
|
93 |
-
return f"Observation: {self.image}, {self.latitude}, {self.longitude}, {self.author_email}, {self.date}, {self.time}, {self.date_option}, {self.time_option}, {self.uploaded_filename}"
|
94 |
-
|
95 |
-
def __eq__(self, other):
|
96 |
-
return (self.image == other.image and self.latitude == other.latitude and self.longitude == other.longitude and
|
97 |
-
self.author_email == other.author_email and self.date == other.date and self.time == other.time and
|
98 |
-
self.date_option == other.date_option and self.time_option == other.time_option and self.uploaded_filename == other.uploaded_filename)
|
99 |
-
|
100 |
-
def __ne__(self, other):
|
101 |
-
return not self.__eq__(other)
|
102 |
-
|
103 |
-
def __hash__(self):
|
104 |
-
return hash((self.image, self.latitude, self.longitude, self.author_email, self.date, self.time, self.date_option, self.time_option, self.uploaded_filename))
|
105 |
-
|
106 |
-
def to_dict(self):
|
107 |
-
return {
|
108 |
-
#"image": self.image,
|
109 |
-
"image_filename": self.uploaded_filename.name if self.uploaded_filename else None,
|
110 |
-
"image_md5": hashlib.md5(self.uploaded_filename.read()).hexdigest() if self.uploaded_filename else generate_random_md5(),
|
111 |
-
"latitude": self.latitude,
|
112 |
-
"longitude": self.longitude,
|
113 |
-
"author_email": self.author_email,
|
114 |
-
"date": self.date,
|
115 |
-
"time": self.time,
|
116 |
-
"date_option": str(self.date_option),
|
117 |
-
"time_option": str(self.time_option),
|
118 |
-
"uploaded_filename": self.uploaded_filename
|
119 |
-
}
|
120 |
-
|
121 |
-
@classmethod
|
122 |
-
def from_dict(cls, data):
|
123 |
-
return cls(data["image"], data["latitude"], data["longitude"], data["author_email"], data["date"], data["time"], data["date_option"], data["time_option"], data["uploaded_filename"])
|
124 |
-
|
125 |
-
@classmethod
|
126 |
-
def from_input(cls, input):
|
127 |
-
return cls(input.image, input.latitude, input.longitude, input.author_email, input.date, input.time, input.date_option, input.time_option, input.uploaded_filename)
|
128 |
-
|
129 |
-
@staticmethod
|
130 |
-
def from_input(input):
|
131 |
-
return InputObservation(input.image, input.latitude, input.longitude, input.author_email, input.date, input.time, input.date_option, input.time_option, input.uploaded_filename)
|
132 |
-
|
133 |
-
@staticmethod
|
134 |
-
def from_dict(data):
|
135 |
-
return InputObservation(data["image"], data["latitude"], data["longitude"], data["author_email"], data["date"], data["time"], data["date_option"], data["time_option"], data["uploaded_filename"])
|
136 |
-
|
137 |
-
|
138 |
-
def is_valid_number(number:str) -> bool:
|
139 |
-
"""
|
140 |
-
Check if the given string is a valid number (int or float, sign ok)
|
141 |
-
|
142 |
-
Args:
|
143 |
-
number (str): The string to be checked.
|
144 |
-
|
145 |
-
Returns:
|
146 |
-
bool: True if the string is a valid number, False otherwise.
|
147 |
-
"""
|
148 |
-
pattern = r'^[-+]?[0-9]*\.?[0-9]+$'
|
149 |
-
return re.match(pattern, number) is not None
|
150 |
-
|
151 |
-
|
152 |
-
# Function to validate email address
|
153 |
-
def is_valid_email(email:str) -> bool:
|
154 |
-
"""
|
155 |
-
Validates if the provided email address is in a correct format.
|
156 |
-
|
157 |
-
Args:
|
158 |
-
email (str): The email address to validate.
|
159 |
-
|
160 |
-
Returns:
|
161 |
-
bool: True if the email address is valid, False otherwise.
|
162 |
-
"""
|
163 |
-
pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
|
164 |
-
return re.match(pattern, email) is not None
|
165 |
-
|
166 |
-
# Function to extract date and time from image metadata
|
167 |
-
# def get_image_datetime(image_file: UploadedFile) -> str | None:
|
168 |
-
def get_image_datetime(image_file):
|
169 |
-
"""
|
170 |
-
Extracts the original date and time from the EXIF metadata of an uploaded image file.
|
171 |
-
|
172 |
-
Args:
|
173 |
-
image_file (UploadedFile): The uploaded image file from which to extract the date and time.
|
174 |
-
|
175 |
-
Returns:
|
176 |
-
str: The original date and time as a string if available, otherwise None.
|
177 |
-
|
178 |
-
Raises:
|
179 |
-
Warning: If the date and time could not be extracted from the image metadata.
|
180 |
-
"""
|
181 |
-
try:
|
182 |
-
image = Image.open(image_file)
|
183 |
-
exif_data = image._getexif()
|
184 |
-
if exif_data is not None:
|
185 |
-
for tag, value in exif_data.items():
|
186 |
-
if ExifTags.TAGS.get(tag) == 'DateTimeOriginal':
|
187 |
-
return value
|
188 |
-
except Exception as e: # FIXME: what types of exception?
|
189 |
-
st.warning(f"Could not extract date from image metadata. (file: {image_file.name})")
|
190 |
-
# TODO: add to logger
|
191 |
-
return None
|
192 |
-
|
193 |
-
|
194 |
# an arbitrary set of defaults so testing is less painful...
|
195 |
# ideally we add in some randomization to the defaults
|
196 |
spoof_metadata = {
|
@@ -282,7 +111,7 @@ def setup_input(
|
|
282 |
observations[file.name] = observation
|
283 |
images[file.name] = image
|
284 |
|
285 |
-
st.session_state.
|
286 |
st.session_state.files = uploaded_files
|
287 |
|
288 |
return observations
|
|
|
|
|
|
|
|
|
1 |
import datetime
|
|
|
2 |
import logging
|
3 |
|
4 |
import streamlit as st
|
|
|
5 |
from streamlit.delta_generator import DeltaGenerator
|
6 |
|
7 |
import cv2
|
8 |
import numpy as np
|
9 |
|
10 |
+
from input.input_observation import InputObservation
|
11 |
+
from input.input_validator import get_image_datetime, is_valid_email, is_valid_number
|
12 |
|
13 |
m_logger = logging.getLogger(__name__)
|
14 |
m_logger.setLevel(logging.INFO)
|
|
|
20 |
'''
|
21 |
allowed_image_types = ['jpg', 'jpeg', 'png', 'webp']
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
# an arbitrary set of defaults so testing is less painful...
|
24 |
# ideally we add in some randomization to the defaults
|
25 |
spoof_metadata = {
|
|
|
111 |
observations[file.name] = observation
|
112 |
images[file.name] = image
|
113 |
|
114 |
+
st.session_state.images = images
|
115 |
st.session_state.files = uploaded_files
|
116 |
|
117 |
return observations
|
src/input/input_observation.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hashlib
|
2 |
+
from input.input_validator import generate_random_md5
|
3 |
+
|
4 |
+
# autogenerated class to hold the input data
|
5 |
+
class InputObservation:
|
6 |
+
"""
|
7 |
+
A class to hold an input observation and associated metadata
|
8 |
+
|
9 |
+
Attributes:
|
10 |
+
image (Any):
|
11 |
+
The image associated with the observation.
|
12 |
+
latitude (float):
|
13 |
+
The latitude where the observation was made.
|
14 |
+
longitude (float):
|
15 |
+
The longitude where the observation was made.
|
16 |
+
author_email (str):
|
17 |
+
The email of the author of the observation.
|
18 |
+
date (str):
|
19 |
+
The date when the observation was made.
|
20 |
+
time (str):
|
21 |
+
The time when the observation was made.
|
22 |
+
date_option (str):
|
23 |
+
Additional date option for the observation.
|
24 |
+
time_option (str):
|
25 |
+
Additional time option for the observation.
|
26 |
+
uploaded_filename (Any):
|
27 |
+
The uploaded filename associated with the observation.
|
28 |
+
|
29 |
+
Methods:
|
30 |
+
__str__():
|
31 |
+
Returns a string representation of the observation.
|
32 |
+
__repr__():
|
33 |
+
Returns a string representation of the observation.
|
34 |
+
__eq__(other):
|
35 |
+
Checks if two observations are equal.
|
36 |
+
__ne__(other):
|
37 |
+
Checks if two observations are not equal.
|
38 |
+
__hash__():
|
39 |
+
Returns the hash of the observation.
|
40 |
+
to_dict():
|
41 |
+
Converts the observation to a dictionary.
|
42 |
+
from_dict(data):
|
43 |
+
Creates an observation from a dictionary.
|
44 |
+
from_input(input):
|
45 |
+
Creates an observation from another input observation.
|
46 |
+
"""
|
47 |
+
def __init__(self, image=None, latitude=None, longitude=None,
|
48 |
+
author_email=None, date=None, time=None, date_option=None, time_option=None,
|
49 |
+
uploaded_filename=None):
|
50 |
+
self.image = image
|
51 |
+
self.latitude = latitude
|
52 |
+
self.longitude = longitude
|
53 |
+
self.author_email = author_email
|
54 |
+
self.date = date
|
55 |
+
self.time = time
|
56 |
+
self.date_option = date_option
|
57 |
+
self.time_option = time_option
|
58 |
+
self.uploaded_filename = uploaded_filename
|
59 |
+
|
60 |
+
def __str__(self):
|
61 |
+
return f"Observation: {self.image}, {self.latitude}, {self.longitude}, {self.author_email}, {self.date}, {self.time}, {self.date_option}, {self.time_option}, {self.uploaded_filename}"
|
62 |
+
|
63 |
+
def __repr__(self):
|
64 |
+
return f"Observation: {self.image}, {self.latitude}, {self.longitude}, {self.author_email}, {self.date}, {self.time}, {self.date_option}, {self.time_option}, {self.uploaded_filename}"
|
65 |
+
|
66 |
+
def __eq__(self, other):
|
67 |
+
return (self.image == other.image and self.latitude == other.latitude and self.longitude == other.longitude and
|
68 |
+
self.author_email == other.author_email and self.date == other.date and self.time == other.time and
|
69 |
+
self.date_option == other.date_option and self.time_option == other.time_option and self.uploaded_filename == other.uploaded_filename)
|
70 |
+
|
71 |
+
def __ne__(self, other):
|
72 |
+
return not self.__eq__(other)
|
73 |
+
|
74 |
+
def __hash__(self):
|
75 |
+
return hash((self.image, self.latitude, self.longitude, self.author_email, self.date, self.time, self.date_option, self.time_option, self.uploaded_filename))
|
76 |
+
|
77 |
+
def to_dict(self):
|
78 |
+
return {
|
79 |
+
#"image": self.image,
|
80 |
+
"image_filename": self.uploaded_filename.name if self.uploaded_filename else None,
|
81 |
+
"image_md5": hashlib.md5(self.uploaded_filename.read()).hexdigest() if self.uploaded_filename else generate_random_md5(),
|
82 |
+
"latitude": self.latitude,
|
83 |
+
"longitude": self.longitude,
|
84 |
+
"author_email": self.author_email,
|
85 |
+
"date": self.date,
|
86 |
+
"time": self.time,
|
87 |
+
"date_option": str(self.date_option),
|
88 |
+
"time_option": str(self.time_option),
|
89 |
+
"uploaded_filename": self.uploaded_filename
|
90 |
+
}
|
91 |
+
|
92 |
+
@classmethod
|
93 |
+
def from_dict(cls, data):
|
94 |
+
return cls(data["image"], data["latitude"], data["longitude"], data["author_email"], data["date"], data["time"], data["date_option"], data["time_option"], data["uploaded_filename"])
|
95 |
+
|
96 |
+
@classmethod
|
97 |
+
def from_input(cls, input):
|
98 |
+
return cls(input.image, input.latitude, input.longitude, input.author_email, input.date, input.time, input.date_option, input.time_option, input.uploaded_filename)
|
99 |
+
|
100 |
+
@staticmethod
|
101 |
+
def from_input(input):
|
102 |
+
return InputObservation(input.image, input.latitude, input.longitude, input.author_email, input.date, input.time, input.date_option, input.time_option, input.uploaded_filename)
|
103 |
+
|
104 |
+
@staticmethod
|
105 |
+
def from_dict(data):
|
106 |
+
return InputObservation(data["image"], data["latitude"], data["longitude"], data["author_email"], data["date"], data["time"], data["date_option"], data["time_option"], data["uploaded_filename"])
|
107 |
+
|
108 |
+
|
109 |
+
|
110 |
+
|
src/input/input_validator.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import string
|
3 |
+
import hashlib
|
4 |
+
import re
|
5 |
+
import streamlit as st
|
6 |
+
|
7 |
+
from PIL import Image
|
8 |
+
from PIL import ExifTags
|
9 |
+
|
10 |
+
def generate_random_md5():
|
11 |
+
# Generate a random string
|
12 |
+
random_string = ''.join(random.choices(string.ascii_letters + string.digits, k=16))
|
13 |
+
# Encode the string and compute its MD5 hash
|
14 |
+
md5_hash = hashlib.md5(random_string.encode()).hexdigest()
|
15 |
+
return md5_hash
|
16 |
+
|
17 |
+
def is_valid_number(number:str) -> bool:
|
18 |
+
"""
|
19 |
+
Check if the given string is a valid number (int or float, sign ok)
|
20 |
+
|
21 |
+
Args:
|
22 |
+
number (str): The string to be checked.
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
bool: True if the string is a valid number, False otherwise.
|
26 |
+
"""
|
27 |
+
pattern = r'^[-+]?[0-9]*\.?[0-9]+$'
|
28 |
+
return re.match(pattern, number) is not None
|
29 |
+
|
30 |
+
# Function to validate email address
|
31 |
+
def is_valid_email(email:str) -> bool:
|
32 |
+
"""
|
33 |
+
Validates if the provided email address is in a correct format.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
email (str): The email address to validate.
|
37 |
+
|
38 |
+
Returns:
|
39 |
+
bool: True if the email address is valid, False otherwise.
|
40 |
+
"""
|
41 |
+
pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
|
42 |
+
return re.match(pattern, email) is not None
|
43 |
+
|
44 |
+
# Function to extract date and time from image metadata
|
45 |
+
def get_image_datetime(image_file):
|
46 |
+
"""
|
47 |
+
Extracts the original date and time from the EXIF metadata of an uploaded image file.
|
48 |
+
|
49 |
+
Args:
|
50 |
+
image_file (UploadedFile): The uploaded image file from which to extract the date and time.
|
51 |
+
|
52 |
+
Returns:
|
53 |
+
str: The original date and time as a string if available, otherwise None.
|
54 |
+
|
55 |
+
Raises:
|
56 |
+
Warning: If the date and time could not be extracted from the image metadata.
|
57 |
+
"""
|
58 |
+
try:
|
59 |
+
image = Image.open(image_file)
|
60 |
+
exif_data = image._getexif()
|
61 |
+
if exif_data is not None:
|
62 |
+
for tag, value in exif_data.items():
|
63 |
+
if ExifTags.TAGS.get(tag) == 'DateTimeOriginal':
|
64 |
+
return value
|
65 |
+
except Exception as e: # FIXME: what types of exception?
|
66 |
+
st.warning(f"Could not extract date from image metadata. (file: {image_file.name})")
|
67 |
+
# TODO: add to logger
|
68 |
+
return None
|
src/main.py
CHANGED
@@ -1,31 +1,25 @@
|
|
1 |
-
#import datetime
|
2 |
-
from PIL import Image
|
3 |
-
|
4 |
-
import json
|
5 |
import logging
|
6 |
import os
|
7 |
-
import tempfile
|
8 |
|
9 |
import pandas as pd
|
10 |
import streamlit as st
|
11 |
-
from streamlit.delta_generator import DeltaGenerator # for type hinting
|
12 |
import folium
|
13 |
from streamlit_folium import st_folium
|
14 |
-
|
15 |
from transformers import pipeline
|
16 |
from transformers import AutoModelForImageClassification
|
17 |
|
18 |
from datasets import disable_caching
|
19 |
disable_caching()
|
20 |
|
21 |
-
import
|
22 |
-
import
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
import
|
27 |
-
|
28 |
-
|
29 |
|
30 |
|
31 |
# setup for the ML model on huggingface (our wrapper)
|
@@ -45,96 +39,40 @@ g_logger = logging.getLogger(__name__)
|
|
45 |
g_logger.setLevel(LOG_LEVEL)
|
46 |
|
47 |
st.set_page_config(layout="wide")
|
48 |
-
#sw_logs.setup_logging(level=LOG_LEVEL, buffer_len=40)
|
49 |
-
|
50 |
-
|
51 |
|
52 |
# initialise various session state variables
|
53 |
if "handler" not in st.session_state:
|
54 |
-
st.session_state['handler'] =
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
-
if "
|
57 |
-
st.session_state.
|
|
|
|
|
|
|
58 |
|
59 |
if "classify_whale_done" not in st.session_state:
|
60 |
st.session_state.classify_whale_done = False
|
61 |
|
62 |
if "whale_prediction1" not in st.session_state:
|
63 |
st.session_state.whale_prediction1 = None
|
64 |
-
|
65 |
-
if "image" not in st.session_state:
|
66 |
-
st.session_state.image = None
|
67 |
|
68 |
if "tab_log" not in st.session_state:
|
69 |
st.session_state.tab_log = None
|
70 |
|
71 |
|
72 |
-
def metadata2md() -> str:
|
73 |
-
"""Get metadata from cache and return as markdown-formatted key-value list
|
74 |
-
|
75 |
-
Returns:
|
76 |
-
str: Markdown-formatted key-value list of metadata
|
77 |
-
|
78 |
-
"""
|
79 |
-
markdown_str = "\n"
|
80 |
-
for key, value in st.session_state.public_observation.items():
|
81 |
-
markdown_str += f"- **{key}**: {value}\n"
|
82 |
-
return markdown_str
|
83 |
-
|
84 |
-
|
85 |
-
def push_observations(tab_log:DeltaGenerator=None):
|
86 |
-
"""
|
87 |
-
Push the observations to the Hugging Face dataset
|
88 |
-
|
89 |
-
Args:
|
90 |
-
tab_log (streamlit.container): The container to log messages to. If not provided,
|
91 |
-
log messages are in any case written to the global logger (TODO: test - didn't
|
92 |
-
push any data since generating the logger)
|
93 |
-
|
94 |
-
"""
|
95 |
-
# we get the data from session state: 1 is the dict 2 is the image.
|
96 |
-
# first, lets do an info display (popup)
|
97 |
-
metadata_str = json.dumps(st.session_state.public_observation)
|
98 |
-
|
99 |
-
st.toast(f"Uploading observations: {metadata_str}", icon="π¦")
|
100 |
-
tab_log = st.session_state.tab_log
|
101 |
-
if tab_log is not None:
|
102 |
-
tab_log.info(f"Uploading observations: {metadata_str}")
|
103 |
-
|
104 |
-
# get huggingface api
|
105 |
-
import os
|
106 |
-
token = os.environ.get("HF_TOKEN", None)
|
107 |
-
api = HfApi(token=token)
|
108 |
-
|
109 |
-
f = tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False)
|
110 |
-
f.write(metadata_str)
|
111 |
-
f.close()
|
112 |
-
st.info(f"temp file: {f.name} with metadata written...")
|
113 |
-
|
114 |
-
path_in_repo= f"metadata/{st.session_state.public_observation['author_email']}/{st.session_state.public_observation['image_md5']}.json"
|
115 |
-
msg = f"fname: {f.name} | path: {path_in_repo}"
|
116 |
-
print(msg)
|
117 |
-
st.warning(msg)
|
118 |
-
rv = api.upload_file(
|
119 |
-
path_or_fileobj=f.name,
|
120 |
-
path_in_repo=path_in_repo,
|
121 |
-
repo_id="Saving-Willy/temp_dataset",
|
122 |
-
repo_type="dataset",
|
123 |
-
)
|
124 |
-
print(rv)
|
125 |
-
msg = f"data attempted tx to repo happy walrus: {rv}"
|
126 |
-
g_logger.info(msg)
|
127 |
-
st.info(msg)
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
def main() -> None:
|
132 |
"""
|
133 |
Main entry point to set up the streamlit UI and run the application.
|
134 |
|
135 |
The organisation is as follows:
|
136 |
|
137 |
-
1.
|
138 |
2. the rest of the interface is organised in tabs:
|
139 |
|
140 |
- cetean classifier
|
@@ -156,25 +94,25 @@ def main() -> None:
|
|
156 |
#g_logger.warning("warning message")
|
157 |
|
158 |
# Streamlit app
|
159 |
-
#tab_gallery, tab_inference, tab_hotdogs, tab_map, tab_data, tab_log = st.tabs(["Cetecean classifier", "Hotdog classifier", "Map", "
|
160 |
-
tab_inference, tab_hotdogs, tab_map, tab_data, tab_log, tab_gallery = st.tabs(["Cetecean classifier", "Hotdog classifier", "Map", "
|
161 |
st.session_state.tab_log = tab_log
|
162 |
|
163 |
|
164 |
# create a sidebar, and parse all the input (returned as `observations` object)
|
165 |
-
observations =
|
166 |
|
167 |
|
168 |
if 0:## WIP
|
169 |
# goal of this code is to allow the user to override the ML prediction, before transmitting an observations
|
170 |
-
predicted_class = st.sidebar.selectbox("Predicted Class",
|
171 |
override_prediction = st.sidebar.checkbox("Override Prediction")
|
172 |
|
173 |
if override_prediction:
|
174 |
-
overridden_class = st.sidebar.selectbox("Override Class",
|
175 |
-
st.session_state.
|
176 |
else:
|
177 |
-
st.session_state.
|
178 |
|
179 |
|
180 |
with tab_map:
|
@@ -188,19 +126,19 @@ def main() -> None:
|
|
188 |
|
189 |
if show_db_points:
|
190 |
# show a nicer map, observations marked, tileset selectable.
|
191 |
-
|
192 |
dataset_id=dataset_id, data_files=data_files,
|
193 |
dbg_show_extra=dbg_show_extra)
|
194 |
|
195 |
else:
|
196 |
# development map.
|
197 |
-
|
198 |
|
199 |
|
200 |
with tab_log:
|
201 |
handler = st.session_state['handler']
|
202 |
if handler is not None:
|
203 |
-
records =
|
204 |
st.dataframe(records[::-1], use_container_width=True,)
|
205 |
st.info(f"Length of records: {len(records)}")
|
206 |
else:
|
@@ -230,19 +168,18 @@ def main() -> None:
|
|
230 |
# specific to the gallery (otherwise we get side effects)
|
231 |
tg_cont = st.container(key="swgallery")
|
232 |
with tg_cont:
|
233 |
-
|
234 |
|
235 |
|
236 |
-
# Display submitted
|
237 |
if st.sidebar.button("Validate"):
|
238 |
-
# create a dictionary with the submitted
|
239 |
submitted_data = observations
|
240 |
-
st.session_state.
|
241 |
|
242 |
-
tab_log.info(f"{st.session_state.
|
243 |
|
244 |
-
df = pd.DataFrame(submitted_data)
|
245 |
-
print("Dataframe Shape: ", df.shape)
|
246 |
with tab_data:
|
247 |
st.table(df)
|
248 |
|
@@ -254,7 +191,7 @@ def main() -> None:
|
|
254 |
# - the model predicts the top 3 most likely species from the input image
|
255 |
# - these species are shown
|
256 |
# - the user can override the species prediction using the dropdown
|
257 |
-
# - an
|
258 |
|
259 |
if tab_inference.button("Identify with cetacean classifier"):
|
260 |
#pipe = pipeline("image-classification", model="Saving-Willy/cetacean-classifier", trust_remote_code=True)
|
@@ -262,58 +199,12 @@ def main() -> None:
|
|
262 |
revision=classifier_revision,
|
263 |
trust_remote_code=True)
|
264 |
|
265 |
-
if st.session_state.
|
266 |
# TODO: cleaner design to disable the button until data input done?
|
267 |
st.info("Please upload an image first.")
|
268 |
else:
|
269 |
-
|
270 |
-
images = st.session_state.images
|
271 |
-
full_data = st.session_state.full_data
|
272 |
-
for file in files:
|
273 |
-
image = images[file]
|
274 |
-
data = full_data[file]
|
275 |
-
# run classifier model on `image`, and persistently store the output
|
276 |
-
out = cetacean_classifier(image) # get top 3 matches
|
277 |
-
st.session_state.whale_prediction1 = out['predictions'][0]
|
278 |
-
st.session_state.classify_whale_done = True
|
279 |
-
msg = f"[D]2 classify_whale_done: {st.session_state.classify_whale_done}, whale_prediction1: {st.session_state.whale_prediction1}"
|
280 |
-
# st.info(msg)
|
281 |
-
g_logger.info(msg)
|
282 |
-
|
283 |
-
# dropdown for selecting/overriding the species prediction
|
284 |
-
#st.info(f"[D] classify_whale_done: {st.session_state.classify_whale_done}, whale_prediction1: {st.session_state.whale_prediction1}")
|
285 |
-
if not st.session_state.classify_whale_done:
|
286 |
-
selected_class = tab_inference.sidebar.selectbox("Species", sw_wv.WHALE_CLASSES,
|
287 |
-
index=None, placeholder="Species not yet identified...",
|
288 |
-
disabled=True)
|
289 |
-
else:
|
290 |
-
pred1 = st.session_state.whale_prediction1
|
291 |
-
# get index of pred1 from WHALE_CLASSES, none if not present
|
292 |
-
print(f"[D] pred1: {pred1}")
|
293 |
-
ix = sw_wv.WHALE_CLASSES.index(pred1) if pred1 in sw_wv.WHALE_CLASSES else None
|
294 |
-
selected_class = tab_inference.selectbox("Species", sw_wv.WHALE_CLASSES, index=ix)
|
295 |
|
296 |
-
data['predicted_class'] = selected_class
|
297 |
-
if selected_class != st.session_state.whale_prediction1:
|
298 |
-
data['class_overriden'] = selected_class
|
299 |
-
|
300 |
-
st.session_state.public_observation = data
|
301 |
-
st.button("Upload observations to THE INTERNET!", on_click=push_observations)
|
302 |
-
# TODO: the metadata only fills properly if `validate` was clicked.
|
303 |
-
tab_inference.markdown(metadata2md())
|
304 |
-
|
305 |
-
msg = f"[D] full data after inference: {data}"
|
306 |
-
g_logger.debug(msg)
|
307 |
-
print(msg)
|
308 |
-
# TODO: add a link to more info on the model, next to the button.
|
309 |
-
|
310 |
-
whale_classes = out['predictions'][:]
|
311 |
-
# render images for the top 3 (that is what the model api returns)
|
312 |
-
with tab_inference:
|
313 |
-
st.markdown("## Species detected")
|
314 |
-
for i in range(len(whale_classes)):
|
315 |
-
sw_wv.display_whale(whale_classes, i)
|
316 |
-
|
317 |
|
318 |
|
319 |
|
@@ -329,29 +220,10 @@ def main() -> None:
|
|
329 |
|
330 |
if st.session_state.image is None:
|
331 |
st.info("Please upload an image first.")
|
332 |
-
st.info(str(observations.to_dict()))
|
333 |
|
334 |
else:
|
335 |
-
|
336 |
-
for file in st.session_state.files:
|
337 |
-
image = st.session_state.images[file]
|
338 |
-
data = st.session_state.full_data[file]
|
339 |
-
# display the image (use cached version, no need to reread)
|
340 |
-
col1.image(image, use_column_width=True)
|
341 |
-
# and then run inference on the image
|
342 |
-
hotdog_image = Image.fromarray(image)
|
343 |
-
predictions = pipeline_hot_dog(hotdog_image)
|
344 |
-
|
345 |
-
col2.header("Probabilities")
|
346 |
-
first = True
|
347 |
-
for p in predictions:
|
348 |
-
col2.subheader(f"{ p['label'] }: { round(p['score'] * 100, 1)}%")
|
349 |
-
if first:
|
350 |
-
data['predicted_class'] = p['label']
|
351 |
-
data['predicted_score'] = round(p['score'] * 100, 1)
|
352 |
-
first = False
|
353 |
-
|
354 |
-
tab_hotdogs.write(f"Session Data: {json.dumps(data)}")
|
355 |
|
356 |
|
357 |
|
|
|
|
|
|
|
|
|
|
|
1 |
import logging
|
2 |
import os
|
|
|
3 |
|
4 |
import pandas as pd
|
5 |
import streamlit as st
|
|
|
6 |
import folium
|
7 |
from streamlit_folium import st_folium
|
8 |
+
|
9 |
from transformers import pipeline
|
10 |
from transformers import AutoModelForImageClassification
|
11 |
|
12 |
from datasets import disable_caching
|
13 |
disable_caching()
|
14 |
|
15 |
+
import whale_gallery as gallery
|
16 |
+
import whale_viewer as viewer
|
17 |
+
from input.input_handling import setup_input
|
18 |
+
from maps.alps_map import present_alps_map
|
19 |
+
from maps.obs_map import present_obs_map
|
20 |
+
from utils.st_logs import setup_logging, parse_log_buffer
|
21 |
+
from classifier.classifier_image import cetacean_classify
|
22 |
+
from classifier.classifier_hotdog import hotdog_classify
|
23 |
|
24 |
|
25 |
# setup for the ML model on huggingface (our wrapper)
|
|
|
39 |
g_logger.setLevel(LOG_LEVEL)
|
40 |
|
41 |
st.set_page_config(layout="wide")
|
|
|
|
|
|
|
42 |
|
43 |
# initialise various session state variables
|
44 |
if "handler" not in st.session_state:
|
45 |
+
st.session_state['handler'] = setup_logging()
|
46 |
+
|
47 |
+
if "observations" not in st.session_state:
|
48 |
+
st.session_state.observations = {}
|
49 |
+
|
50 |
+
if "images" not in st.session_state:
|
51 |
+
st.session_state.images = {}
|
52 |
|
53 |
+
if "files" not in st.session_state:
|
54 |
+
st.session_state.files = {}
|
55 |
+
|
56 |
+
if "public_observation" not in st.session_state:
|
57 |
+
st.session_state.public_observation = {}
|
58 |
|
59 |
if "classify_whale_done" not in st.session_state:
|
60 |
st.session_state.classify_whale_done = False
|
61 |
|
62 |
if "whale_prediction1" not in st.session_state:
|
63 |
st.session_state.whale_prediction1 = None
|
|
|
|
|
|
|
64 |
|
65 |
if "tab_log" not in st.session_state:
|
66 |
st.session_state.tab_log = None
|
67 |
|
68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
def main() -> None:
|
70 |
"""
|
71 |
Main entry point to set up the streamlit UI and run the application.
|
72 |
|
73 |
The organisation is as follows:
|
74 |
|
75 |
+
1. observation input (a new observations) is handled in the sidebar
|
76 |
2. the rest of the interface is organised in tabs:
|
77 |
|
78 |
- cetean classifier
|
|
|
94 |
#g_logger.warning("warning message")
|
95 |
|
96 |
# Streamlit app
|
97 |
+
#tab_gallery, tab_inference, tab_hotdogs, tab_map, tab_data, tab_log = st.tabs(["Cetecean classifier", "Hotdog classifier", "Map", "observation", "Log", "Beautiful cetaceans"])
|
98 |
+
tab_inference, tab_hotdogs, tab_map, tab_data, tab_log, tab_gallery = st.tabs(["Cetecean classifier", "Hotdog classifier", "Map", "observation", "Log", "Beautiful cetaceans"])
|
99 |
st.session_state.tab_log = tab_log
|
100 |
|
101 |
|
102 |
# create a sidebar, and parse all the input (returned as `observations` object)
|
103 |
+
observations = setup_input(viewcontainer=st.sidebar)
|
104 |
|
105 |
|
106 |
if 0:## WIP
|
107 |
# goal of this code is to allow the user to override the ML prediction, before transmitting an observations
|
108 |
+
predicted_class = st.sidebar.selectbox("Predicted Class", viewer.WHALE_CLASSES)
|
109 |
override_prediction = st.sidebar.checkbox("Override Prediction")
|
110 |
|
111 |
if override_prediction:
|
112 |
+
overridden_class = st.sidebar.selectbox("Override Class", viewer.WHALE_CLASSES)
|
113 |
+
st.session_state.observations['class_overriden'] = overridden_class
|
114 |
else:
|
115 |
+
st.session_state.observations['class_overriden'] = None
|
116 |
|
117 |
|
118 |
with tab_map:
|
|
|
126 |
|
127 |
if show_db_points:
|
128 |
# show a nicer map, observations marked, tileset selectable.
|
129 |
+
st_observation = present_obs_map(
|
130 |
dataset_id=dataset_id, data_files=data_files,
|
131 |
dbg_show_extra=dbg_show_extra)
|
132 |
|
133 |
else:
|
134 |
# development map.
|
135 |
+
st_observation = present_alps_map()
|
136 |
|
137 |
|
138 |
with tab_log:
|
139 |
handler = st.session_state['handler']
|
140 |
if handler is not None:
|
141 |
+
records = parse_log_buffer(handler.buffer)
|
142 |
st.dataframe(records[::-1], use_container_width=True,)
|
143 |
st.info(f"Length of records: {len(records)}")
|
144 |
else:
|
|
|
168 |
# specific to the gallery (otherwise we get side effects)
|
169 |
tg_cont = st.container(key="swgallery")
|
170 |
with tg_cont:
|
171 |
+
gallery.render_whale_gallery(n_cols=4)
|
172 |
|
173 |
|
174 |
+
# Display submitted observation
|
175 |
if st.sidebar.button("Validate"):
|
176 |
+
# create a dictionary with the submitted observation
|
177 |
submitted_data = observations
|
178 |
+
st.session_state.observations = observations
|
179 |
|
180 |
+
tab_log.info(f"{st.session_state.observations}")
|
181 |
|
182 |
+
df = pd.DataFrame(submitted_data, index=[0])
|
|
|
183 |
with tab_data:
|
184 |
st.table(df)
|
185 |
|
|
|
191 |
# - the model predicts the top 3 most likely species from the input image
|
192 |
# - these species are shown
|
193 |
# - the user can override the species prediction using the dropdown
|
194 |
+
# - an observation is uploaded if the user chooses.
|
195 |
|
196 |
if tab_inference.button("Identify with cetacean classifier"):
|
197 |
#pipe = pipeline("image-classification", model="Saving-Willy/cetacean-classifier", trust_remote_code=True)
|
|
|
199 |
revision=classifier_revision,
|
200 |
trust_remote_code=True)
|
201 |
|
202 |
+
if st.session_state.images is None:
|
203 |
# TODO: cleaner design to disable the button until data input done?
|
204 |
st.info("Please upload an image first.")
|
205 |
else:
|
206 |
+
cetacean_classify(cetacean_classifier, tab_inference)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
207 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
208 |
|
209 |
|
210 |
|
|
|
220 |
|
221 |
if st.session_state.image is None:
|
222 |
st.info("Please upload an image first.")
|
223 |
+
#st.info(str(observations.to_dict()))
|
224 |
|
225 |
else:
|
226 |
+
hotdog_classify(pipeline_hot_dog, tab_hotdogs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
227 |
|
228 |
|
229 |
|
src/{alps_map.py β maps/alps_map.py}
RENAMED
File without changes
|
src/{obs_map.py β maps/obs_map.py}
RENAMED
@@ -7,8 +7,8 @@ import streamlit as st
|
|
7 |
import folium
|
8 |
from streamlit_folium import st_folium
|
9 |
|
10 |
-
import whale_viewer as
|
11 |
-
from fix_tabrender import js_show_zeroheight_iframe
|
12 |
|
13 |
m_logger = logging.getLogger(__name__)
|
14 |
# we can set the log level locally for funcs in this module
|
@@ -60,7 +60,7 @@ _colors = [
|
|
60 |
"#778899" # Light Slate Gray
|
61 |
]
|
62 |
|
63 |
-
whale2color = {k: v for k, v in zip(
|
64 |
|
65 |
def create_map(tile_name:str, location:Tuple[float], zoom_start: int = 7) -> folium.Map:
|
66 |
"""
|
|
|
7 |
import folium
|
8 |
from streamlit_folium import st_folium
|
9 |
|
10 |
+
import whale_viewer as viewer
|
11 |
+
from utils.fix_tabrender import js_show_zeroheight_iframe
|
12 |
|
13 |
m_logger = logging.getLogger(__name__)
|
14 |
# we can set the log level locally for funcs in this module
|
|
|
60 |
"#778899" # Light Slate Gray
|
61 |
]
|
62 |
|
63 |
+
whale2color = {k: v for k, v in zip(viewer.WHALE_CLASSES, _colors)}
|
64 |
|
65 |
def create_map(tile_name:str, location:Tuple[float], zoom_start: int = 7) -> folium.Map:
|
66 |
"""
|
src/{fix_tabrender.py β utils/fix_tabrender.py}
RENAMED
File without changes
|
src/utils/grid_maker.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import math
|
3 |
+
|
4 |
+
def gridder(files):
|
5 |
+
cols = st.columns(3)
|
6 |
+
with cols[0]:
|
7 |
+
batch_size = st.select_slider("Batch size:",range(10,110,10), value=10)
|
8 |
+
with cols[1]:
|
9 |
+
row_size = st.select_slider("Row size:", range(1,6), value = 5)
|
10 |
+
num_batches = math.ceil(len(files)/batch_size)
|
11 |
+
with cols[2]:
|
12 |
+
page = st.selectbox("Page", range(1,num_batches+1))
|
13 |
+
return batch_size, row_size, page
|
src/utils/metadata_handler.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
|
3 |
+
def metadata2md() -> str:
|
4 |
+
"""Get metadata from cache and return as markdown-formatted key-value list
|
5 |
+
|
6 |
+
Returns:
|
7 |
+
str: Markdown-formatted key-value list of metadata
|
8 |
+
|
9 |
+
"""
|
10 |
+
markdown_str = "\n"
|
11 |
+
keys_to_print = ["latitude","longitude","author_email","date","time"]
|
12 |
+
for key, value in st.session_state.public_observation.items():
|
13 |
+
if key in keys_to_print:
|
14 |
+
markdown_str += f"- **{key}**: {value}\n"
|
15 |
+
return markdown_str
|
16 |
+
|
src/{st_logs.py β utils/st_logs.py}
RENAMED
File without changes
|
src/whale_viewer.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
from typing import List
|
2 |
-
|
3 |
from PIL import Image
|
4 |
import pandas as pd
|
5 |
import os
|
@@ -134,7 +134,7 @@ def display_whale(whale_classes:List[str], i:int, viewcontainer=None):
|
|
134 |
TODO: how to find the object type of viewcontainer.? they are just "deltagenerators" but
|
135 |
we want the result of the generator.. In any case, it works ok with either call signature.
|
136 |
"""
|
137 |
-
|
138 |
if viewcontainer is None:
|
139 |
viewcontainer = st
|
140 |
|
@@ -148,11 +148,10 @@ def display_whale(whale_classes:List[str], i:int, viewcontainer=None):
|
|
148 |
|
149 |
|
150 |
viewcontainer.markdown(
|
151 |
-
"
|
152 |
)
|
153 |
current_dir = os.getcwd()
|
154 |
image_path = os.path.join(current_dir, "src/images/references/")
|
155 |
image = Image.open(image_path + df_whale_img_ref.loc[whale_classes[i], "WHALE_IMAGES"])
|
156 |
|
157 |
-
viewcontainer.image(image, caption=df_whale_img_ref.loc[whale_classes[i], "WHALE_REFERENCES"])
|
158 |
-
# link st.markdown(f"[{df.loc[whale_classes[i], 'WHALE_REFERENCES']}]({df.loc[whale_classes[i], 'WHALE_REFERENCES']})")
|
|
|
1 |
from typing import List
|
2 |
+
import streamlit as st
|
3 |
from PIL import Image
|
4 |
import pandas as pd
|
5 |
import os
|
|
|
134 |
TODO: how to find the object type of viewcontainer.? they are just "deltagenerators" but
|
135 |
we want the result of the generator.. In any case, it works ok with either call signature.
|
136 |
"""
|
137 |
+
|
138 |
if viewcontainer is None:
|
139 |
viewcontainer = st
|
140 |
|
|
|
148 |
|
149 |
|
150 |
viewcontainer.markdown(
|
151 |
+
":whale: #" + str(i + 1) + ": " + format_whale_name(whale_classes[i])
|
152 |
)
|
153 |
current_dir = os.getcwd()
|
154 |
image_path = os.path.join(current_dir, "src/images/references/")
|
155 |
image = Image.open(image_path + df_whale_img_ref.loc[whale_classes[i], "WHALE_IMAGES"])
|
156 |
|
157 |
+
viewcontainer.image(image, caption=df_whale_img_ref.loc[whale_classes[i], "WHALE_REFERENCES"], use_column_width=True)
|
|