rmm commited on
Commit
7a5f0ca
·
1 Parent(s): 00bdefd

fix: classification_done and prediction1 now ok for image batches

Browse files
src/classifier/classifier_image.py CHANGED
@@ -33,25 +33,25 @@ def cetacean_classify(cetacean_classifier):
33
  observation = observations[hash].to_dict()
34
  # run classifier model on `image`, and persistently store the output
35
  out = cetacean_classifier(image) # get top 3 matches
36
- st.session_state.whale_prediction1 = out['predictions'][0]
37
- st.session_state.classify_whale_done = True
38
- msg = f"[D]2 classify_whale_done: {st.session_state.classify_whale_done}, whale_prediction1: {st.session_state.whale_prediction1}"
39
  g_logger.info(msg)
40
 
41
  # dropdown for selecting/overriding the species prediction
42
- if not st.session_state.classify_whale_done:
43
  selected_class = st.sidebar.selectbox("Species", viewer.WHALE_CLASSES,
44
  index=None, placeholder="Species not yet identified...",
45
  disabled=True)
46
  else:
47
- pred1 = st.session_state.whale_prediction1
48
  # get index of pred1 from WHALE_CLASSES, none if not present
49
  print(f"[D] pred1: {pred1}")
50
  ix = viewer.WHALE_CLASSES.index(pred1) if pred1 in viewer.WHALE_CLASSES else None
51
  selected_class = st.selectbox(f"Species for observation {str(o)}", viewer.WHALE_CLASSES, index=ix)
52
 
53
  observation['predicted_class'] = selected_class
54
- if selected_class != st.session_state.whale_prediction1:
55
  observation['class_overriden'] = selected_class
56
 
57
  st.session_state.public_observation = observation
 
33
  observation = observations[hash].to_dict()
34
  # run classifier model on `image`, and persistently store the output
35
  out = cetacean_classifier(image) # get top 3 matches
36
+ st.session_state.whale_prediction1[hash] = out['predictions'][0]
37
+ st.session_state.classify_whale_done[hash] = True
38
+ msg = f"[D]2 classify_whale_done for {hash}: {st.session_state.classify_whale_done[hash]}, whale_prediction1: {st.session_state.whale_prediction1[hash]}"
39
  g_logger.info(msg)
40
 
41
  # dropdown for selecting/overriding the species prediction
42
+ if not st.session_state.classify_whale_done[hash]:
43
  selected_class = st.sidebar.selectbox("Species", viewer.WHALE_CLASSES,
44
  index=None, placeholder="Species not yet identified...",
45
  disabled=True)
46
  else:
47
+ pred1 = st.session_state.whale_prediction1[hash]
48
  # get index of pred1 from WHALE_CLASSES, none if not present
49
  print(f"[D] pred1: {pred1}")
50
  ix = viewer.WHALE_CLASSES.index(pred1) if pred1 in viewer.WHALE_CLASSES else None
51
  selected_class = st.selectbox(f"Species for observation {str(o)}", viewer.WHALE_CLASSES, index=ix)
52
 
53
  observation['predicted_class'] = selected_class
54
+ if selected_class != st.session_state.whale_prediction1[hash]:
55
  observation['class_overriden'] = selected_class
56
 
57
  st.session_state.public_observation = observation
src/main.py CHANGED
@@ -67,10 +67,10 @@ if "public_observation" not in st.session_state:
67
  st.session_state.public_observation = {}
68
 
69
  if "classify_whale_done" not in st.session_state:
70
- st.session_state.classify_whale_done = False
71
 
72
  if "whale_prediction1" not in st.session_state:
73
- st.session_state.whale_prediction1 = None
74
 
75
  if "tab_log" not in st.session_state:
76
  st.session_state.tab_log = None
 
67
  st.session_state.public_observation = {}
68
 
69
  if "classify_whale_done" not in st.session_state:
70
+ st.session_state.classify_whale_done = {}
71
 
72
  if "whale_prediction1" not in st.session_state:
73
+ st.session_state.whale_prediction1 = {}
74
 
75
  if "tab_log" not in st.session_state:
76
  st.session_state.tab_log = None