Spaces:
Running
Running
File size: 7,885 Bytes
55d18b1 d7725f6 55d18b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
import streamlit as st
import logging
# get a global var for logger accessor in this module
LOG_LEVEL = logging.DEBUG
g_logger = logging.getLogger(__name__)
g_logger.setLevel(LOG_LEVEL)
import whale_viewer as viewer
from hf_push_observations import push_observations
from utils.grid_maker import gridder
from utils.metadata_handler import metadata2md
# need to divide this into two functions, one for the classification and one for the display
# it is currently somewhat interleaved, not totally clear how to separate them.
# perhaps we have more stages than I realised.
# ML started, ML completed, manual review completed, data uploaded
# for now, let's implement the division between ML classification, and display+manual review.
def cetacean_classify_list(cetacean_classifier):
success = False
files = st.session_state.files
images = st.session_state.images
observations = st.session_state.observations
#batch_size, row_size, page = gridder(files)
#grid = st.columns(row_size)
#col = 0
for file in files:
key = file.name
image = images[key]
observation = observations[key].to_dict()
# run classifier model on `image`, and persistently store the output
out = cetacean_classifier(image) # get top 3 matches
st.session_state.whale_prediction1[key] = out['predictions'][0]
st.session_state.classify_whale_done[key] = True # TODO 25.01 unclear what this is for;
msg = f"[D]2 classify_whale_done: {st.session_state.classify_whale_done[key]}, whale_prediction1: {st.session_state.whale_prediction1[key]}"
g_logger.info(msg)
observations[key].set_top_predictions(out['predictions'][:])
st.session_state.public_observation[key] = observation #
msg = f"[D] full observation after inference: {observation}"
g_logger.debug(msg)
print(msg)
# TODO: add some mech to test if it was successful.
success = True
st.balloons()
return success
def cetacean_show_classifications():
st.write("TOP TEXT")
st.write("Reviewing the classifications :construction:")
files = st.session_state.files
images = st.session_state.images
observations = st.session_state.observations
batch_size, row_size, page = gridder(files)
grid = st.columns(row_size)
col = 0
for file in files:
key = file.name
image = images[key]
with grid[col]:
st.image(image, use_column_width=True)
observation = observations[key].to_dict()
# fetch the classification results
# run classifier model on `image`, and persistently store the output
msg = f"[D]2b classify_whale_done ({file}): {st.session_state.classify_whale_done[key]}, whale_prediction1: {st.session_state.whale_prediction1[key]}"
g_logger.info(msg)
# dropdown for selecting/overriding the species prediction
# TODO: the "it's done" flag seems to get reset when we re-load the tab. Not quite right.
if not st.session_state.classify_whale_done[key]:
#selected_class = st.sidebar.selectbox("Species", viewer.WHALE_CLASSES,
# TODO: ask LV why it is in the sidebar, and not in the grid
selected_class = st.selectbox("Species", viewer.WHALE_CLASSES,
index=None, placeholder="Species not yet identified...",
disabled=True, key=f"cldd_{key}")
else:
pred1 = st.session_state.whale_prediction1[key]
# get index of pred1 from WHALE_CLASSES, none if not present
print(f"[D] pred1: {pred1}")
ix = viewer.WHALE_CLASSES.index(pred1) if pred1 in viewer.WHALE_CLASSES else None
selected_class = st.selectbox(f"Species for {file.name}", viewer.WHALE_CLASSES, index=ix)
observation['predicted_class'] = selected_class
if selected_class != st.session_state.whale_prediction1[key]:
observation['class_overriden'] = selected_class
st.session_state.public_observation = observation
st.button(f"Upload observation for {file.name} to THE INTERNET!", on_click=push_observations)
# TODO: the metadata only fills properly if `validate` was clicked.
st.markdown(metadata2md())
msg = f"[D] full observation after inference: {observation}"
g_logger.debug(msg)
print(msg)
# TODO: add a link to more info on the model, next to the button.
whale_classes = observations[key].top_predictions
# render images for the top 3 (that is what the model api returns)
n = len(whale_classes)
st.markdown(f"Top {n} Predictions for {file.name}")
for i in range(n):
viewer.display_whale(whale_classes, i)
col = (col + 1) % row_size
return True
def cetacean_classify_and_review(cetacean_classifier):
files = st.session_state.files
images = st.session_state.images
observations = st.session_state.observations
batch_size, row_size, page = gridder(files)
grid = st.columns(row_size)
col = 0
for file in files:
image = images[file.name]
with grid[col]:
st.image(image, use_column_width=True)
observation = observations[file.name].to_dict()
# run classifier model on `image`, and persistently store the output
out = cetacean_classifier(image) # get top 3 matches
st.session_state.whale_prediction1 = out['predictions'][0]
st.session_state.classify_whale_done = True
msg = f"[D]2 classify_whale_done: {st.session_state.classify_whale_done}, whale_prediction1: {st.session_state.whale_prediction1}"
g_logger.info(msg)
# dropdown for selecting/overriding the species prediction
if not st.session_state.classify_whale_done:
selected_class = st.sidebar.selectbox("Species", viewer.WHALE_CLASSES,
index=None, placeholder="Species not yet identified...",
disabled=True)
else:
pred1 = st.session_state.whale_prediction1
# get index of pred1 from WHALE_CLASSES, none if not present
print(f"[D] pred1: {pred1}")
ix = viewer.WHALE_CLASSES.index(pred1) if pred1 in viewer.WHALE_CLASSES else None
selected_class = st.selectbox(f"Species for {file.name}", viewer.WHALE_CLASSES, index=ix)
observation['predicted_class'] = selected_class
if selected_class != st.session_state.whale_prediction1:
observation['class_overriden'] = selected_class
st.session_state.public_observation = observation
st.button(f"Upload observation for {file.name} to THE INTERNET!", on_click=push_observations)
# TODO: the metadata only fills properly if `validate` was clicked.
st.markdown(metadata2md())
msg = f"[D] full observation after inference: {observation}"
g_logger.debug(msg)
print(msg)
# TODO: add a link to more info on the model, next to the button.
whale_classes = out['predictions'][:]
# render images for the top 3 (that is what the model api returns)
st.markdown(f"Top 3 Predictions for {file.name}")
for i in range(len(whale_classes)):
viewer.display_whale(whale_classes, i)
col = (col + 1) % row_size |