from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
import streamlit as st
@st.cache_data
def prepare_model():
"""
Prepare the tokenizer and the model for classification.
"""
tokenizer = AutoTokenizer.from_pretrained("oracat/bert-paper-classifier")
model = AutoModelForSequenceClassification.from_pretrained(
"oracat/bert-paper-classifier"
)
return (tokenizer, model)
def process(text):
"""
Translate incoming text to tokens and classify it
"""
pipe = pipeline("text-classification", model=model, tokenizer=tokenizer, top_k=3)
result = pipe(text)[0]
result = sorted(result, key=lambda x: -x["score"])
cum_score = 0
prev_score = 0
for i, item in enumerate(result):
cum_score += item["score"]
if cum_score >= 0.95:
break
if i > 0:
# Heuristic to drop less relevant categories
if prev_score / item["score"] > 10:
i -= 1
break
prev_score = item["score"]
result = result[: (i + 1)]
return result
tokenizer, model = prepare_model()
# State managements
#
# The state in the app is the title and the abstract.
# State management is used here in order to pre-fill
# input fields with values for demos.
if "title" not in st.session_state:
st.session_state["title"] = ""
if "abstract" not in st.session_state:
st.session_state["abstract"] = ""
if "output" not in st.session_state:
st.session_state["output"] = ""
# Simple streamlit interface
st.markdown("### Biomedical paper classifier")
st.markdown("
", unsafe_allow_html=True)
## Demo buttons and their callbacks
def demo_immunology_callback():
"""
Use https://www.biorxiv.org/content/10.1101/2022.12.01.518788v1 for demo
"""
paper_title = "Using TCR and BCR sequencing to unravel the role of T and B cells in abdominal aortic aneurysm"
paper_abstract = "Recent evidence suggests that AAA displays characteristics of an autoimmune disease and it gained increasing prominence that specific antigen-driven T cells in the aortic tissue may contribute to the initial immune response. We found no clonal expansion of TCRs or BCRs in elastase-induced AAA in mice."
st.session_state["title"] = paper_title
st.session_state["abstract"] = paper_abstract
def demo_virology_callback():
"""
Use https://doi.org/10.4269/ajtmh.20-0849 for demo
"""
paper_title = "The Origin of COVID-19 and Why It Matters"
paper_abstract = "The COVID-19 pandemic is among the deadliest infectious diseases to have emerged in recent history. As with all past pandemics, the specific mechanism of its emergence in humans remains unknown. Nevertheless, a large body of virologic, epidemiologic, veterinary, and ecologic data establishes that the new virus, SARS-CoV-2, evolved directly or indirectly from a β-coronavirus in the sarbecovirus (SARS-like virus) group that naturally infect bats and pangolins in Asia and Southeast Asia. Scientists have warned for decades that such sarbecoviruses are poised to emerge again and again, identified risk factors, and argued for enhanced pandemic prevention and control efforts. Unfortunately, few such preventive actions were taken resulting in the latest coronavirus emergence detected in late 2019 which quickly spread pandemically. The risk of similar coronavirus outbreaks in the future remains high. In addition to controlling the COVID-19 pandemic, we must undertake vigorous scientific, public health, and societal actions, including significantly increased funding for basic and applied research addressing disease emergence, to prevent this tragic history from repeating itself."
st.session_state["title"] = paper_title
st.session_state["abstract"] = paper_abstract
def demo_microbiology_callback():
"""
Use https://doi.org/10.1016/j.cell.2023.01.002 for demo
"""
paper_title = "Bacterial droplet-based single-cell RNA-seq reveals antibiotic-associated heterogeneous cellular states"
paper_abstract = "We introduce BacDrop, a highly scalable technology for bacterial single-cell RNA sequencing that has overcome many challenges hindering the development of scRNA-seq in bacteria. BacDrop can be applied to thousands to millions of cells from both gram-negative and gram-positive species. It features universal ribosomal RNA depletion and combinatorial barcodes that enable multiplexing and massively parallel sequencing. We applied BacDrop to study Klebsiella pneumoniae clinical isolates and to elucidate their heterogeneous responses to antibiotic stress. In an unperturbed population presumed to be homogeneous, we found within-population heterogeneity largely driven by the expression of mobile genetic elements that promote the evolution of antibiotic resistance. Under antibiotic perturbation, BacDrop revealed transcriptionally distinct subpopulations associated with different phenotypic outcomes including antibiotic persistence. BacDrop thus can capture cellular states that cannot be detected by bulk RNA-seq, which will unlock new microbiological insights into bacterial responses to perturbations and larger bacterial communities such as the microbiome."
st.session_state["title"] = paper_title
st.session_state["abstract"] = paper_abstract
def clear_callback():
"""
Clear input fields
"""
st.session_state["title"] = ""
st.session_state["abstract"] = ""
st.session_state["output"] = ""
st.markdown(
"""""",
unsafe_allow_html=True,
)
col1, col2, col3, col4 = st.columns([1, 1, 1, 1])
with col1:
st.button("Demo: immunology", on_click=demo_immunology_callback)
with col2:
st.button("Demo: microbiology", on_click=demo_microbiology_callback)
with col3:
st.button("Demo: virology", on_click=demo_virology_callback)
with col4:
st.button("Clear fields", on_click=clear_callback)
## Input fields
placeholder = st.empty()
title = st.text_input("Enter the title:", key="title")
abstract = st.text_area(
"... and maybe the abstract of the paper you want to classify:", key="abstract"
)
text = "\n".join([title, abstract])
## Output
if len(text.strip()) > 0:
results = process(text)
if len(results) == 0:
out_text = ""
else:
out_text = f"This paper is likely to be from the category **{results[0]['label']}** *(score {results[0]['score']:.2f})*."
if len(results) > 1:
out_text += "\n(Other fitting categories are " + " and ".join(
[
f"{item['label']} *(score {item['score']:.2f})*"
for item in results[1:]
]
)
out_text += ".)"
st.markdown(out_text)