Spaces:
Sleeping
Sleeping
File size: 12,100 Bytes
0e8c927 c915f7c 0e8c927 94698a8 1021b6c 4854d2c 60a7864 d4ec4a0 60a7864 d4ec4a0 c915f7c d4ec4a0 5823912 d4ec4a0 60a7864 d4ec4a0 c915f7c d4ec4a0 5823912 d4ec4a0 c915f7c d4ec4a0 c915f7c 0e02e00 c915f7c 5823912 d4ec4a0 c915f7c d4ec4a0 5823912 d4ec4a0 60a7864 d4ec4a0 5823912 c915f7c d4ec4a0 5823912 d4ec4a0 7b28238 0e02e00 0e8c927 1c0e2a5 0e8c927 1c0e2a5 0e8c927 1c0e2a5 0e8c927 7a5f0ca 0e8c927 7a5f0ca 0e8c927 7a5f0ca 0e8c927 1c0e2a5 0e8c927 7a5f0ca 0e8c927 1c0e2a5 0e8c927 1c0e2a5 0e8c927 1c0e2a5 c915f7c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 |
import streamlit as st
import logging
# get a global var for logger accessor in this module
LOG_LEVEL = logging.DEBUG
g_logger = logging.getLogger(__name__)
g_logger.setLevel(LOG_LEVEL)
import whale_viewer as viewer
from hf_push_observations import push_observations
from utils.grid_maker import gridder
from utils.metadata_handler import metadata2md
from input.input_observation import InputObservation
def init_classifier_session_states() -> None:
'''
Initialise the session state variables used in classification
'''
if "classify_whale_done" not in st.session_state:
st.session_state.classify_whale_done = {}
if "whale_prediction1" not in st.session_state:
st.session_state.whale_prediction1 = {}
def add_classifier_header() -> None:
"""
Add brief explainer text about cetacean classification to the tab
"""
st.markdown("""
*Run classifer to identify the species of cetean on the uploaded image.
Once inference is complete, the top three predictions are shown.
You can override the prediction by selecting a species from the dropdown.*""")
# func to just run classification, store results.
def cetacean_just_classify(cetacean_classifier):
"""
Infer cetacean species for all observations in the session state.
- this function runs the classifier, and stores results in the session state.
- the top 3 predictions are stored in the observation object, which is retained
in st.session_state.observations
- to display results use cetacean_show_results() or cetacean_show_results_and_review()
Args:
cetacean_classifier ([type]): saving-willy model from Saving Willy Hugging Face space
"""
images = st.session_state.images
#observations = st.session_state.observations
hashes = st.session_state.image_hashes
for hash in hashes:
image = images[hash]
# run classifier model on `image`, and persistently store the output
out = cetacean_classifier(image) # get top 3 matches
st.session_state.whale_prediction1[hash] = out['predictions'][0]
st.session_state.classify_whale_done[hash] = True
st.session_state.observations[hash].set_top_predictions(out['predictions'][:])
msg = f"[D]2 classify_whale_done for {hash}: {st.session_state.classify_whale_done[hash]}, whale_prediction1: {st.session_state.whale_prediction1[hash]}"
g_logger.info(msg)
if st.session_state.MODE_DEV_STATEFUL:
st.write(f"*[D] Observation {hash} classified as {st.session_state.whale_prediction1[hash]}*")
# func to show results and allow review
def cetacean_show_results_and_review() -> None:
"""
Present classification results and allow user to review and override the prediction.
- for each observation in the session state, displays the image, summarised
metadata, and the top 3 predictions.
- allows user to override the prediction by selecting a species from the dropdown.
- the selected species is stored in the observation object, which is retained in
st.session_state.observations
"""
images = st.session_state.images
observations = st.session_state.observations
hashes = st.session_state.image_hashes
batch_size, row_size, page = gridder(hashes)
grid = st.columns(row_size)
col = 0
o = 1
for hash in hashes:
image = images[hash]
#observation = observations[hash].to_dict()
_observation:InputObservation = observations[hash]
with grid[col]:
st.image(image, use_column_width=True)
# dropdown for selecting/overriding the species prediction
if not st.session_state.classify_whale_done[hash]:
selected_class = st.sidebar.selectbox("Species", viewer.WHALE_CLASSES,
index=None, placeholder="Species not yet identified...",
disabled=True)
else:
pred1 = st.session_state.whale_prediction1[hash]
# get index of pred1 from WHALE_CLASSES, none if not present
print(f"[D] {o:3} pred1: {pred1:30} | {hash}")
ix = viewer.WHALE_CLASSES.index(pred1) if pred1 in viewer.WHALE_CLASSES else None
selected_class = st.selectbox(f"Species for observation {str(o)}", viewer.WHALE_CLASSES, index=ix)
_observation.set_selected_class(selected_class)
#observation['predicted_class'] = selected_class
# this logic is now in the InputObservation class automatially
#if selected_class != st.session_state.whale_prediction1[hash]:
# observation['class_overriden'] = selected_class # TODO: this should be boolean!
# store the elements of the observation that will be transmitted (not image)
observation = _observation.to_dict()
st.session_state.public_observations[hash] = observation
#st.button(f"Upload observation {str(o)} to THE INTERNET!", on_click=push_observations)
# TODO: the metadata only fills properly if `validate` was clicked.
st.markdown(metadata2md(hash, debug=True))
msg = f"[D] full observation after inference: {observation}"
g_logger.debug(msg)
print(msg)
# TODO: add a link to more info on the model, next to the button.
whale_classes = observations[hash].top_predictions
# render images for the top 3 (that is what the model api returns)
n = len(whale_classes)
st.markdown(f"**Top {n} Predictions for observation {str(o)}**")
for i in range(n):
viewer.display_whale(whale_classes, i)
o += 1
col = (col + 1) % row_size
# func to just present results
def cetacean_show_results():
"""
Present classification results that may be pushed to the online dataset.
- for each observation in the session state, displays the image, summarised
metadata, the top 3 predictions, and the selected species (which may have
been manually selected, or the top prediction accepted).
"""
images = st.session_state.images
observations = st.session_state.observations
hashes = st.session_state.image_hashes
batch_size, row_size, page = gridder(hashes)
grid = st.columns(row_size)
col = 0
o = 1
for hash in hashes:
image = images[hash]
observation = observations[hash].to_dict()
with grid[col]:
st.image(image, use_column_width=True)
# # dropdown for selecting/overriding the species prediction
# if not st.session_state.classify_whale_done[hash]:
# selected_class = st.sidebar.selectbox("Species", viewer.WHALE_CLASSES,
# index=None, placeholder="Species not yet identified...",
# disabled=True)
# else:
# pred1 = st.session_state.whale_prediction1[hash]
# # get index of pred1 from WHALE_CLASSES, none if not present
# print(f"[D] pred1: {pred1}")
# ix = viewer.WHALE_CLASSES.index(pred1) if pred1 in viewer.WHALE_CLASSES else None
# selected_class = st.selectbox(f"Species for observation {str(o)}", viewer.WHALE_CLASSES, index=ix)
# observation['predicted_class'] = selected_class
# if selected_class != st.session_state.whale_prediction1[hash]:
# observation['class_overriden'] = selected_class # TODO: this should be boolean!
# st.session_state.public_observation = observation
#st.button(f"Upload observation {str(o)} to THE INTERNET!", on_click=push_observations)
#
st.markdown(metadata2md(hash, debug=True))
msg = f"[D] full observation after inference: {observation}"
g_logger.debug(msg)
print(msg)
# TODO: add a link to more info on the model, next to the button.
whale_classes = observations[hash].top_predictions
# render images for the top 3 (that is what the model api returns)
n = len(whale_classes)
st.markdown(f"**Top {n} Predictions for observation {str(o)}**")
for i in range(n):
viewer.display_whale(whale_classes, i)
o += 1
col = (col + 1) % row_size
# func to do all in one
def cetacean_classify_show_and_review(cetacean_classifier):
"""Cetacean classifier using the saving-willy model from Saving Willy Hugging Face space.
For each image in the session state, classify the image and display the top 3 predictions.
Args:
cetacean_classifier ([type]): saving-willy model from Saving Willy Hugging Face space
"""
raise DeprecationWarning("This function is deprecated. Use individual steps instead")
images = st.session_state.images
observations = st.session_state.observations
hashes = st.session_state.image_hashes
batch_size, row_size, page = gridder(hashes)
grid = st.columns(row_size)
col = 0
o=1
for hash in hashes:
image = images[hash]
with grid[col]:
st.image(image, use_column_width=True)
observation = observations[hash].to_dict()
# run classifier model on `image`, and persistently store the output
out = cetacean_classifier(image) # get top 3 matches
st.session_state.whale_prediction1[hash] = out['predictions'][0]
st.session_state.classify_whale_done[hash] = True
msg = f"[D]2 classify_whale_done for {hash}: {st.session_state.classify_whale_done[hash]}, whale_prediction1: {st.session_state.whale_prediction1[hash]}"
g_logger.info(msg)
# dropdown for selecting/overriding the species prediction
if not st.session_state.classify_whale_done[hash]:
selected_class = st.sidebar.selectbox("Species", viewer.WHALE_CLASSES,
index=None, placeholder="Species not yet identified...",
disabled=True)
else:
pred1 = st.session_state.whale_prediction1[hash]
# get index of pred1 from WHALE_CLASSES, none if not present
print(f"[D] pred1: {pred1}")
ix = viewer.WHALE_CLASSES.index(pred1) if pred1 in viewer.WHALE_CLASSES else None
selected_class = st.selectbox(f"Species for observation {str(o)}", viewer.WHALE_CLASSES, index=ix)
observation['predicted_class'] = selected_class
if selected_class != st.session_state.whale_prediction1[hash]:
observation['class_overriden'] = selected_class
st.session_state.public_observation = observation
st.button(f"Upload observation {str(o)} to THE INTERNET!", on_click=push_observations)
# TODO: the metadata only fills properly if `validate` was clicked.
st.markdown(metadata2md())
msg = f"[D] full observation after inference: {observation}"
g_logger.debug(msg)
print(msg)
# TODO: add a link to more info on the model, next to the button.
whale_classes = out['predictions'][:]
# render images for the top 3 (that is what the model api returns)
st.markdown(f"Top 3 Predictions for observation {str(o)}")
for i in range(len(whale_classes)):
viewer.display_whale(whale_classes, i)
o += 1
col = (col + 1) % row_size
|