rmm commited on
Commit
d4ec4a0
·
1 Parent(s): 4854d2c

feat: separate functions for ML inference, manual validation, display

Browse files
Files changed (1) hide show
  1. src/classifier/classifier_image.py +138 -1
src/classifier/classifier_image.py CHANGED
@@ -20,7 +20,144 @@ def add_header_text() -> None:
20
  Once inference is complete, the top three predictions are shown.
21
  You can override the prediction by selecting a species from the dropdown.*""")
22
 
23
- def cetacean_classify(cetacean_classifier):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  """Cetacean classifier using the saving-willy model from Saving Willy Hugging Face space.
25
  For each image in the session state, classify the image and display the top 3 predictions.
26
  Args:
 
20
  Once inference is complete, the top three predictions are shown.
21
  You can override the prediction by selecting a species from the dropdown.*""")
22
 
23
+ # func to just run classification, store results.
24
+ def cetacean_just_classify(cetacean_classifier):
25
+
26
+ images = st.session_state.images
27
+ observations = st.session_state.observations
28
+ hashes = st.session_state.image_hashes
29
+
30
+ for hash in hashes:
31
+ image = images[hash]
32
+ observation = observations[hash].to_dict()
33
+ # run classifier model on `image`, and persistently store the output
34
+ out = cetacean_classifier(image) # get top 3 matches
35
+ st.session_state.whale_prediction1[hash] = out['predictions'][0]
36
+ st.session_state.classify_whale_done[hash] = True
37
+ st.session_state.observations[hash].set_top_predictions(out['predictions'][:])
38
+
39
+ msg = f"[D]2 classify_whale_done for {hash}: {st.session_state.classify_whale_done[hash]}, whale_prediction1: {st.session_state.whale_prediction1[hash]}"
40
+ g_logger.info(msg)
41
+
42
+ # TODO: what is the difference between public and regular; and why is this not array-ready?
43
+ st.session_state.public_observation = observation
44
+ st.write(f"*[D] Observation {hash} classified as {st.session_state.whale_prediction1[hash]}*")
45
+
46
+
47
+ # func to show results and allow review
48
+ def cetacean_show_results_and_review():
49
+ images = st.session_state.images
50
+ observations = st.session_state.observations
51
+ hashes = st.session_state.image_hashes
52
+ batch_size, row_size, page = gridder(hashes)
53
+
54
+ grid = st.columns(row_size)
55
+ col = 0
56
+ o = 1
57
+
58
+ for hash in hashes:
59
+ image = images[hash]
60
+ observation = observations[hash].to_dict()
61
+
62
+ with grid[col]:
63
+ st.image(image, use_column_width=True)
64
+
65
+ # dropdown for selecting/overriding the species prediction
66
+ if not st.session_state.classify_whale_done[hash]:
67
+ selected_class = st.sidebar.selectbox("Species", viewer.WHALE_CLASSES,
68
+ index=None, placeholder="Species not yet identified...",
69
+ disabled=True)
70
+ else:
71
+ pred1 = st.session_state.whale_prediction1[hash]
72
+ # get index of pred1 from WHALE_CLASSES, none if not present
73
+ print(f"[D] pred1: {pred1}")
74
+ ix = viewer.WHALE_CLASSES.index(pred1) if pred1 in viewer.WHALE_CLASSES else None
75
+ selected_class = st.selectbox(f"Species for observation {str(o)}", viewer.WHALE_CLASSES, index=ix)
76
+
77
+ observation['predicted_class'] = selected_class
78
+ if selected_class != st.session_state.whale_prediction1[hash]:
79
+ observation['class_overriden'] = selected_class # TODO: this should be boolean!
80
+
81
+ st.session_state.public_observation = observation
82
+ st.button(f"Upload observation {str(o)} to THE INTERNET!", on_click=push_observations)
83
+ # TODO: the metadata only fills properly if `validate` was clicked.
84
+ st.markdown(metadata2md())
85
+
86
+ msg = f"[D] full observation after inference: {observation}"
87
+ g_logger.debug(msg)
88
+ print(msg)
89
+ # TODO: add a link to more info on the model, next to the button.
90
+
91
+ whale_classes = observations[hash].top_predictions
92
+ # render images for the top 3 (that is what the model api returns)
93
+ n = len(whale_classes)
94
+ st.markdown(f"Top {n} Predictions for observation {str(o)}")
95
+ for i in range(n):
96
+ viewer.display_whale(whale_classes, i)
97
+ o += 1
98
+ col = (col + 1) % row_size
99
+
100
+
101
+ # func to just present results
102
+ def cetacean_show_results():
103
+ images = st.session_state.images
104
+ observations = st.session_state.observations
105
+ hashes = st.session_state.image_hashes
106
+ batch_size, row_size, page = gridder(hashes)
107
+
108
+
109
+ grid = st.columns(row_size)
110
+ col = 0
111
+ o = 1
112
+
113
+ for hash in hashes:
114
+ image = images[hash]
115
+ observation = observations[hash].to_dict()
116
+
117
+ with grid[col]:
118
+ st.image(image, use_column_width=True)
119
+
120
+ # # dropdown for selecting/overriding the species prediction
121
+ # if not st.session_state.classify_whale_done[hash]:
122
+ # selected_class = st.sidebar.selectbox("Species", viewer.WHALE_CLASSES,
123
+ # index=None, placeholder="Species not yet identified...",
124
+ # disabled=True)
125
+ # else:
126
+ # pred1 = st.session_state.whale_prediction1[hash]
127
+ # # get index of pred1 from WHALE_CLASSES, none if not present
128
+ # print(f"[D] pred1: {pred1}")
129
+ # ix = viewer.WHALE_CLASSES.index(pred1) if pred1 in viewer.WHALE_CLASSES else None
130
+ # selected_class = st.selectbox(f"Species for observation {str(o)}", viewer.WHALE_CLASSES, index=ix)
131
+
132
+ # observation['predicted_class'] = selected_class
133
+ # if selected_class != st.session_state.whale_prediction1[hash]:
134
+ # observation['class_overriden'] = selected_class # TODO: this should be boolean!
135
+
136
+ # st.session_state.public_observation = observation
137
+ st.button(f"Upload observation {str(o)} to THE INTERNET!", on_click=push_observations)
138
+ # TODO: the metadata only fills properly if `validate` was clicked.
139
+ st.markdown(metadata2md())
140
+ st.markdown(f"- **hash**: {hash}")
141
+
142
+ msg = f"[D] full observation after inference: {observation}"
143
+ g_logger.debug(msg)
144
+ print(msg)
145
+ # TODO: add a link to more info on the model, next to the button.
146
+
147
+ whale_classes = observations[hash].top_predictions
148
+ # render images for the top 3 (that is what the model api returns)
149
+ n = len(whale_classes)
150
+ st.markdown(f"Top {n} Predictions for observation {str(o)}")
151
+ for i in range(n):
152
+ viewer.display_whale(whale_classes, i)
153
+ o += 1
154
+ col = (col + 1) % row_size
155
+
156
+
157
+
158
+
159
+ # func to do all in one
160
+ def cetacean_classify_show_and_review(cetacean_classifier):
161
  """Cetacean classifier using the saving-willy model from Saving Willy Hugging Face space.
162
  For each image in the session state, classify the image and display the top 3 predictions.
163
  Args: