vladyur commited on
Commit
96207b3
·
1 Parent(s): 15a4b18

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -0
app.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ import transformers
4
+ import torch
5
+ import tokenizers
6
+ import streamlit as st
7
+
8
+ NUM_LABELS = 15
9
+
10
+ labels_names = {
11
+ 0: 'Astrophysics',
12
+ 1: 'Condensed Matter',
13
+ 2: 'Computer Science',
14
+ 3: 'Economics',
15
+ 4: 'Electrical Engineering and Systems Science',
16
+ 5: 'General Relativity and Quantum Cosmology',
17
+ 6: 'High Energy Physics',
18
+ 7: 'Mathematics',
19
+ 8: 'Nonlinear Sciences',
20
+ 9: 'Nuclear Theory',
21
+ 10: 'General Physics',
22
+ 11: 'Quantitative Biology',
23
+ 12: 'Quantitative Finance',
24
+ 13: 'Quantum Physics',
25
+ 14: 'Statistics',
26
+ }
27
+
28
+ @st.cache(hash_funcs={tokenizers.Tokenizer: lambda _: None}, suppress_st_warning=True)
29
+ def get_model(model_name, model_path):
30
+ tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
31
+ model = transformers.AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=NUM_LABELS)
32
+ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
33
+ model.eval()
34
+ return model, tokenizer
35
+
36
+ @st.cache(hash_funcs={tokenizers.Tokenizer: lambda _: None}, suppress_st_warning=True)
37
+ def predict(text, tokenizer, model, temperature = 1):
38
+ tokens = tokenizer.encode(text)
39
+ with torch.no_grad():
40
+ logits = model.cpu()(torch.as_tensor([tokens]))[0]
41
+ probs = torch.softmax(logits[-1, :] / temperature, dim=-1).data.cpu().numpy()
42
+
43
+ indexes_descending = np.argsort(probs)[::-1]
44
+ percents = 0
45
+ preds = []
46
+ pred_probs = []
47
+ for index in indexes_descending:
48
+ preds.append(labels_names[index])
49
+ pred_prob = 100 * probs[index]
50
+ pred_probs.append(f"{pred_prob:.1f}%")
51
+
52
+ percents += pred_prob
53
+ if percents >= 95:
54
+ break
55
+
56
+ result = pd.DataFrame({'Probability': pred_probs})
57
+ result.index = preds
58
+
59
+ return result
60
+
61
+ model, tokenizer = get_model('distilbert-base-cased', 'distilbert-checkpoint-10983.bin')
62
+
63
+ st.title("Yandex School of Data Analysis. ML course, laboratory work 2: classifier of categories of scientific papers")
64
+ st.markdown("<img width=500px src='https://m.media-amazon.com/images/I/71XOMSKx8NL._AC_SL1500_.jpg'>", unsafe_allow_html=True)
65
+ st.markdown("Enter the title of the article and its abstract (although, if you really don't want to, you can do with just the title)")
66
+
67
+ title = st.text_area(label='Title of the article', height=200)
68
+ abstract = st.text_area(label='Abstract of the article', height=400)
69
+ button = st.button('Go')
70
+
71
+ if button:
72
+ try:
73
+ text = ' [ABSTRACT] '.join(title, abstract)
74
+ result = predict(text, tokenizer, model)
75
+ if len(text) > 10:
76
+ st.subheader('I think, this paper related to')
77
+ st.write(result)
78
+ else:
79
+ st.error("Enter some more info please")
80
+ except Exception:
81
+ st.error("Ooooops, something went wrong. Try again please and report to me, tg: @vladyur")