Spaces:
Sleeping
Sleeping
rmm
commited on
Commit
·
d4ec4a0
1
Parent(s):
4854d2c
feat: separate functions for ML inference, manual validation, display
Browse files
src/classifier/classifier_image.py
CHANGED
@@ -20,7 +20,144 @@ def add_header_text() -> None:
|
|
20 |
Once inference is complete, the top three predictions are shown.
|
21 |
You can override the prediction by selecting a species from the dropdown.*""")
|
22 |
|
23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
"""Cetacean classifier using the saving-willy model from Saving Willy Hugging Face space.
|
25 |
For each image in the session state, classify the image and display the top 3 predictions.
|
26 |
Args:
|
|
|
20 |
Once inference is complete, the top three predictions are shown.
|
21 |
You can override the prediction by selecting a species from the dropdown.*""")
|
22 |
|
23 |
+
# func to just run classification, store results.
|
24 |
+
def cetacean_just_classify(cetacean_classifier):
|
25 |
+
|
26 |
+
images = st.session_state.images
|
27 |
+
observations = st.session_state.observations
|
28 |
+
hashes = st.session_state.image_hashes
|
29 |
+
|
30 |
+
for hash in hashes:
|
31 |
+
image = images[hash]
|
32 |
+
observation = observations[hash].to_dict()
|
33 |
+
# run classifier model on `image`, and persistently store the output
|
34 |
+
out = cetacean_classifier(image) # get top 3 matches
|
35 |
+
st.session_state.whale_prediction1[hash] = out['predictions'][0]
|
36 |
+
st.session_state.classify_whale_done[hash] = True
|
37 |
+
st.session_state.observations[hash].set_top_predictions(out['predictions'][:])
|
38 |
+
|
39 |
+
msg = f"[D]2 classify_whale_done for {hash}: {st.session_state.classify_whale_done[hash]}, whale_prediction1: {st.session_state.whale_prediction1[hash]}"
|
40 |
+
g_logger.info(msg)
|
41 |
+
|
42 |
+
# TODO: what is the difference between public and regular; and why is this not array-ready?
|
43 |
+
st.session_state.public_observation = observation
|
44 |
+
st.write(f"*[D] Observation {hash} classified as {st.session_state.whale_prediction1[hash]}*")
|
45 |
+
|
46 |
+
|
47 |
+
# func to show results and allow review
|
48 |
+
def cetacean_show_results_and_review():
|
49 |
+
images = st.session_state.images
|
50 |
+
observations = st.session_state.observations
|
51 |
+
hashes = st.session_state.image_hashes
|
52 |
+
batch_size, row_size, page = gridder(hashes)
|
53 |
+
|
54 |
+
grid = st.columns(row_size)
|
55 |
+
col = 0
|
56 |
+
o = 1
|
57 |
+
|
58 |
+
for hash in hashes:
|
59 |
+
image = images[hash]
|
60 |
+
observation = observations[hash].to_dict()
|
61 |
+
|
62 |
+
with grid[col]:
|
63 |
+
st.image(image, use_column_width=True)
|
64 |
+
|
65 |
+
# dropdown for selecting/overriding the species prediction
|
66 |
+
if not st.session_state.classify_whale_done[hash]:
|
67 |
+
selected_class = st.sidebar.selectbox("Species", viewer.WHALE_CLASSES,
|
68 |
+
index=None, placeholder="Species not yet identified...",
|
69 |
+
disabled=True)
|
70 |
+
else:
|
71 |
+
pred1 = st.session_state.whale_prediction1[hash]
|
72 |
+
# get index of pred1 from WHALE_CLASSES, none if not present
|
73 |
+
print(f"[D] pred1: {pred1}")
|
74 |
+
ix = viewer.WHALE_CLASSES.index(pred1) if pred1 in viewer.WHALE_CLASSES else None
|
75 |
+
selected_class = st.selectbox(f"Species for observation {str(o)}", viewer.WHALE_CLASSES, index=ix)
|
76 |
+
|
77 |
+
observation['predicted_class'] = selected_class
|
78 |
+
if selected_class != st.session_state.whale_prediction1[hash]:
|
79 |
+
observation['class_overriden'] = selected_class # TODO: this should be boolean!
|
80 |
+
|
81 |
+
st.session_state.public_observation = observation
|
82 |
+
st.button(f"Upload observation {str(o)} to THE INTERNET!", on_click=push_observations)
|
83 |
+
# TODO: the metadata only fills properly if `validate` was clicked.
|
84 |
+
st.markdown(metadata2md())
|
85 |
+
|
86 |
+
msg = f"[D] full observation after inference: {observation}"
|
87 |
+
g_logger.debug(msg)
|
88 |
+
print(msg)
|
89 |
+
# TODO: add a link to more info on the model, next to the button.
|
90 |
+
|
91 |
+
whale_classes = observations[hash].top_predictions
|
92 |
+
# render images for the top 3 (that is what the model api returns)
|
93 |
+
n = len(whale_classes)
|
94 |
+
st.markdown(f"Top {n} Predictions for observation {str(o)}")
|
95 |
+
for i in range(n):
|
96 |
+
viewer.display_whale(whale_classes, i)
|
97 |
+
o += 1
|
98 |
+
col = (col + 1) % row_size
|
99 |
+
|
100 |
+
|
101 |
+
# func to just present results
|
102 |
+
def cetacean_show_results():
|
103 |
+
images = st.session_state.images
|
104 |
+
observations = st.session_state.observations
|
105 |
+
hashes = st.session_state.image_hashes
|
106 |
+
batch_size, row_size, page = gridder(hashes)
|
107 |
+
|
108 |
+
|
109 |
+
grid = st.columns(row_size)
|
110 |
+
col = 0
|
111 |
+
o = 1
|
112 |
+
|
113 |
+
for hash in hashes:
|
114 |
+
image = images[hash]
|
115 |
+
observation = observations[hash].to_dict()
|
116 |
+
|
117 |
+
with grid[col]:
|
118 |
+
st.image(image, use_column_width=True)
|
119 |
+
|
120 |
+
# # dropdown for selecting/overriding the species prediction
|
121 |
+
# if not st.session_state.classify_whale_done[hash]:
|
122 |
+
# selected_class = st.sidebar.selectbox("Species", viewer.WHALE_CLASSES,
|
123 |
+
# index=None, placeholder="Species not yet identified...",
|
124 |
+
# disabled=True)
|
125 |
+
# else:
|
126 |
+
# pred1 = st.session_state.whale_prediction1[hash]
|
127 |
+
# # get index of pred1 from WHALE_CLASSES, none if not present
|
128 |
+
# print(f"[D] pred1: {pred1}")
|
129 |
+
# ix = viewer.WHALE_CLASSES.index(pred1) if pred1 in viewer.WHALE_CLASSES else None
|
130 |
+
# selected_class = st.selectbox(f"Species for observation {str(o)}", viewer.WHALE_CLASSES, index=ix)
|
131 |
+
|
132 |
+
# observation['predicted_class'] = selected_class
|
133 |
+
# if selected_class != st.session_state.whale_prediction1[hash]:
|
134 |
+
# observation['class_overriden'] = selected_class # TODO: this should be boolean!
|
135 |
+
|
136 |
+
# st.session_state.public_observation = observation
|
137 |
+
st.button(f"Upload observation {str(o)} to THE INTERNET!", on_click=push_observations)
|
138 |
+
# TODO: the metadata only fills properly if `validate` was clicked.
|
139 |
+
st.markdown(metadata2md())
|
140 |
+
st.markdown(f"- **hash**: {hash}")
|
141 |
+
|
142 |
+
msg = f"[D] full observation after inference: {observation}"
|
143 |
+
g_logger.debug(msg)
|
144 |
+
print(msg)
|
145 |
+
# TODO: add a link to more info on the model, next to the button.
|
146 |
+
|
147 |
+
whale_classes = observations[hash].top_predictions
|
148 |
+
# render images for the top 3 (that is what the model api returns)
|
149 |
+
n = len(whale_classes)
|
150 |
+
st.markdown(f"Top {n} Predictions for observation {str(o)}")
|
151 |
+
for i in range(n):
|
152 |
+
viewer.display_whale(whale_classes, i)
|
153 |
+
o += 1
|
154 |
+
col = (col + 1) % row_size
|
155 |
+
|
156 |
+
|
157 |
+
|
158 |
+
|
159 |
+
# func to do all in one
|
160 |
+
def cetacean_classify_show_and_review(cetacean_classifier):
|
161 |
"""Cetacean classifier using the saving-willy model from Saving Willy Hugging Face space.
|
162 |
For each image in the session state, classify the image and display the top 3 predictions.
|
163 |
Args:
|