File size: 7,885 Bytes
55d18b1
 
 
 
 
 
 
 
 
 
 
 
 
d7725f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55d18b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import streamlit as st
import logging

# get a global var for logger accessor in this module
LOG_LEVEL = logging.DEBUG
g_logger = logging.getLogger(__name__)
g_logger.setLevel(LOG_LEVEL)

import whale_viewer as viewer
from hf_push_observations import push_observations
from utils.grid_maker import gridder
from utils.metadata_handler import metadata2md

# need to divide this into two functions, one for the classification and one for the display
# it is currently somewhat interleaved, not totally clear how to separate them.
# perhaps we have more stages than I realised.
# ML started, ML completed, manual review completed, data uploaded

# for now, let's implement the division between ML classification, and display+manual review. 

def cetacean_classify_list(cetacean_classifier):
    success = False
    
    files = st.session_state.files
    images = st.session_state.images
    observations = st.session_state.observations

    #batch_size, row_size, page = gridder(files)
    #grid = st.columns(row_size)
    #col = 0

    for file in files: 
        key = file.name
        image = images[key]
        
        observation = observations[key].to_dict()
        # run classifier model on `image`, and persistently store the output
        out = cetacean_classifier(image) # get top 3 matches
        st.session_state.whale_prediction1[key] = out['predictions'][0]
        st.session_state.classify_whale_done[key] = True # TODO 25.01 unclear what this is for; 
        msg = f"[D]2 classify_whale_done: {st.session_state.classify_whale_done[key]}, whale_prediction1: {st.session_state.whale_prediction1[key]}"
        g_logger.info(msg)

        observations[key].set_top_predictions(out['predictions'][:])
        
        st.session_state.public_observation[key] = observation # 
        msg = f"[D] full observation after inference: {observation}"
        g_logger.debug(msg)
        print(msg)
    
    # TODO: add some mech to test if it was successful. 
    success = True
    st.balloons()
    return success

def cetacean_show_classifications():
    st.write("TOP TEXT")
    st.write("Reviewing the classifications :construction:")
    files = st.session_state.files
    images = st.session_state.images
    observations = st.session_state.observations

    batch_size, row_size, page = gridder(files)
    
    grid = st.columns(row_size)
    col = 0

    for file in files: 
        key = file.name
        image = images[key]
        
        with grid[col]:
            st.image(image, use_column_width=True)
            observation = observations[key].to_dict()
            # fetch the classification results
            # run classifier model on `image`, and persistently store the output
            msg = f"[D]2b classify_whale_done ({file}): {st.session_state.classify_whale_done[key]}, whale_prediction1: {st.session_state.whale_prediction1[key]}"
            g_logger.info(msg)
            
            # dropdown for selecting/overriding the species prediction
            # TODO: the "it's done" flag seems to get reset when we re-load the tab. Not quite right.
            if not st.session_state.classify_whale_done[key]:
                #selected_class = st.sidebar.selectbox("Species", viewer.WHALE_CLASSES, 
                # TODO: ask LV why it is in the sidebar, and not in the grid
                selected_class = st.selectbox("Species", viewer.WHALE_CLASSES, 
                                            index=None, placeholder="Species not yet identified...", 
                                            disabled=True, key=f"cldd_{key}")
            else:
                pred1 = st.session_state.whale_prediction1[key]
                # get index of pred1 from WHALE_CLASSES, none if not present
                print(f"[D] pred1: {pred1}")
                ix = viewer.WHALE_CLASSES.index(pred1) if pred1 in viewer.WHALE_CLASSES else None
                selected_class = st.selectbox(f"Species for {file.name}", viewer.WHALE_CLASSES, index=ix)
            
            observation['predicted_class'] = selected_class
            if selected_class != st.session_state.whale_prediction1[key]:
                observation['class_overriden'] = selected_class
            
            st.session_state.public_observation = observation
            st.button(f"Upload observation for {file.name} to THE INTERNET!", on_click=push_observations)
            # TODO: the metadata only fills properly if `validate` was clicked.
            st.markdown(metadata2md())

            msg = f"[D] full observation after inference: {observation}"
            g_logger.debug(msg)
            print(msg)
            # TODO: add a link to more info on the model, next to the button.
            whale_classes = observations[key].top_predictions
            # render images for the top 3 (that is what the model api returns)
            n = len(whale_classes)
            st.markdown(f"Top {n} Predictions for {file.name}")
            for i in range(n):
                viewer.display_whale(whale_classes, i)
        col = (col + 1) % row_size
    return True


def cetacean_classify_and_review(cetacean_classifier):
    files = st.session_state.files
    images = st.session_state.images
    observations = st.session_state.observations

    batch_size, row_size, page = gridder(files)
    
    grid = st.columns(row_size)
    col = 0

    for file in files: 
        image = images[file.name]
        
        with grid[col]:
            st.image(image, use_column_width=True)
            observation = observations[file.name].to_dict()
            # run classifier model on `image`, and persistently store the output
            out = cetacean_classifier(image) # get top 3 matches
            st.session_state.whale_prediction1 = out['predictions'][0]
            st.session_state.classify_whale_done = True
            msg = f"[D]2 classify_whale_done: {st.session_state.classify_whale_done}, whale_prediction1: {st.session_state.whale_prediction1}"
            g_logger.info(msg)
            
            # dropdown for selecting/overriding the species prediction
            if not st.session_state.classify_whale_done:
                selected_class = st.sidebar.selectbox("Species", viewer.WHALE_CLASSES, 
                                                                index=None, placeholder="Species not yet identified...", 
                                                                disabled=True)
            else:
                pred1 = st.session_state.whale_prediction1
                # get index of pred1 from WHALE_CLASSES, none if not present
                print(f"[D] pred1: {pred1}")
                ix = viewer.WHALE_CLASSES.index(pred1) if pred1 in viewer.WHALE_CLASSES else None
                selected_class = st.selectbox(f"Species for {file.name}", viewer.WHALE_CLASSES, index=ix)
            
            observation['predicted_class'] = selected_class
            if selected_class != st.session_state.whale_prediction1:
                observation['class_overriden'] = selected_class
            
            st.session_state.public_observation = observation
            st.button(f"Upload observation for {file.name} to THE INTERNET!", on_click=push_observations)
            # TODO: the metadata only fills properly if `validate` was clicked.
            st.markdown(metadata2md())

            msg = f"[D] full observation after inference: {observation}"
            g_logger.debug(msg)
            print(msg)
            # TODO: add a link to more info on the model, next to the button.

            whale_classes = out['predictions'][:]
            # render images for the top 3 (that is what the model api returns)
            st.markdown(f"Top 3 Predictions for {file.name}")
            for i in range(len(whale_classes)):
                viewer.display_whale(whale_classes, i)
        col = (col + 1) % row_size