saving-willy-dev / src /classifier /classifier_image.py
rmm
chore: moved all session state init to relevant modules
94698a8
raw
history blame
12.1 kB
import streamlit as st
import logging
# get a global var for logger accessor in this module
LOG_LEVEL = logging.DEBUG
g_logger = logging.getLogger(__name__)
g_logger.setLevel(LOG_LEVEL)
import whale_viewer as viewer
from hf_push_observations import push_observations
from utils.grid_maker import gridder
from utils.metadata_handler import metadata2md
from input.input_observation import InputObservation
def init_classifier_session_states() -> None:
'''
Initialise the session state variables used in classification
'''
if "classify_whale_done" not in st.session_state:
st.session_state.classify_whale_done = {}
if "whale_prediction1" not in st.session_state:
st.session_state.whale_prediction1 = {}
def add_classifier_header() -> None:
"""
Add brief explainer text about cetacean classification to the tab
"""
st.markdown("""
*Run classifer to identify the species of cetean on the uploaded image.
Once inference is complete, the top three predictions are shown.
You can override the prediction by selecting a species from the dropdown.*""")
# func to just run classification, store results.
def cetacean_just_classify(cetacean_classifier):
"""
Infer cetacean species for all observations in the session state.
- this function runs the classifier, and stores results in the session state.
- the top 3 predictions are stored in the observation object, which is retained
in st.session_state.observations
- to display results use cetacean_show_results() or cetacean_show_results_and_review()
Args:
cetacean_classifier ([type]): saving-willy model from Saving Willy Hugging Face space
"""
images = st.session_state.images
#observations = st.session_state.observations
hashes = st.session_state.image_hashes
for hash in hashes:
image = images[hash]
# run classifier model on `image`, and persistently store the output
out = cetacean_classifier(image) # get top 3 matches
st.session_state.whale_prediction1[hash] = out['predictions'][0]
st.session_state.classify_whale_done[hash] = True
st.session_state.observations[hash].set_top_predictions(out['predictions'][:])
msg = f"[D]2 classify_whale_done for {hash}: {st.session_state.classify_whale_done[hash]}, whale_prediction1: {st.session_state.whale_prediction1[hash]}"
g_logger.info(msg)
if st.session_state.MODE_DEV_STATEFUL:
st.write(f"*[D] Observation {hash} classified as {st.session_state.whale_prediction1[hash]}*")
# func to show results and allow review
def cetacean_show_results_and_review() -> None:
"""
Present classification results and allow user to review and override the prediction.
- for each observation in the session state, displays the image, summarised
metadata, and the top 3 predictions.
- allows user to override the prediction by selecting a species from the dropdown.
- the selected species is stored in the observation object, which is retained in
st.session_state.observations
"""
images = st.session_state.images
observations = st.session_state.observations
hashes = st.session_state.image_hashes
batch_size, row_size, page = gridder(hashes)
grid = st.columns(row_size)
col = 0
o = 1
for hash in hashes:
image = images[hash]
#observation = observations[hash].to_dict()
_observation:InputObservation = observations[hash]
with grid[col]:
st.image(image, use_column_width=True)
# dropdown for selecting/overriding the species prediction
if not st.session_state.classify_whale_done[hash]:
selected_class = st.sidebar.selectbox("Species", viewer.WHALE_CLASSES,
index=None, placeholder="Species not yet identified...",
disabled=True)
else:
pred1 = st.session_state.whale_prediction1[hash]
# get index of pred1 from WHALE_CLASSES, none if not present
print(f"[D] {o:3} pred1: {pred1:30} | {hash}")
ix = viewer.WHALE_CLASSES.index(pred1) if pred1 in viewer.WHALE_CLASSES else None
selected_class = st.selectbox(f"Species for observation {str(o)}", viewer.WHALE_CLASSES, index=ix)
_observation.set_selected_class(selected_class)
#observation['predicted_class'] = selected_class
# this logic is now in the InputObservation class automatially
#if selected_class != st.session_state.whale_prediction1[hash]:
# observation['class_overriden'] = selected_class # TODO: this should be boolean!
# store the elements of the observation that will be transmitted (not image)
observation = _observation.to_dict()
st.session_state.public_observations[hash] = observation
#st.button(f"Upload observation {str(o)} to THE INTERNET!", on_click=push_observations)
# TODO: the metadata only fills properly if `validate` was clicked.
st.markdown(metadata2md(hash, debug=True))
msg = f"[D] full observation after inference: {observation}"
g_logger.debug(msg)
print(msg)
# TODO: add a link to more info on the model, next to the button.
whale_classes = observations[hash].top_predictions
# render images for the top 3 (that is what the model api returns)
n = len(whale_classes)
st.markdown(f"**Top {n} Predictions for observation {str(o)}**")
for i in range(n):
viewer.display_whale(whale_classes, i)
o += 1
col = (col + 1) % row_size
# func to just present results
def cetacean_show_results():
"""
Present classification results that may be pushed to the online dataset.
- for each observation in the session state, displays the image, summarised
metadata, the top 3 predictions, and the selected species (which may have
been manually selected, or the top prediction accepted).
"""
images = st.session_state.images
observations = st.session_state.observations
hashes = st.session_state.image_hashes
batch_size, row_size, page = gridder(hashes)
grid = st.columns(row_size)
col = 0
o = 1
for hash in hashes:
image = images[hash]
observation = observations[hash].to_dict()
with grid[col]:
st.image(image, use_column_width=True)
# # dropdown for selecting/overriding the species prediction
# if not st.session_state.classify_whale_done[hash]:
# selected_class = st.sidebar.selectbox("Species", viewer.WHALE_CLASSES,
# index=None, placeholder="Species not yet identified...",
# disabled=True)
# else:
# pred1 = st.session_state.whale_prediction1[hash]
# # get index of pred1 from WHALE_CLASSES, none if not present
# print(f"[D] pred1: {pred1}")
# ix = viewer.WHALE_CLASSES.index(pred1) if pred1 in viewer.WHALE_CLASSES else None
# selected_class = st.selectbox(f"Species for observation {str(o)}", viewer.WHALE_CLASSES, index=ix)
# observation['predicted_class'] = selected_class
# if selected_class != st.session_state.whale_prediction1[hash]:
# observation['class_overriden'] = selected_class # TODO: this should be boolean!
# st.session_state.public_observation = observation
#st.button(f"Upload observation {str(o)} to THE INTERNET!", on_click=push_observations)
#
st.markdown(metadata2md(hash, debug=True))
msg = f"[D] full observation after inference: {observation}"
g_logger.debug(msg)
print(msg)
# TODO: add a link to more info on the model, next to the button.
whale_classes = observations[hash].top_predictions
# render images for the top 3 (that is what the model api returns)
n = len(whale_classes)
st.markdown(f"**Top {n} Predictions for observation {str(o)}**")
for i in range(n):
viewer.display_whale(whale_classes, i)
o += 1
col = (col + 1) % row_size
# func to do all in one
def cetacean_classify_show_and_review(cetacean_classifier):
"""Cetacean classifier using the saving-willy model from Saving Willy Hugging Face space.
For each image in the session state, classify the image and display the top 3 predictions.
Args:
cetacean_classifier ([type]): saving-willy model from Saving Willy Hugging Face space
"""
raise DeprecationWarning("This function is deprecated. Use individual steps instead")
images = st.session_state.images
observations = st.session_state.observations
hashes = st.session_state.image_hashes
batch_size, row_size, page = gridder(hashes)
grid = st.columns(row_size)
col = 0
o=1
for hash in hashes:
image = images[hash]
with grid[col]:
st.image(image, use_column_width=True)
observation = observations[hash].to_dict()
# run classifier model on `image`, and persistently store the output
out = cetacean_classifier(image) # get top 3 matches
st.session_state.whale_prediction1[hash] = out['predictions'][0]
st.session_state.classify_whale_done[hash] = True
msg = f"[D]2 classify_whale_done for {hash}: {st.session_state.classify_whale_done[hash]}, whale_prediction1: {st.session_state.whale_prediction1[hash]}"
g_logger.info(msg)
# dropdown for selecting/overriding the species prediction
if not st.session_state.classify_whale_done[hash]:
selected_class = st.sidebar.selectbox("Species", viewer.WHALE_CLASSES,
index=None, placeholder="Species not yet identified...",
disabled=True)
else:
pred1 = st.session_state.whale_prediction1[hash]
# get index of pred1 from WHALE_CLASSES, none if not present
print(f"[D] pred1: {pred1}")
ix = viewer.WHALE_CLASSES.index(pred1) if pred1 in viewer.WHALE_CLASSES else None
selected_class = st.selectbox(f"Species for observation {str(o)}", viewer.WHALE_CLASSES, index=ix)
observation['predicted_class'] = selected_class
if selected_class != st.session_state.whale_prediction1[hash]:
observation['class_overriden'] = selected_class
st.session_state.public_observation = observation
st.button(f"Upload observation {str(o)} to THE INTERNET!", on_click=push_observations)
# TODO: the metadata only fills properly if `validate` was clicked.
st.markdown(metadata2md())
msg = f"[D] full observation after inference: {observation}"
g_logger.debug(msg)
print(msg)
# TODO: add a link to more info on the model, next to the button.
whale_classes = out['predictions'][:]
# render images for the top 3 (that is what the model api returns)
st.markdown(f"Top 3 Predictions for observation {str(o)}")
for i in range(len(whale_classes)):
viewer.display_whale(whale_classes, i)
o += 1
col = (col + 1) % row_size