Spaces:
Running
Running
kovacsvi
commited on
Commit
·
83612bb
1
Parent(s):
7d03fc5
normalize probs (top5)
Browse files
interfaces/cap_minor_media.py
CHANGED
@@ -39,8 +39,11 @@ for code in CAP_MIN_CODES:
|
|
39 |
major_to_minor_map[major_id].append(code)
|
40 |
major_to_minor_map = dict(major_to_minor_map)
|
41 |
|
42 |
-
def normalize_probs(probs: dict):
|
43 |
-
|
|
|
|
|
|
|
44 |
exp_values = np.exp(values)
|
45 |
sum_exp = np.sum(exp_values)
|
46 |
return {k: float(v) for k, v in zip(probs.keys(), exp_values / sum_exp)}
|
@@ -95,7 +98,7 @@ def predict(text, major_model_id, minor_model_id, tokenizer_id, HF_TOKEN=None):
|
|
95 |
i: float(major_probs_np[i])
|
96 |
for i in np.argsort(major_probs_np)[::-1]
|
97 |
}
|
98 |
-
filtered_probs = normalize_probs(filtered_probs)
|
99 |
|
100 |
output_pred = {
|
101 |
f"[{major_index_to_id[k]}] {CAP_MEDIA_LABEL_NAMES[major_index_to_id[k]]}": v
|
|
|
39 |
major_to_minor_map[major_id].append(code)
|
40 |
major_to_minor_map = dict(major_to_minor_map)
|
41 |
|
42 |
+
def normalize_probs(probs: dict, n: int):
|
43 |
+
probs = list(probs.values())
|
44 |
+
if len(probs) > n:
|
45 |
+
probs = probs.sort(reverse=True)[:5]
|
46 |
+
values = np.array(probs)
|
47 |
exp_values = np.exp(values)
|
48 |
sum_exp = np.sum(exp_values)
|
49 |
return {k: float(v) for k, v in zip(probs.keys(), exp_values / sum_exp)}
|
|
|
98 |
i: float(major_probs_np[i])
|
99 |
for i in np.argsort(major_probs_np)[::-1]
|
100 |
}
|
101 |
+
filtered_probs = normalize_probs(filtered_probs, n=5)
|
102 |
|
103 |
output_pred = {
|
104 |
f"[{major_index_to_id[k]}] {CAP_MEDIA_LABEL_NAMES[major_index_to_id[k]]}": v
|