kovacsvi commited on
Commit
83612bb
·
1 Parent(s): 7d03fc5

normalize probs (top5)

Browse files
Files changed (1) hide show
  1. interfaces/cap_minor_media.py +6 -3
interfaces/cap_minor_media.py CHANGED
@@ -39,8 +39,11 @@ for code in CAP_MIN_CODES:
39
  major_to_minor_map[major_id].append(code)
40
  major_to_minor_map = dict(major_to_minor_map)
41
 
42
- def normalize_probs(probs: dict):
43
- values = np.array(list(probs.values()))
 
 
 
44
  exp_values = np.exp(values)
45
  sum_exp = np.sum(exp_values)
46
  return {k: float(v) for k, v in zip(probs.keys(), exp_values / sum_exp)}
@@ -95,7 +98,7 @@ def predict(text, major_model_id, minor_model_id, tokenizer_id, HF_TOKEN=None):
95
  i: float(major_probs_np[i])
96
  for i in np.argsort(major_probs_np)[::-1]
97
  }
98
- filtered_probs = normalize_probs(filtered_probs)
99
 
100
  output_pred = {
101
  f"[{major_index_to_id[k]}] {CAP_MEDIA_LABEL_NAMES[major_index_to_id[k]]}": v
 
39
  major_to_minor_map[major_id].append(code)
40
  major_to_minor_map = dict(major_to_minor_map)
41
 
42
+ def normalize_probs(probs: dict, n: int):
43
+ probs = list(probs.values())
44
+ if len(probs) > n:
45
+ probs = probs.sort(reverse=True)[:5]
46
+ values = np.array(probs)
47
  exp_values = np.exp(values)
48
  sum_exp = np.sum(exp_values)
49
  return {k: float(v) for k, v in zip(probs.keys(), exp_values / sum_exp)}
 
98
  i: float(major_probs_np[i])
99
  for i in np.argsort(major_probs_np)[::-1]
100
  }
101
+ filtered_probs = normalize_probs(filtered_probs, n=5)
102
 
103
  output_pred = {
104
  f"[{major_index_to_id[k]}] {CAP_MEDIA_LABEL_NAMES[major_index_to_id[k]]}": v