MariaUDmitrieva commited on
Commit
a929695
·
verified ·
1 Parent(s): 5c327f6

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +157 -0
app.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import time
3
+
4
+ # model part
5
+
6
+ import json
7
+ import torch
8
+ from torch import nn
9
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
10
+
11
+ with open('categories_with_names.json', 'r') as f:
12
+ cat_with_names = json.load(f)
13
+ with open('categories_from_model.json', 'r') as f:
14
+ categories_from_model = json.load(f)
15
+
16
+ @st.cache_resource
17
+ def load_models_and_tokenizer():
18
+ tokenizer = AutoTokenizer.from_pretrained("oracat/bert-paper-classifier-arxiv")
19
+
20
+ model_titles = AutoModelForSequenceClassification.from_pretrained(
21
+ "powerful_model_titles/checkpoint-13472", num_labels=len(categories_from_model), problem_type="multi_label_classification"
22
+ )
23
+ model_titles.eval()
24
+ model_abstracts = AutoModelForSequenceClassification.from_pretrained(
25
+ "powerful_model_abstracts/checkpoint-13472", num_labels=len(categories_from_model), problem_type="multi_label_classification"
26
+ )
27
+ model_abstracts.eval()
28
+
29
+ return model_titles, model_abstracts, tokenizer
30
+
31
+ model_titles, model_abstracts, tokenizer = load_models_and_tokenizer()
32
+
33
+ def categorize_text(title: str | None = None, abstract: str | None = None, progress_bar = None):
34
+ if title is None and abstract is None:
35
+ raise ValueError('title is None and abstract is None')
36
+
37
+ models_to_run = 2 if (title is not None and abstract is not None) else 1
38
+
39
+ proba_title = None
40
+ if title is not None:
41
+ progresses = (10, 30) if models_to_run == 2 else (20, 60)
42
+ my_bar.progress(progresses[0], text='computing titles')
43
+ input_tok = tokenizer(title, return_tensors='pt')
44
+ with torch.no_grad():
45
+ logits = model_titles(**input_tok)['logits']
46
+ proba_title = torch.sigmoid(logits)[0]
47
+ my_bar.progress(progresses[1], text='computed titles')
48
+
49
+ proba_abstract = None
50
+ if abstract is not None:
51
+ progresses = (40, 70) if models_to_run == 2 else (20, 60)
52
+ my_bar.progress(progresses[0], text='computing abstracts')
53
+ input_tok = tokenizer(abstract, return_tensors='pt')
54
+ with torch.no_grad():
55
+ logits = model_abstracts(**input_tok)['logits']
56
+ proba_abstract = torch.sigmoid(logits)[0]
57
+ my_bar.progress(progresses[0], text='computed abstracts')
58
+
59
+ if title is None:
60
+ proba = proba_abstract
61
+ elif abstract is None:
62
+ proba = proba_title
63
+ else:
64
+ proba = proba_title * 0.1 + proba_abstract * 0.9
65
+
66
+ progresses = (80, 90) if models_to_run == 2 else (70, 90)
67
+
68
+ my_bar.progress(progresses[0], text='computed proba')
69
+
70
+ sorted_proba, indices = torch.sort(proba, descending=True)
71
+ my_bar.progress(progresses[1], text='sorted proba')
72
+ to_take = 1
73
+ while sorted_proba[:to_take].sum() < 0.95 and to_take < len(categories_from_model):
74
+ to_take += 1
75
+ output = [(cat_with_names[categories_from_model[index]], proba[index].item())
76
+ for index in indices[:to_take]]
77
+ my_bar.progress(100, text='generated output')
78
+ return output
79
+
80
+ # front part
81
+
82
+ st.markdown("<h1 style='text-align: center;'>Classify your paper!</h1>", unsafe_allow_html=True)
83
+
84
+
85
+ if "title" not in st.session_state:
86
+ st.session_state.title = ""
87
+ if "abstract" not in st.session_state:
88
+ st.session_state.abstract = ""
89
+ if "title_input_key" not in st.session_state:
90
+ st.session_state.title_input_key = ""
91
+ if "abstract_input_key" not in st.session_state:
92
+ st.session_state.abstract_input_key = ""
93
+ if "model_type" not in st.session_state:
94
+ st.session_state.model_type = []
95
+
96
+ def input_error():
97
+ if not st.session_state.model_type:
98
+ return 'you have to select title or abstract'
99
+ if 'Title' in model_type and not st.session_state.title:
100
+ return 'Title is empty'
101
+ if 'Abstract' in model_type and not st.session_state.abstract:
102
+ return 'Abstract is empty'
103
+ return ''
104
+
105
+
106
+ def clear_input():
107
+ st.session_state.title = title.title()
108
+ st.session_state.abstract = abstract.title()
109
+ if not input_error():
110
+ if "Title" in st.session_state.model_type:
111
+ st.session_state.title_input_key = ""
112
+ if "Abstract" in st.session_state.model_type:
113
+ st.session_state.abstract_input_key = ""
114
+
115
+
116
+ title = st.text_input(r"$\textsf{\Large Title}$", key="title_input_key")
117
+
118
+ abstract = st.text_input(r"$\textsf{\Large Abstract}$", key="abstract_input_key")
119
+
120
+
121
+ model_type = st.multiselect(
122
+ r"$\textsf{\large Classify by:}$",
123
+ ['Title', 'Abstract'],
124
+ )
125
+
126
+ st.session_state.model_type = model_type
127
+
128
+ if(st.button('Submit', on_click=clear_input)):
129
+ if input_error():
130
+ st.error(input_error())
131
+ else:
132
+ send_time = time.localtime(time.time())
133
+ #st.success(f"Submitted {(' and '.join(st.session_state.model_type)).lower()} on {time.strftime('%d.%m %H:%M:%S', send_time)}")
134
+ model_input = dict()
135
+ if 'Title' in st.session_state.model_type:
136
+ model_input['title'] = st.session_state.title
137
+ if 'Abstract' in st.session_state.model_type:
138
+ model_input['abstract'] = st.session_state.abstract
139
+ #st.success(f'{model_input=}')
140
+ my_bar = st.progress(0, text='starting model')
141
+ model_result = categorize_text(**model_input, progress_bar=my_bar)
142
+ st.markdown("<h1 style='text-align: center;'>Classification completed!</h1>", unsafe_allow_html=True)
143
+ small_categories = []
144
+ cat, proba = model_result[0]
145
+ st.write(r"$\textsf{\Large " + f'{cat}: {round(100*proba)}' + r"\%}$")
146
+ for cat, proba in model_result[1:]:
147
+ if proba < 0.1:
148
+ small_categories.append(f'{cat}: {round(100*proba, 1)}' + r"\%")
149
+ else:
150
+ st.write(r"$\textsf{\large " + f'{cat}: {round(100*proba)}' + r"\%}$")
151
+ if small_categories:
152
+ st.write(', '.join(small_categories))
153
+
154
+
155
+
156
+
157
+