Poe Dator
buttons introduced
4643ef6
raw
history blame
3.73 kB
import streamlit as st
import torch
from torch import nn
from transformers import BertModel, AutoTokenizer, AutoModel, pipeline
from time import time
import matplotlib.pyplot as plt
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = 'cpu'
from PIL import Image
# dict for decoding / enclding labels
labels = {'cs.NE': 0, 'cs.CL': 1, 'cs.AI': 2, 'stat.ML': 3, 'cs.CV': 4, 'cs.LG': 5}
labels_decoder = {'cs.NE': 'Neural and Evolutionary Computing', 'cs.CL': 'Computation and Language', 'cs.AI': 'Artificial Intelligence',
'stat.ML': 'Machine Learning (stat)', 'cs.CV': 'Computer Vision', 'cs.LG': 'Machine Learning'}
model_name = 'bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(model_name)
class BertClassifier(nn.Module):
def __init__(self, n_classes, dropout=0.5, model_name='bert-base-uncased'):
super(BertClassifier, self).__init__()
self.bert = BertModel.from_pretrained(model_name)
self.dropout = nn.Dropout(dropout)
self.linear = nn.Linear(768, n_classes)
self.relu = nn.ReLU()
def forward(self, input_id, mask):
_, pooled_output = self.bert(input_ids=input_id, attention_mask=mask,return_dict=False)
dropout_output = self.dropout(pooled_output)
linear_output = self.linear(dropout_output)
final_layer = self.relu(linear_output)
return final_layer
@st.cache(suppress_st_warning=True)
def build_model():
model = BertClassifier(n_classes=len(labels))
st.markdown("Model created")
model.load_state_dict(torch.load('model_weights_1.pt', map_location=torch.device('cpu')))
model.eval()
st.markdown("Model weights loaded")
return model
def inference(txt):
# infers classes for text topic based on loaded trained model
t2 = tokenizer(txt.lower().replace('\n', ''),
padding='max_length', max_length = 512, truncation=True,
return_tensors="pt")
inp2 = t2['input_ids'].to(device)
mask2 = t2['attention_mask'].unsqueeze(0).to(device)
out = model(inp2, mask2)
out = out.cpu().detach().numpy().reshape(-1)
out = out/out.sum() * 100
res = [(l, o) for l, o in zip (list(labels.keys()), out.tolist())]
return res
def infer_and_display_result(txt):
start_time = time()
res = inference(txt, mode=None)
res.sort(key = lambda x : - x[1])
st.markdown("###Inference results:")
for lbl, score in res:
if score >=1:
st.write(f"[ {lbl:<7}] {labels_decoder[lbl]:<35} {score:.1f}%")
res_plot = [] # storage for plot data
total=0
for r in res:
if total < 95:
res_plot.append(r)
total += r[1]
else:
break
res.sort(key = lambda x : x[1])
fig, ax = plt.subplots(figsize=(10, len(res_plot)))
for r in res_plot :
ax.barh(r[0], r[1])
st.pyplot(fig)
st.markdown(f"cycle time = {time() - start_time:.2f} s.")
# ======================================
model = build_model()
st.title('Big-data cloud application for scientific article topic inference using in-memory computing and stuff.')
image = Image.open('dilbert_big_data.jpg')
st.image(image)
st.write('test application for ML-2 class, YSDA-2022' )
# st.markdown("<img width=200px src='https://i.pinimg.com/736x/11/33/19/113319f0ffe91f4bb0f468914b9916da.jpg'>", unsafe_allow_html=True)
# st.markdown("###Predict topic by abstract.")
text = st.text_area("ENTER ARTICLE TITLE OR ABSTRACT HERE")
action = st.button('click here to infer topic')
if action:
infer_and_display_result(text)
action2 = st.button('click here to infer topic')
if action2:
st.write(text.upper())