Spaces:
Sleeping
Sleeping
Upload app.py
Browse files
app.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import time
|
3 |
+
|
4 |
+
# model part
|
5 |
+
|
6 |
+
import json
|
7 |
+
import torch
|
8 |
+
from torch import nn
|
9 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
10 |
+
|
11 |
+
with open('categories_with_names.json', 'r') as f:
|
12 |
+
cat_with_names = json.load(f)
|
13 |
+
with open('categories_from_model.json', 'r') as f:
|
14 |
+
categories_from_model = json.load(f)
|
15 |
+
|
16 |
+
@st.cache_resource
|
17 |
+
def load_models_and_tokenizer():
|
18 |
+
tokenizer = AutoTokenizer.from_pretrained("oracat/bert-paper-classifier-arxiv")
|
19 |
+
|
20 |
+
model_titles = AutoModelForSequenceClassification.from_pretrained(
|
21 |
+
"powerful_model_titles/checkpoint-13472", num_labels=len(categories_from_model), problem_type="multi_label_classification"
|
22 |
+
)
|
23 |
+
model_titles.eval()
|
24 |
+
model_abstracts = AutoModelForSequenceClassification.from_pretrained(
|
25 |
+
"powerful_model_abstracts/checkpoint-13472", num_labels=len(categories_from_model), problem_type="multi_label_classification"
|
26 |
+
)
|
27 |
+
model_abstracts.eval()
|
28 |
+
|
29 |
+
return model_titles, model_abstracts, tokenizer
|
30 |
+
|
31 |
+
model_titles, model_abstracts, tokenizer = load_models_and_tokenizer()
|
32 |
+
|
33 |
+
def categorize_text(title: str | None = None, abstract: str | None = None, progress_bar = None):
|
34 |
+
if title is None and abstract is None:
|
35 |
+
raise ValueError('title is None and abstract is None')
|
36 |
+
|
37 |
+
models_to_run = 2 if (title is not None and abstract is not None) else 1
|
38 |
+
|
39 |
+
proba_title = None
|
40 |
+
if title is not None:
|
41 |
+
progresses = (10, 30) if models_to_run == 2 else (20, 60)
|
42 |
+
my_bar.progress(progresses[0], text='computing titles')
|
43 |
+
input_tok = tokenizer(title, return_tensors='pt')
|
44 |
+
with torch.no_grad():
|
45 |
+
logits = model_titles(**input_tok)['logits']
|
46 |
+
proba_title = torch.sigmoid(logits)[0]
|
47 |
+
my_bar.progress(progresses[1], text='computed titles')
|
48 |
+
|
49 |
+
proba_abstract = None
|
50 |
+
if abstract is not None:
|
51 |
+
progresses = (40, 70) if models_to_run == 2 else (20, 60)
|
52 |
+
my_bar.progress(progresses[0], text='computing abstracts')
|
53 |
+
input_tok = tokenizer(abstract, return_tensors='pt')
|
54 |
+
with torch.no_grad():
|
55 |
+
logits = model_abstracts(**input_tok)['logits']
|
56 |
+
proba_abstract = torch.sigmoid(logits)[0]
|
57 |
+
my_bar.progress(progresses[0], text='computed abstracts')
|
58 |
+
|
59 |
+
if title is None:
|
60 |
+
proba = proba_abstract
|
61 |
+
elif abstract is None:
|
62 |
+
proba = proba_title
|
63 |
+
else:
|
64 |
+
proba = proba_title * 0.1 + proba_abstract * 0.9
|
65 |
+
|
66 |
+
progresses = (80, 90) if models_to_run == 2 else (70, 90)
|
67 |
+
|
68 |
+
my_bar.progress(progresses[0], text='computed proba')
|
69 |
+
|
70 |
+
sorted_proba, indices = torch.sort(proba, descending=True)
|
71 |
+
my_bar.progress(progresses[1], text='sorted proba')
|
72 |
+
to_take = 1
|
73 |
+
while sorted_proba[:to_take].sum() < 0.95 and to_take < len(categories_from_model):
|
74 |
+
to_take += 1
|
75 |
+
output = [(cat_with_names[categories_from_model[index]], proba[index].item())
|
76 |
+
for index in indices[:to_take]]
|
77 |
+
my_bar.progress(100, text='generated output')
|
78 |
+
return output
|
79 |
+
|
80 |
+
# front part
|
81 |
+
|
82 |
+
st.markdown("<h1 style='text-align: center;'>Classify your paper!</h1>", unsafe_allow_html=True)
|
83 |
+
|
84 |
+
|
85 |
+
if "title" not in st.session_state:
|
86 |
+
st.session_state.title = ""
|
87 |
+
if "abstract" not in st.session_state:
|
88 |
+
st.session_state.abstract = ""
|
89 |
+
if "title_input_key" not in st.session_state:
|
90 |
+
st.session_state.title_input_key = ""
|
91 |
+
if "abstract_input_key" not in st.session_state:
|
92 |
+
st.session_state.abstract_input_key = ""
|
93 |
+
if "model_type" not in st.session_state:
|
94 |
+
st.session_state.model_type = []
|
95 |
+
|
96 |
+
def input_error():
|
97 |
+
if not st.session_state.model_type:
|
98 |
+
return 'you have to select title or abstract'
|
99 |
+
if 'Title' in model_type and not st.session_state.title:
|
100 |
+
return 'Title is empty'
|
101 |
+
if 'Abstract' in model_type and not st.session_state.abstract:
|
102 |
+
return 'Abstract is empty'
|
103 |
+
return ''
|
104 |
+
|
105 |
+
|
106 |
+
def clear_input():
|
107 |
+
st.session_state.title = title.title()
|
108 |
+
st.session_state.abstract = abstract.title()
|
109 |
+
if not input_error():
|
110 |
+
if "Title" in st.session_state.model_type:
|
111 |
+
st.session_state.title_input_key = ""
|
112 |
+
if "Abstract" in st.session_state.model_type:
|
113 |
+
st.session_state.abstract_input_key = ""
|
114 |
+
|
115 |
+
|
116 |
+
title = st.text_input(r"$\textsf{\Large Title}$", key="title_input_key")
|
117 |
+
|
118 |
+
abstract = st.text_input(r"$\textsf{\Large Abstract}$", key="abstract_input_key")
|
119 |
+
|
120 |
+
|
121 |
+
model_type = st.multiselect(
|
122 |
+
r"$\textsf{\large Classify by:}$",
|
123 |
+
['Title', 'Abstract'],
|
124 |
+
)
|
125 |
+
|
126 |
+
st.session_state.model_type = model_type
|
127 |
+
|
128 |
+
if(st.button('Submit', on_click=clear_input)):
|
129 |
+
if input_error():
|
130 |
+
st.error(input_error())
|
131 |
+
else:
|
132 |
+
send_time = time.localtime(time.time())
|
133 |
+
#st.success(f"Submitted {(' and '.join(st.session_state.model_type)).lower()} on {time.strftime('%d.%m %H:%M:%S', send_time)}")
|
134 |
+
model_input = dict()
|
135 |
+
if 'Title' in st.session_state.model_type:
|
136 |
+
model_input['title'] = st.session_state.title
|
137 |
+
if 'Abstract' in st.session_state.model_type:
|
138 |
+
model_input['abstract'] = st.session_state.abstract
|
139 |
+
#st.success(f'{model_input=}')
|
140 |
+
my_bar = st.progress(0, text='starting model')
|
141 |
+
model_result = categorize_text(**model_input, progress_bar=my_bar)
|
142 |
+
st.markdown("<h1 style='text-align: center;'>Classification completed!</h1>", unsafe_allow_html=True)
|
143 |
+
small_categories = []
|
144 |
+
cat, proba = model_result[0]
|
145 |
+
st.write(r"$\textsf{\Large " + f'{cat}: {round(100*proba)}' + r"\%}$")
|
146 |
+
for cat, proba in model_result[1:]:
|
147 |
+
if proba < 0.1:
|
148 |
+
small_categories.append(f'{cat}: {round(100*proba, 1)}' + r"\%")
|
149 |
+
else:
|
150 |
+
st.write(r"$\textsf{\large " + f'{cat}: {round(100*proba)}' + r"\%}$")
|
151 |
+
if small_categories:
|
152 |
+
st.write(', '.join(small_categories))
|
153 |
+
|
154 |
+
|
155 |
+
|
156 |
+
|
157 |
+
|