Spaces:
Sleeping
Sleeping
File size: 10,178 Bytes
0e8c927 4854d2c d4ec4a0 7b28238 0e8c927 1c0e2a5 0e8c927 1c0e2a5 0e8c927 1c0e2a5 0e8c927 7a5f0ca 0e8c927 7a5f0ca 0e8c927 7a5f0ca 0e8c927 1c0e2a5 0e8c927 7a5f0ca 0e8c927 1c0e2a5 0e8c927 1c0e2a5 0e8c927 1c0e2a5 0e8c927 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 |
import streamlit as st
import logging
# get a global var for logger accessor in this module
LOG_LEVEL = logging.DEBUG
g_logger = logging.getLogger(__name__)
g_logger.setLevel(LOG_LEVEL)
import whale_viewer as viewer
from hf_push_observations import push_observations
from utils.grid_maker import gridder
from utils.metadata_handler import metadata2md
def add_header_text() -> None:
"""
Add brief explainer text about cetacean classification to the tab
"""
st.markdown("""
*Run classifer to identify the species of cetean on the uploaded image.
Once inference is complete, the top three predictions are shown.
You can override the prediction by selecting a species from the dropdown.*""")
# func to just run classification, store results.
def cetacean_just_classify(cetacean_classifier):
images = st.session_state.images
observations = st.session_state.observations
hashes = st.session_state.image_hashes
for hash in hashes:
image = images[hash]
observation = observations[hash].to_dict()
# run classifier model on `image`, and persistently store the output
out = cetacean_classifier(image) # get top 3 matches
st.session_state.whale_prediction1[hash] = out['predictions'][0]
st.session_state.classify_whale_done[hash] = True
st.session_state.observations[hash].set_top_predictions(out['predictions'][:])
msg = f"[D]2 classify_whale_done for {hash}: {st.session_state.classify_whale_done[hash]}, whale_prediction1: {st.session_state.whale_prediction1[hash]}"
g_logger.info(msg)
# TODO: what is the difference between public and regular; and why is this not array-ready?
st.session_state.public_observation = observation
st.write(f"*[D] Observation {hash} classified as {st.session_state.whale_prediction1[hash]}*")
# func to show results and allow review
def cetacean_show_results_and_review():
images = st.session_state.images
observations = st.session_state.observations
hashes = st.session_state.image_hashes
batch_size, row_size, page = gridder(hashes)
grid = st.columns(row_size)
col = 0
o = 1
for hash in hashes:
image = images[hash]
observation = observations[hash].to_dict()
with grid[col]:
st.image(image, use_column_width=True)
# dropdown for selecting/overriding the species prediction
if not st.session_state.classify_whale_done[hash]:
selected_class = st.sidebar.selectbox("Species", viewer.WHALE_CLASSES,
index=None, placeholder="Species not yet identified...",
disabled=True)
else:
pred1 = st.session_state.whale_prediction1[hash]
# get index of pred1 from WHALE_CLASSES, none if not present
print(f"[D] pred1: {pred1}")
ix = viewer.WHALE_CLASSES.index(pred1) if pred1 in viewer.WHALE_CLASSES else None
selected_class = st.selectbox(f"Species for observation {str(o)}", viewer.WHALE_CLASSES, index=ix)
observation['predicted_class'] = selected_class
if selected_class != st.session_state.whale_prediction1[hash]:
observation['class_overriden'] = selected_class # TODO: this should be boolean!
st.session_state.public_observation = observation
st.button(f"Upload observation {str(o)} to THE INTERNET!", on_click=push_observations)
# TODO: the metadata only fills properly if `validate` was clicked.
st.markdown(metadata2md())
msg = f"[D] full observation after inference: {observation}"
g_logger.debug(msg)
print(msg)
# TODO: add a link to more info on the model, next to the button.
whale_classes = observations[hash].top_predictions
# render images for the top 3 (that is what the model api returns)
n = len(whale_classes)
st.markdown(f"Top {n} Predictions for observation {str(o)}")
for i in range(n):
viewer.display_whale(whale_classes, i)
o += 1
col = (col + 1) % row_size
# func to just present results
def cetacean_show_results():
images = st.session_state.images
observations = st.session_state.observations
hashes = st.session_state.image_hashes
batch_size, row_size, page = gridder(hashes)
grid = st.columns(row_size)
col = 0
o = 1
for hash in hashes:
image = images[hash]
observation = observations[hash].to_dict()
with grid[col]:
st.image(image, use_column_width=True)
# # dropdown for selecting/overriding the species prediction
# if not st.session_state.classify_whale_done[hash]:
# selected_class = st.sidebar.selectbox("Species", viewer.WHALE_CLASSES,
# index=None, placeholder="Species not yet identified...",
# disabled=True)
# else:
# pred1 = st.session_state.whale_prediction1[hash]
# # get index of pred1 from WHALE_CLASSES, none if not present
# print(f"[D] pred1: {pred1}")
# ix = viewer.WHALE_CLASSES.index(pred1) if pred1 in viewer.WHALE_CLASSES else None
# selected_class = st.selectbox(f"Species for observation {str(o)}", viewer.WHALE_CLASSES, index=ix)
# observation['predicted_class'] = selected_class
# if selected_class != st.session_state.whale_prediction1[hash]:
# observation['class_overriden'] = selected_class # TODO: this should be boolean!
# st.session_state.public_observation = observation
st.button(f"Upload observation {str(o)} to THE INTERNET!", on_click=push_observations)
# TODO: the metadata only fills properly if `validate` was clicked.
st.markdown(metadata2md())
st.markdown(f"- **hash**: {hash}")
msg = f"[D] full observation after inference: {observation}"
g_logger.debug(msg)
print(msg)
# TODO: add a link to more info on the model, next to the button.
whale_classes = observations[hash].top_predictions
# render images for the top 3 (that is what the model api returns)
n = len(whale_classes)
st.markdown(f"Top {n} Predictions for observation {str(o)}")
for i in range(n):
viewer.display_whale(whale_classes, i)
o += 1
col = (col + 1) % row_size
# func to do all in one
def cetacean_classify_show_and_review(cetacean_classifier):
"""Cetacean classifier using the saving-willy model from Saving Willy Hugging Face space.
For each image in the session state, classify the image and display the top 3 predictions.
Args:
cetacean_classifier ([type]): saving-willy model from Saving Willy Hugging Face space
"""
images = st.session_state.images
observations = st.session_state.observations
hashes = st.session_state.image_hashes
batch_size, row_size, page = gridder(hashes)
grid = st.columns(row_size)
col = 0
o=1
for hash in hashes:
image = images[hash]
with grid[col]:
st.image(image, use_column_width=True)
observation = observations[hash].to_dict()
# run classifier model on `image`, and persistently store the output
out = cetacean_classifier(image) # get top 3 matches
st.session_state.whale_prediction1[hash] = out['predictions'][0]
st.session_state.classify_whale_done[hash] = True
msg = f"[D]2 classify_whale_done for {hash}: {st.session_state.classify_whale_done[hash]}, whale_prediction1: {st.session_state.whale_prediction1[hash]}"
g_logger.info(msg)
# dropdown for selecting/overriding the species prediction
if not st.session_state.classify_whale_done[hash]:
selected_class = st.sidebar.selectbox("Species", viewer.WHALE_CLASSES,
index=None, placeholder="Species not yet identified...",
disabled=True)
else:
pred1 = st.session_state.whale_prediction1[hash]
# get index of pred1 from WHALE_CLASSES, none if not present
print(f"[D] pred1: {pred1}")
ix = viewer.WHALE_CLASSES.index(pred1) if pred1 in viewer.WHALE_CLASSES else None
selected_class = st.selectbox(f"Species for observation {str(o)}", viewer.WHALE_CLASSES, index=ix)
observation['predicted_class'] = selected_class
if selected_class != st.session_state.whale_prediction1[hash]:
observation['class_overriden'] = selected_class
st.session_state.public_observation = observation
st.button(f"Upload observation {str(o)} to THE INTERNET!", on_click=push_observations)
# TODO: the metadata only fills properly if `validate` was clicked.
st.markdown(metadata2md())
msg = f"[D] full observation after inference: {observation}"
g_logger.debug(msg)
print(msg)
# TODO: add a link to more info on the model, next to the button.
whale_classes = out['predictions'][:]
# render images for the top 3 (that is what the model api returns)
st.markdown(f"Top 3 Predictions for observation {str(o)}")
for i in range(len(whale_classes)):
viewer.display_whale(whale_classes, i)
o += 1
col = (col + 1) % row_size |