oracat commited on
Commit
847199e
·
1 Parent(s): 012c2a6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -3
app.py CHANGED
@@ -18,9 +18,20 @@ def process(text):
18
  """
19
  Translate incoming text to tokens and classify it
20
  """
21
- pipe = pipeline("text-classification", model=model, tokenizer=tokenizer)
22
  result = pipe(text)[0]
23
- return result["label"]
 
 
 
 
 
 
 
 
 
 
 
24
 
25
 
26
  tokenizer, model = prepare_model()
@@ -113,4 +124,13 @@ text = "\n".join([title, abstract])
113
  ## Output
114
 
115
  if len(text.strip()) > 0:
116
- st.markdown(f"This paper is likely to be from the category **{process(text)}**.")
 
 
 
 
 
 
 
 
 
 
18
  """
19
  Translate incoming text to tokens and classify it
20
  """
21
+ pipe = pipeline("text-classification", model=model, tokenizer=tokenizer, top_k=3)
22
  result = pipe(text)[0]
23
+
24
+ result = sorted(result, key=lambda x: -x["score"])
25
+
26
+ cum_score = 0
27
+ for i, item in enumerate(result):
28
+ cum_score += item["score"]
29
+ if cum_score >= 0.95:
30
+ break
31
+
32
+ result = result[: (i + 1)]
33
+
34
+ return result
35
 
36
 
37
  tokenizer, model = prepare_model()
 
124
  ## Output
125
 
126
  if len(text.strip()) > 0:
127
+ results = process(text)
128
+ if len(results) == 0:
129
+ out_text = ""
130
+ else:
131
+ out_text = f"This paper is likely to be from the category **{results[0]['label']}** *(score {results[0]['score']:.2f})*."
132
+ if len(results) > 1:
133
+ out_text += "\n(Alternative categories are " + " and ".join(
134
+ [f"{item['label']} *(score {item['score']:.2f})*"]
135
+ )
136
+ st.markdown(out_text)