MariaUDmitrieva commited on
Commit
36ce20c
·
verified ·
1 Parent(s): 1e009cb

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -157
app.py DELETED
@@ -1,157 +0,0 @@
1
- import streamlit as st
2
- import time
3
-
4
- # model part
5
-
6
- import json
7
- import torch
8
- from torch import nn
9
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
10
-
11
- with open('categories_with_names.json', 'r') as f:
12
- cat_with_names = json.load(f)
13
- with open('categories_from_model.json', 'r') as f:
14
- categories_from_model = json.load(f)
15
-
16
- @st.cache_resource
17
- def load_models_and_tokenizer():
18
- tokenizer = AutoTokenizer.from_pretrained("oracat/bert-paper-classifier-arxiv")
19
-
20
- model_titles = AutoModelForSequenceClassification.from_pretrained(
21
- "powerful_model_titles/checkpoint-13472", num_labels=len(categories_from_model), problem_type="multi_label_classification"
22
- )
23
- model_titles.eval()
24
- model_abstracts = AutoModelForSequenceClassification.from_pretrained(
25
- "powerful_model_abstracts/checkpoint-13472", num_labels=len(categories_from_model), problem_type="multi_label_classification"
26
- )
27
- model_abstracts.eval()
28
-
29
- return model_titles, model_abstracts, tokenizer
30
-
31
- model_titles, model_abstracts, tokenizer = load_models_and_tokenizer()
32
-
33
- def categorize_text(title: str | None = None, abstract: str | None = None, progress_bar = None):
34
- if title is None and abstract is None:
35
- raise ValueError('title is None and abstract is None')
36
-
37
- models_to_run = 2 if (title is not None and abstract is not None) else 1
38
-
39
- proba_title = None
40
- if title is not None:
41
- progresses = (10, 30) if models_to_run == 2 else (20, 60)
42
- my_bar.progress(progresses[0], text='computing titles')
43
- input_tok = tokenizer(title, return_tensors='pt')
44
- with torch.no_grad():
45
- logits = model_titles(**input_tok)['logits']
46
- proba_title = torch.sigmoid(logits)[0]
47
- my_bar.progress(progresses[1], text='computed titles')
48
-
49
- proba_abstract = None
50
- if abstract is not None:
51
- progresses = (40, 70) if models_to_run == 2 else (20, 60)
52
- my_bar.progress(progresses[0], text='computing abstracts')
53
- input_tok = tokenizer(abstract, return_tensors='pt')
54
- with torch.no_grad():
55
- logits = model_abstracts(**input_tok)['logits']
56
- proba_abstract = torch.sigmoid(logits)[0]
57
- my_bar.progress(progresses[0], text='computed abstracts')
58
-
59
- if title is None:
60
- proba = proba_abstract
61
- elif abstract is None:
62
- proba = proba_title
63
- else:
64
- proba = proba_title * 0.1 + proba_abstract * 0.9
65
-
66
- progresses = (80, 90) if models_to_run == 2 else (70, 90)
67
-
68
- my_bar.progress(progresses[0], text='computed proba')
69
-
70
- sorted_proba, indices = torch.sort(proba, descending=True)
71
- my_bar.progress(progresses[1], text='sorted proba')
72
- to_take = 1
73
- while sorted_proba[:to_take].sum() < 0.95 and to_take < len(categories_from_model):
74
- to_take += 1
75
- output = [(cat_with_names[categories_from_model[index]], proba[index].item())
76
- for index in indices[:to_take]]
77
- my_bar.progress(100, text='generated output')
78
- return output
79
-
80
- # front part
81
-
82
- st.markdown("<h1 style='text-align: center;'>Classify your paper!</h1>", unsafe_allow_html=True)
83
-
84
-
85
- if "title" not in st.session_state:
86
- st.session_state.title = ""
87
- if "abstract" not in st.session_state:
88
- st.session_state.abstract = ""
89
- if "title_input_key" not in st.session_state:
90
- st.session_state.title_input_key = ""
91
- if "abstract_input_key" not in st.session_state:
92
- st.session_state.abstract_input_key = ""
93
- if "model_type" not in st.session_state:
94
- st.session_state.model_type = []
95
-
96
- def input_error():
97
- if not st.session_state.model_type:
98
- return 'you have to select title or abstract'
99
- if 'Title' in model_type and not st.session_state.title:
100
- return 'Title is empty'
101
- if 'Abstract' in model_type and not st.session_state.abstract:
102
- return 'Abstract is empty'
103
- return ''
104
-
105
-
106
- def clear_input():
107
- st.session_state.title = title.title()
108
- st.session_state.abstract = abstract.title()
109
- if not input_error():
110
- if "Title" in st.session_state.model_type:
111
- st.session_state.title_input_key = ""
112
- if "Abstract" in st.session_state.model_type:
113
- st.session_state.abstract_input_key = ""
114
-
115
-
116
- title = st.text_input(r"$\textsf{\Large Title}$", key="title_input_key")
117
-
118
- abstract = st.text_input(r"$\textsf{\Large Abstract}$", key="abstract_input_key")
119
-
120
-
121
- model_type = st.multiselect(
122
- r"$\textsf{\large Classify by:}$",
123
- ['Title', 'Abstract'],
124
- )
125
-
126
- st.session_state.model_type = model_type
127
-
128
- if(st.button('Submit', on_click=clear_input)):
129
- if input_error():
130
- st.error(input_error())
131
- else:
132
- send_time = time.localtime(time.time())
133
- #st.success(f"Submitted {(' and '.join(st.session_state.model_type)).lower()} on {time.strftime('%d.%m %H:%M:%S', send_time)}")
134
- model_input = dict()
135
- if 'Title' in st.session_state.model_type:
136
- model_input['title'] = st.session_state.title
137
- if 'Abstract' in st.session_state.model_type:
138
- model_input['abstract'] = st.session_state.abstract
139
- #st.success(f'{model_input=}')
140
- my_bar = st.progress(0, text='starting model')
141
- model_result = categorize_text(**model_input, progress_bar=my_bar)
142
- st.markdown("<h1 style='text-align: center;'>Classification completed!</h1>", unsafe_allow_html=True)
143
- small_categories = []
144
- cat, proba = model_result[0]
145
- st.write(r"$\textsf{\Large " + f'{cat}: {round(100*proba)}' + r"\%}$")
146
- for cat, proba in model_result[1:]:
147
- if proba < 0.1:
148
- small_categories.append(f'{cat}: {round(100*proba, 1)}' + r"\%")
149
- else:
150
- st.write(r"$\textsf{\large " + f'{cat}: {round(100*proba)}' + r"\%}$")
151
- if small_categories:
152
- st.write(', '.join(small_categories))
153
-
154
-
155
-
156
-
157
-