import streamlit as st import logging # get a global var for logger accessor in this module LOG_LEVEL = logging.DEBUG g_logger = logging.getLogger(__name__) g_logger.setLevel(LOG_LEVEL) import whale_viewer as viewer from hf_push_observations import push_observations from utils.grid_maker import gridder from utils.metadata_handler import metadata2md from input.input_observation import InputObservation def init_classifier_session_states() -> None: ''' Initialise the session state variables used in classification ''' if "classify_whale_done" not in st.session_state: st.session_state.classify_whale_done = {} if "whale_prediction1" not in st.session_state: st.session_state.whale_prediction1 = {} def add_classifier_header() -> None: """ Add brief explainer text about cetacean classification to the tab """ st.markdown(""" *Run classifer to identify the species of cetean on the uploaded image. Once inference is complete, the top three predictions are shown. You can override the prediction by selecting a species from the dropdown.*""") # func to just run classification, store results. def cetacean_just_classify(cetacean_classifier): """ Infer cetacean species for all observations in the session state. - this function runs the classifier, and stores results in the session state. - the top 3 predictions are stored in the observation object, which is retained in st.session_state.observations - to display results use cetacean_show_results() or cetacean_show_results_and_review() Args: cetacean_classifier ([type]): saving-willy model from Saving Willy Hugging Face space """ images = st.session_state.images #observations = st.session_state.observations hashes = st.session_state.image_hashes for hash in hashes: image = images[hash] # run classifier model on `image`, and persistently store the output out = cetacean_classifier(image) # get top 3 matches st.session_state.whale_prediction1[hash] = out['predictions'][0] st.session_state.classify_whale_done[hash] = True st.session_state.observations[hash].set_top_predictions(out['predictions'][:]) msg = f"[D]2 classify_whale_done for {hash}: {st.session_state.classify_whale_done[hash]}, whale_prediction1: {st.session_state.whale_prediction1[hash]}" g_logger.info(msg) if st.session_state.MODE_DEV_STATEFUL: st.write(f"*[D] Observation {hash} classified as {st.session_state.whale_prediction1[hash]}*") # func to show results and allow review def cetacean_show_results_and_review() -> None: """ Present classification results and allow user to review and override the prediction. - for each observation in the session state, displays the image, summarised metadata, and the top 3 predictions. - allows user to override the prediction by selecting a species from the dropdown. - the selected species is stored in the observation object, which is retained in st.session_state.observations """ images = st.session_state.images observations = st.session_state.observations hashes = st.session_state.image_hashes batch_size, row_size, page = gridder(hashes) grid = st.columns(row_size) col = 0 o = 1 for hash in hashes: image = images[hash] #observation = observations[hash].to_dict() _observation:InputObservation = observations[hash] with grid[col]: st.image(image, use_column_width=True) # dropdown for selecting/overriding the species prediction if not st.session_state.classify_whale_done[hash]: selected_class = st.sidebar.selectbox("Species", viewer.WHALE_CLASSES, index=None, placeholder="Species not yet identified...", disabled=True) else: pred1 = st.session_state.whale_prediction1[hash] # get index of pred1 from WHALE_CLASSES, none if not present print(f"[D] {o:3} pred1: {pred1:30} | {hash}") ix = viewer.WHALE_CLASSES.index(pred1) if pred1 in viewer.WHALE_CLASSES else None selected_class = st.selectbox(f"Species for observation {str(o)}", viewer.WHALE_CLASSES, index=ix) _observation.set_selected_class(selected_class) #observation['predicted_class'] = selected_class # this logic is now in the InputObservation class automatially #if selected_class != st.session_state.whale_prediction1[hash]: # observation['class_overriden'] = selected_class # TODO: this should be boolean! # store the elements of the observation that will be transmitted (not image) observation = _observation.to_dict() st.session_state.public_observations[hash] = observation #st.button(f"Upload observation {str(o)} to THE INTERNET!", on_click=push_observations) # TODO: the metadata only fills properly if `validate` was clicked. st.markdown(metadata2md(hash, debug=True)) msg = f"[D] full observation after inference: {observation}" g_logger.debug(msg) print(msg) # TODO: add a link to more info on the model, next to the button. whale_classes = observations[hash].top_predictions # render images for the top 3 (that is what the model api returns) n = len(whale_classes) st.markdown(f"**Top {n} Predictions for observation {str(o)}**") for i in range(n): viewer.display_whale(whale_classes, i) o += 1 col = (col + 1) % row_size # func to just present results def cetacean_show_results(): """ Present classification results that may be pushed to the online dataset. - for each observation in the session state, displays the image, summarised metadata, the top 3 predictions, and the selected species (which may have been manually selected, or the top prediction accepted). """ images = st.session_state.images observations = st.session_state.observations hashes = st.session_state.image_hashes batch_size, row_size, page = gridder(hashes) grid = st.columns(row_size) col = 0 o = 1 for hash in hashes: image = images[hash] observation = observations[hash].to_dict() with grid[col]: st.image(image, use_column_width=True) # # dropdown for selecting/overriding the species prediction # if not st.session_state.classify_whale_done[hash]: # selected_class = st.sidebar.selectbox("Species", viewer.WHALE_CLASSES, # index=None, placeholder="Species not yet identified...", # disabled=True) # else: # pred1 = st.session_state.whale_prediction1[hash] # # get index of pred1 from WHALE_CLASSES, none if not present # print(f"[D] pred1: {pred1}") # ix = viewer.WHALE_CLASSES.index(pred1) if pred1 in viewer.WHALE_CLASSES else None # selected_class = st.selectbox(f"Species for observation {str(o)}", viewer.WHALE_CLASSES, index=ix) # observation['predicted_class'] = selected_class # if selected_class != st.session_state.whale_prediction1[hash]: # observation['class_overriden'] = selected_class # TODO: this should be boolean! # st.session_state.public_observation = observation #st.button(f"Upload observation {str(o)} to THE INTERNET!", on_click=push_observations) # st.markdown(metadata2md(hash, debug=True)) msg = f"[D] full observation after inference: {observation}" g_logger.debug(msg) print(msg) # TODO: add a link to more info on the model, next to the button. whale_classes = observations[hash].top_predictions # render images for the top 3 (that is what the model api returns) n = len(whale_classes) st.markdown(f"**Top {n} Predictions for observation {str(o)}**") for i in range(n): viewer.display_whale(whale_classes, i) o += 1 col = (col + 1) % row_size # func to do all in one def cetacean_classify_show_and_review(cetacean_classifier): """Cetacean classifier using the saving-willy model from Saving Willy Hugging Face space. For each image in the session state, classify the image and display the top 3 predictions. Args: cetacean_classifier ([type]): saving-willy model from Saving Willy Hugging Face space """ raise DeprecationWarning("This function is deprecated. Use individual steps instead") images = st.session_state.images observations = st.session_state.observations hashes = st.session_state.image_hashes batch_size, row_size, page = gridder(hashes) grid = st.columns(row_size) col = 0 o=1 for hash in hashes: image = images[hash] with grid[col]: st.image(image, use_column_width=True) observation = observations[hash].to_dict() # run classifier model on `image`, and persistently store the output out = cetacean_classifier(image) # get top 3 matches st.session_state.whale_prediction1[hash] = out['predictions'][0] st.session_state.classify_whale_done[hash] = True msg = f"[D]2 classify_whale_done for {hash}: {st.session_state.classify_whale_done[hash]}, whale_prediction1: {st.session_state.whale_prediction1[hash]}" g_logger.info(msg) # dropdown for selecting/overriding the species prediction if not st.session_state.classify_whale_done[hash]: selected_class = st.sidebar.selectbox("Species", viewer.WHALE_CLASSES, index=None, placeholder="Species not yet identified...", disabled=True) else: pred1 = st.session_state.whale_prediction1[hash] # get index of pred1 from WHALE_CLASSES, none if not present print(f"[D] pred1: {pred1}") ix = viewer.WHALE_CLASSES.index(pred1) if pred1 in viewer.WHALE_CLASSES else None selected_class = st.selectbox(f"Species for observation {str(o)}", viewer.WHALE_CLASSES, index=ix) observation['predicted_class'] = selected_class if selected_class != st.session_state.whale_prediction1[hash]: observation['class_overriden'] = selected_class st.session_state.public_observation = observation st.button(f"Upload observation {str(o)} to THE INTERNET!", on_click=push_observations) # TODO: the metadata only fills properly if `validate` was clicked. st.markdown(metadata2md()) msg = f"[D] full observation after inference: {observation}" g_logger.debug(msg) print(msg) # TODO: add a link to more info on the model, next to the button. whale_classes = out['predictions'][:] # render images for the top 3 (that is what the model api returns) st.markdown(f"Top 3 Predictions for observation {str(o)}") for i in range(len(whale_classes)): viewer.display_whale(whale_classes, i) o += 1 col = (col + 1) % row_size