Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import pipeline | |
# кэширование | |
def load_model(): | |
return pipeline("text-classification", model="voroninip/bert-paper-classifier-arxiv", top_k=None) | |
model = load_model() | |
def top_pct(preds, threshold=.95): | |
preds = sorted(preds, key=lambda x: -x["score"]) | |
cum_score = 0 | |
for i, item in enumerate(preds): | |
cum_score += item["score"] | |
if cum_score >= threshold: | |
break | |
preds = preds[:(i+1)] | |
return preds | |
def format_predictions(preds) -> str: | |
""" | |
Prepare predictions and their scores for printing to the user | |
""" | |
out = "" | |
for i, item in enumerate(preds): | |
out += f"{i+1}. {item['label']} (score {item['score']:.2f})\n" | |
return out | |
st.markdown( | |
""" | |
<style> | |
.stApp { | |
background: linear-gradient(to bottom right, #020420, #080f6b); | |
} | |
</style> | |
""", | |
unsafe_allow_html=True | |
) | |
st.markdown(""" | |
<div style='text-align: center;'> | |
<img src='https://info.arxiv.org/brand/images/brand-logo-primary.jpg' alt='Centered Image' width='300'/> | |
</div> | |
""", unsafe_allow_html=True) | |
st.markdown(""" | |
<h2 style='text-align: center; color: lime; font-family: Arial;'> | |
🚀 arXiv paper categories predictor | |
</h2> | |
""", unsafe_allow_html=True) | |
st.markdown(""" | |
<p style=' | |
color: white; | |
font-size: 20px; | |
font-family: "Courier New", monospace; | |
'> | |
Paste Title and Abstract of the paper and get most likely categories of the paper in the | |
<a href="https://arxiv.org/category_taxonomy" target="_blank" style="color: cyan; text-decoration: none;"> | |
arXiv taxonomy | |
</a> | |
</p> | |
""", unsafe_allow_html=True) | |
title = st.text_input("Title", value="") | |
abstract = st.text_input("Abstract", value="") | |
st.markdown(""" | |
<p style=' | |
color: white; | |
font-size: 20px; | |
font-family: "Courier New", monospace; | |
'> | |
Most likely categories of the paper: | |
</p> | |
""", unsafe_allow_html=True) | |
query = title + '\n' + abstract | |
if title or abstract: | |
result = format_predictions(top_pct(model(query)[0])) | |
st.write(result) |