vladyur's picture
upgrade from distilbert to bert
9001366
raw
history blame
2.92 kB
import numpy as np
import pandas as pd
import transformers
import torch
import tokenizers
import streamlit as st
NUM_LABELS = 15
labels_names = {
0: 'Astrophysics',
1: 'Condensed Matter',
2: 'Computer Science',
3: 'Economics',
4: 'Electrical Engineering and Systems Science',
5: 'General Relativity and Quantum Cosmology',
6: 'High Energy Physics',
7: 'Mathematics',
8: 'Nonlinear Sciences',
9: 'Nuclear Theory',
10: 'General Physics',
11: 'Quantitative Biology',
12: 'Quantitative Finance',
13: 'Quantum Physics',
14: 'Statistics',
}
@st.cache(hash_funcs={tokenizers.Tokenizer: lambda _: None}, suppress_st_warning=True)
def get_model(model_name, model_path):
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
model = transformers.AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=NUM_LABELS)
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()
return model, tokenizer
@st.cache(hash_funcs={tokenizers.Tokenizer: lambda _: None}, suppress_st_warning=True)
def predict(text, tokenizer, model, temperature = 1):
tokens = tokenizer.encode(text)
with torch.no_grad():
logits = model.cpu()(torch.as_tensor([tokens]))[0]
probs = torch.softmax(logits[-1, :] / temperature, dim=-1).data.cpu().numpy()
indexes_descending = np.argsort(probs)[::-1]
percents = 0
preds = []
pred_probs = []
for index in indexes_descending:
preds.append(labels_names[index])
pred_prob = 100 * probs[index]
pred_probs.append(f"{pred_prob:.1f}%")
percents += pred_prob
if percents >= 95:
break
result = pd.DataFrame({'Probability': pred_probs})
result.index = preds
return result
model, tokenizer = get_model('bert-base-cased', 'bert-checkpoint-14644.bin')
st.title("Yandex School of Data Analysis. ML course")
st.title("Laboratory work 2: classifier of categories of scientific papers")
st.markdown("<img width=200px src='https://m.media-amazon.com/images/I/71XOMSKx8NL._AC_SL1500_.jpg'>", unsafe_allow_html=True)
st.markdown("\n")
st.markdown("Enter the title of the article and its abstract (although, if you really don't want to, you can do with just the title)")
title = st.text_area(label='Title of the article', height=100)
abstract = st.text_area(label='Abstract of the article', height=200)
button = st.button('Go')
if button:
try:
text = ' [ABSTRACT] '.join([title, abstract])
result = predict(text, tokenizer, model)
if len(text) > 10:
st.subheader('Bumblebee thinks, this paper related to')
st.write(result)
else:
st.error("Enter some more info please")
except Exception:
st.error("Ooooops, something went wrong. Try again please and report to me, tg: @vladyur")