File size: 10,290 Bytes
0e8c927
 
 
 
 
 
 
 
 
 
 
 
 
4854d2c
 
 
 
 
 
 
 
 
d4ec4a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0e02e00
d4ec4a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0e02e00
d4ec4a0
 
0e02e00
d4ec4a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0e02e00
d4ec4a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b28238
 
 
 
 
0e02e00
0e8c927
 
1c0e2a5
 
0e8c927
 
 
1c0e2a5
 
 
0e8c927
 
 
1c0e2a5
0e8c927
 
7a5f0ca
 
 
0e8c927
 
 
7a5f0ca
0e8c927
 
 
 
7a5f0ca
0e8c927
 
 
1c0e2a5
0e8c927
 
7a5f0ca
0e8c927
 
 
1c0e2a5
0e8c927
 
 
 
 
 
 
 
 
 
1c0e2a5
0e8c927
 
1c0e2a5
0e8c927
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
import streamlit as st
import logging

# get a global var for logger accessor in this module
LOG_LEVEL = logging.DEBUG
g_logger = logging.getLogger(__name__)
g_logger.setLevel(LOG_LEVEL)

import whale_viewer as viewer
from hf_push_observations import push_observations
from utils.grid_maker import gridder
from utils.metadata_handler import metadata2md

def add_header_text() -> None:
    """
    Add brief explainer text about cetacean classification to the tab 
    """
    st.markdown("""
                *Run classifer to identify the species of cetean on the uploaded image.
                Once inference is complete, the top three predictions are shown.
                You can override the prediction by selecting a species from the dropdown.*""")

# func to just run classification, store results.
def cetacean_just_classify(cetacean_classifier):

    images = st.session_state.images
    observations = st.session_state.observations
    hashes = st.session_state.image_hashes
    
    for hash in hashes: 
        image = images[hash]
        observation = observations[hash].to_dict()
        # run classifier model on `image`, and persistently store the output
        out = cetacean_classifier(image) # get top 3 matches
        st.session_state.whale_prediction1[hash] = out['predictions'][0]
        st.session_state.classify_whale_done[hash] = True
        st.session_state.observations[hash].set_top_predictions(out['predictions'][:])

        msg = f"[D]2 classify_whale_done for {hash}: {st.session_state.classify_whale_done[hash]}, whale_prediction1: {st.session_state.whale_prediction1[hash]}"
        g_logger.info(msg)

        # TODO: what is the difference between public and regular; and why is this not array-ready?
        st.session_state.public_observations[hash] = observation
        st.write(f"*[D] Observation {hash} classified as {st.session_state.whale_prediction1[hash]}*")
       
        
# func to show results and allow review
def cetacean_show_results_and_review():
    images = st.session_state.images
    observations = st.session_state.observations
    hashes = st.session_state.image_hashes
    batch_size, row_size, page = gridder(hashes)
    
    grid = st.columns(row_size)
    col = 0
    o = 1

    for hash in hashes:
        image = images[hash]
        observation = observations[hash].to_dict()
    
        with grid[col]:
            st.image(image, use_column_width=True)
            
            # dropdown for selecting/overriding the species prediction
            if not st.session_state.classify_whale_done[hash]:
                selected_class = st.sidebar.selectbox("Species", viewer.WHALE_CLASSES, 
                                                                index=None, placeholder="Species not yet identified...", 
                                                                disabled=True)
            else:
                pred1 = st.session_state.whale_prediction1[hash]
                # get index of pred1 from WHALE_CLASSES, none if not present
                print(f"[D] pred1: {pred1}")
                ix = viewer.WHALE_CLASSES.index(pred1) if pred1 in viewer.WHALE_CLASSES else None
                selected_class = st.selectbox(f"Species for observation {str(o)}", viewer.WHALE_CLASSES, index=ix)
            
            observation['predicted_class'] = selected_class
            if selected_class != st.session_state.whale_prediction1[hash]:
                observation['class_overriden'] = selected_class # TODO: this should be boolean!
            
            st.session_state.public_observations[hash] = observation
            st.button(f"Upload observation {str(o)} to THE INTERNET!", on_click=push_observations)
            # TODO: the metadata only fills properly if `validate` was clicked.
            st.markdown(metadata2md(hash))

            msg = f"[D] full observation after inference: {observation}"
            g_logger.debug(msg)
            print(msg)
            # TODO: add a link to more info on the model, next to the button.

            whale_classes = observations[hash].top_predictions
            # render images for the top 3 (that is what the model api returns)
            n = len(whale_classes)
            st.markdown(f"Top {n} Predictions for observation {str(o)}")
            for i in range(n):
                viewer.display_whale(whale_classes, i)
        o += 1
        col = (col + 1) % row_size


# func to just present results
def cetacean_show_results():
    images = st.session_state.images
    observations = st.session_state.observations
    hashes = st.session_state.image_hashes
    batch_size, row_size, page = gridder(hashes)
    
    
    grid = st.columns(row_size)
    col = 0
    o = 1

    for hash in hashes:
        image = images[hash]
        observation = observations[hash].to_dict()
    
        with grid[col]:
            st.image(image, use_column_width=True)
            
            # # dropdown for selecting/overriding the species prediction
            # if not st.session_state.classify_whale_done[hash]:
            #     selected_class = st.sidebar.selectbox("Species", viewer.WHALE_CLASSES, 
            #                                                     index=None, placeholder="Species not yet identified...", 
            #                                                     disabled=True)
            # else:
            #     pred1 = st.session_state.whale_prediction1[hash]
            #     # get index of pred1 from WHALE_CLASSES, none if not present
            #     print(f"[D] pred1: {pred1}")
            #     ix = viewer.WHALE_CLASSES.index(pred1) if pred1 in viewer.WHALE_CLASSES else None
            #     selected_class = st.selectbox(f"Species for observation {str(o)}", viewer.WHALE_CLASSES, index=ix)
            
            # observation['predicted_class'] = selected_class
            # if selected_class != st.session_state.whale_prediction1[hash]:
            #     observation['class_overriden'] = selected_class # TODO: this should be boolean!
            
            # st.session_state.public_observation = observation
            st.button(f"Upload observation {str(o)} to THE INTERNET!", on_click=push_observations)
            # TODO: the metadata only fills properly if `validate` was clicked.
            st.markdown(metadata2md(hash))
            st.markdown(f"- **hash**: {hash}")

            msg = f"[D] full observation after inference: {observation}"
            g_logger.debug(msg)
            print(msg)
            # TODO: add a link to more info on the model, next to the button.

            whale_classes = observations[hash].top_predictions
            # render images for the top 3 (that is what the model api returns)
            n = len(whale_classes)
            st.markdown(f"Top {n} Predictions for observation {str(o)}")
            for i in range(n):
                viewer.display_whale(whale_classes, i)
        o += 1
        col = (col + 1) % row_size




# func to do all in one
def cetacean_classify_show_and_review(cetacean_classifier):
    """Cetacean classifier using the saving-willy model from Saving Willy Hugging Face space.
    For each image in the session state, classify the image and display the top 3 predictions.
    Args:
        cetacean_classifier ([type]):  saving-willy model from Saving Willy Hugging Face space
    """
    raise DeprecationWarning("This function is deprecated. Use individual steps instead")
    images = st.session_state.images
    observations = st.session_state.observations
    hashes = st.session_state.image_hashes
    batch_size, row_size, page = gridder(hashes)
    
    grid = st.columns(row_size)
    col = 0
    o=1
    for hash in hashes: 
        image = images[hash]
        
        with grid[col]:
            st.image(image, use_column_width=True)
            observation = observations[hash].to_dict()
            # run classifier model on `image`, and persistently store the output
            out = cetacean_classifier(image) # get top 3 matches
            st.session_state.whale_prediction1[hash] = out['predictions'][0]
            st.session_state.classify_whale_done[hash] = True
            msg = f"[D]2 classify_whale_done for {hash}: {st.session_state.classify_whale_done[hash]}, whale_prediction1: {st.session_state.whale_prediction1[hash]}"
            g_logger.info(msg)
            
            # dropdown for selecting/overriding the species prediction
            if not st.session_state.classify_whale_done[hash]:
                selected_class = st.sidebar.selectbox("Species", viewer.WHALE_CLASSES, 
                                                                index=None, placeholder="Species not yet identified...", 
                                                                disabled=True)
            else:
                pred1 = st.session_state.whale_prediction1[hash]
                # get index of pred1 from WHALE_CLASSES, none if not present
                print(f"[D] pred1: {pred1}")
                ix = viewer.WHALE_CLASSES.index(pred1) if pred1 in viewer.WHALE_CLASSES else None
                selected_class = st.selectbox(f"Species for observation {str(o)}", viewer.WHALE_CLASSES, index=ix)
            
            observation['predicted_class'] = selected_class
            if selected_class != st.session_state.whale_prediction1[hash]:
                observation['class_overriden'] = selected_class
            
            st.session_state.public_observation = observation
            st.button(f"Upload observation {str(o)} to THE INTERNET!", on_click=push_observations)
            # TODO: the metadata only fills properly if `validate` was clicked.
            st.markdown(metadata2md())

            msg = f"[D] full observation after inference: {observation}"
            g_logger.debug(msg)
            print(msg)
            # TODO: add a link to more info on the model, next to the button.

            whale_classes = out['predictions'][:]
            # render images for the top 3 (that is what the model api returns)
            st.markdown(f"Top 3 Predictions for observation {str(o)}")
            for i in range(len(whale_classes)):
                viewer.display_whale(whale_classes, i)
        o += 1
        col = (col + 1) % row_size