poltextlab commited on
Commit
c259974
·
verified ·
1 Parent(s): 6257e98

change design

Browse files
Files changed (1) hide show
  1. interfaces/cap_minor.py +10 -20
interfaces/cap_minor.py CHANGED
@@ -32,24 +32,15 @@ domains = {
32
  "local government agenda": "localgovernment"
33
  }
34
 
35
- def convert_minor_to_major(results, probs):
36
- results_as_text = dict()
37
- for i in results:
38
- prob = probs[i]
39
- major_code = str(CAP_MIN_NUM_DICT[i])[:-2]
40
-
41
- if major_code == "99":
42
- major_code = "999"
43
-
44
- label = CAP_LABEL_NAMES[int(major_code)]
45
 
46
- key = f"[{major_code}] {label}"
47
- if key in results_as_text:
48
- results_as_text[key] += probs[i]
49
- else:
50
- results_as_text[key] = probs[i]
51
 
52
- return results_as_text
53
 
54
 
55
  def check_huggingface_path(checkpoint_path: str):
@@ -79,10 +70,9 @@ def predict(text, model_id, tokenizer_id):
79
  logits = model(**inputs).logits
80
 
81
  probs = torch.nn.functional.softmax(logits, dim=1).cpu().numpy().flatten()
82
- output_pred_minor = {f"[{CAP_MIN_NUM_DICT[i]}] {CAP_MIN_LABEL_NAMES[CAP_MIN_NUM_DICT[i]]}": probs[i] for i in np.argsort(probs)[::-1]}
83
- output_pred_major = convert_minor_to_major(np.argsort(probs)[::-1], probs)
84
  output_info = f'<p style="text-align: center; display: block">Prediction was made using the <a href="https://huggingface.co/{model_id}">{model_id}</a> model.</p>'
85
- return output_pred_minor, output_pred_major, output_info
86
 
87
  def predict_cap(text, language, domain):
88
  domain = domains[domain]
@@ -101,4 +91,4 @@ demo = gr.Interface(
101
  inputs=[gr.Textbox(lines=6, label="Input"),
102
  gr.Dropdown(languages, label="Language"),
103
  gr.Dropdown(domains.keys(), label="Domain")],
104
- outputs=[gr.Label(num_top_classes=5, label="Output minor"), gr.Label(num_top_classes=5, label="Output major"), gr.Markdown()])
 
32
  "local government agenda": "localgovernment"
33
  }
34
 
35
+ def convert_minor_to_major(minor_topic):
36
+ major_code = str(minor_topic)[:-2]
 
 
 
 
 
 
 
 
37
 
38
+ if major_code == "99":
39
+ major_code = "999"
40
+
41
+ label = CAP_LABEL_NAMES[int(major_code)]
 
42
 
43
+ return label
44
 
45
 
46
  def check_huggingface_path(checkpoint_path: str):
 
70
  logits = model(**inputs).logits
71
 
72
  probs = torch.nn.functional.softmax(logits, dim=1).cpu().numpy().flatten()
73
+ output_pred = {f"[{CAP_MIN_NUM_DICT[i]}] {convert_minor_to_major(CAP_MIN_NUM_DICT[i])} - {CAP_MIN_LABEL_NAMES[CAP_MIN_NUM_DICT[i]]}": probs[i] for i in np.argsort(probs)[::-1]}
 
74
  output_info = f'<p style="text-align: center; display: block">Prediction was made using the <a href="https://huggingface.co/{model_id}">{model_id}</a> model.</p>'
75
+ return output_pred, output_info
76
 
77
  def predict_cap(text, language, domain):
78
  domain = domains[domain]
 
91
  inputs=[gr.Textbox(lines=6, label="Input"),
92
  gr.Dropdown(languages, label="Language"),
93
  gr.Dropdown(domains.keys(), label="Domain")],
94
+ outputs=[gr.Label(num_top_classes=5, label="Output"), gr.Markdown()])