kovacsvi commited on
Commit
7d912ec
·
1 Parent(s): a1eb2dd

normalize major topic predictions

Browse files
Files changed (1) hide show
  1. interfaces/cap_minor_media.py +3 -2
interfaces/cap_minor_media.py CHANGED
@@ -98,9 +98,10 @@ def predict(text, major_model_id, minor_model_id, tokenizer_id, HF_TOKEN=None):
98
  top_major_id = major_index_to_id[top_major_index]
99
 
100
  # Default: show major topic predictions
 
101
  output_pred = {
102
- f"[{major_index_to_id[i]}] {CAP_MEDIA_LABEL_NAMES[major_index_to_id[i]]}": float(major_probs_np[i])
103
- for i in np.argsort(major_probs_np)[::-1]
104
  }
105
 
106
  # If eligible for minor prediction
 
98
  top_major_id = major_index_to_id[top_major_index]
99
 
100
  # Default: show major topic predictions
101
+ filtered_probs = normalize_probs(np.argsort(major_probs_np)[::-1])
102
  output_pred = {
103
+ f"[{major_index_to_id[k]}] {CAP_MEDIA_LABEL_NAMES[filtered_probs[k]]}": v
104
+ for k, v in sorted(filtered_probs.items(), key=lambda item: item[1], reverse=True)
105
  }
106
 
107
  # If eligible for minor prediction