File size: 3,164 Bytes
889109a
b339b00
 
 
2ffe758
1f7ef74
b339b00
 
239f32e
889109a
b339b00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35bd6d3
2ffe758
 
 
 
 
 
 
 
b339b00
 
 
2ffe758
b339b00
 
 
 
 
 
 
 
 
 
2ffe758
35bd6d3
 
 
239f32e
 
 
 
35bd6d3
 
239f32e
35bd6d3
 
 
 
239f32e
 
2ffe758
 
35bd6d3
2ffe758
b7fbb48
1f7ef74
 
 
b7fbb48
1f7ef74
 
 
b7fbb48
 
 
 
 
1f7ef74
2ffe758
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import streamlit as st 
import torch
from torch import nn
from transformers import BertModel, AutoTokenizer, AutoModel, pipeline
from time import time
import matplotlib.pyplot as plt
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = 'cpu'
from PIL import Image

# dict for decoding / enclding labels
labels = {'cs.NE': 0, 'cs.CL': 1, 'cs.AI': 2, 'stat.ML': 3, 'cs.CV': 4, 'cs.LG': 5}
labels_decoder = {'cs.NE': 'Neural and Evolutionary Computing', 'cs.CL': 'Computation and Language', 'cs.AI': 'Artificial Intelligence', 
 'stat.ML': 'Machine Learning (stat)', 'cs.CV': 'Computer Vision', 'cs.LG': 'Machine Learning'}
 
model_name = 'bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(model_name)

class BertClassifier(nn.Module):

    def __init__(self, n_classes, dropout=0.5, model_name='bert-base-uncased'):
        super(BertClassifier, self).__init__()
        self.bert = BertModel.from_pretrained(model_name)
        self.dropout = nn.Dropout(dropout)
        self.linear = nn.Linear(768, n_classes)
        self.relu = nn.ReLU()

    def forward(self, input_id, mask):
        _, pooled_output = self.bert(input_ids=input_id, attention_mask=mask,return_dict=False)
        dropout_output = self.dropout(pooled_output)
        linear_output = self.linear(dropout_output)
        final_layer = self.relu(linear_output)
        return final_layer

@st.cache(suppress_st_warning=True)
def build_model():
    model = BertClassifier(n_classes=len(labels))
    st.markdown("Model created")
    model.load_state_dict(torch.load('model_weights_1.pt', map_location=torch.device('cpu')))
    model.eval()
    st.markdown("Model weights loaded")
    return model
    
def inference(txt, mode=None):
    # infers classes for text topic based on the trained model from above
    # has separate mode 'print' for just output
    t2 = tokenizer(txt.lower().replace('\n', ''), 
       padding='max_length', max_length = 512, truncation=True,
       return_tensors="pt")

    inp2 =  t2['input_ids'].to(device)
    mask2 = t2['attention_mask'].unsqueeze(0).to(device)

    out = model(inp2, mask2)
    out = out.cpu().detach().numpy().reshape(-1)
    out = out/out.sum() * 100
    res = [(l, o) for l, o in zip (list(labels.keys()), out.tolist())]   
    return res   
    
model = build_model()
    
st.markdown("###Predict topic by abstract.")
image = Image.open('dilbert_big_data.jpg')
st.image(image)
# st.markdown("<img width=200px src='https://i.pinimg.com/736x/11/33/19/113319f0ffe91f4bb0f468914b9916da.jpg'>", unsafe_allow_html=True)

text = st.text_area("ENTER TEXT HERE")

start_time = time()

res = inference(text, mode=None)
res.sort(key = lambda x : - x[1])

st.markdown("INFERENCE RESULT")    
for lbl, score in res:
    if score >=1:
        st.markdown(f"[        {lbl:<7}] {labels_decoder[lbl]:<35}  {score:.1f}%")

res_plot = []
total=0
for r in res:
    if total < 95:
        res_plot.append(r)
        total += r[1]
    else:
        break

fig, ax = plt.subplots(figsize=(10, len(res_plot)+1))
for r in res_plot :
    ax.barh(r[0], r[1])
st.pyplot(fig)

st.markdown(f"cycle time = {time() - start_time:.2f} s.")