Spaces:
Running
Running
change design
Browse files- interfaces/cap_minor.py +10 -20
interfaces/cap_minor.py
CHANGED
@@ -32,24 +32,15 @@ domains = {
|
|
32 |
"local government agenda": "localgovernment"
|
33 |
}
|
34 |
|
35 |
-
def convert_minor_to_major(
|
36 |
-
|
37 |
-
for i in results:
|
38 |
-
prob = probs[i]
|
39 |
-
major_code = str(CAP_MIN_NUM_DICT[i])[:-2]
|
40 |
-
|
41 |
-
if major_code == "99":
|
42 |
-
major_code = "999"
|
43 |
-
|
44 |
-
label = CAP_LABEL_NAMES[int(major_code)]
|
45 |
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
results_as_text[key] = probs[i]
|
51 |
|
52 |
-
return
|
53 |
|
54 |
|
55 |
def check_huggingface_path(checkpoint_path: str):
|
@@ -79,10 +70,9 @@ def predict(text, model_id, tokenizer_id):
|
|
79 |
logits = model(**inputs).logits
|
80 |
|
81 |
probs = torch.nn.functional.softmax(logits, dim=1).cpu().numpy().flatten()
|
82 |
-
|
83 |
-
output_pred_major = convert_minor_to_major(np.argsort(probs)[::-1], probs)
|
84 |
output_info = f'<p style="text-align: center; display: block">Prediction was made using the <a href="https://huggingface.co/{model_id}">{model_id}</a> model.</p>'
|
85 |
-
return
|
86 |
|
87 |
def predict_cap(text, language, domain):
|
88 |
domain = domains[domain]
|
@@ -101,4 +91,4 @@ demo = gr.Interface(
|
|
101 |
inputs=[gr.Textbox(lines=6, label="Input"),
|
102 |
gr.Dropdown(languages, label="Language"),
|
103 |
gr.Dropdown(domains.keys(), label="Domain")],
|
104 |
-
outputs=[gr.Label(num_top_classes=5, label="Output
|
|
|
32 |
"local government agenda": "localgovernment"
|
33 |
}
|
34 |
|
35 |
+
def convert_minor_to_major(minor_topic):
|
36 |
+
major_code = str(minor_topic)[:-2]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
+
if major_code == "99":
|
39 |
+
major_code = "999"
|
40 |
+
|
41 |
+
label = CAP_LABEL_NAMES[int(major_code)]
|
|
|
42 |
|
43 |
+
return label
|
44 |
|
45 |
|
46 |
def check_huggingface_path(checkpoint_path: str):
|
|
|
70 |
logits = model(**inputs).logits
|
71 |
|
72 |
probs = torch.nn.functional.softmax(logits, dim=1).cpu().numpy().flatten()
|
73 |
+
output_pred = {f"[{CAP_MIN_NUM_DICT[i]}] {convert_minor_to_major(CAP_MIN_NUM_DICT[i])} - {CAP_MIN_LABEL_NAMES[CAP_MIN_NUM_DICT[i]]}": probs[i] for i in np.argsort(probs)[::-1]}
|
|
|
74 |
output_info = f'<p style="text-align: center; display: block">Prediction was made using the <a href="https://huggingface.co/{model_id}">{model_id}</a> model.</p>'
|
75 |
+
return output_pred, output_info
|
76 |
|
77 |
def predict_cap(text, language, domain):
|
78 |
domain = domains[domain]
|
|
|
91 |
inputs=[gr.Textbox(lines=6, label="Input"),
|
92 |
gr.Dropdown(languages, label="Language"),
|
93 |
gr.Dropdown(domains.keys(), label="Domain")],
|
94 |
+
outputs=[gr.Label(num_top_classes=5, label="Output"), gr.Markdown()])
|