Spaces:
Sleeping
Sleeping
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline | |
import streamlit as st | |
def prepare_model(): | |
""" | |
Prepare the tokenizer and the model for classification. | |
""" | |
tokenizer = AutoTokenizer.from_pretrained("oracat/bert-paper-classifier") | |
model = AutoModelForSequenceClassification.from_pretrained( | |
"oracat/bert-paper-classifier" | |
) | |
return (tokenizer, model) | |
def process(text): | |
""" | |
Translate incoming text to tokens and classify it | |
""" | |
pipe = pipeline("text-classification", model=model, tokenizer=tokenizer) | |
result = pipe(text)[0] | |
return result["label"] | |
tokenizer, model = prepare_model() | |
# State managements | |
# | |
# The state in the app is the title and the abstract. | |
# State management is used here in order to pre-fill | |
# input fields with values for demos. | |
if "title" not in st.session_state: | |
st.session_state["title"] = "" | |
if "abstract" not in st.session_state: | |
st.session_state["abstract"] = "" | |
if "output" not in st.session_state: | |
st.session_state["output"] = "" | |
# Simple streamlit interface | |
st.markdown("### Hello, paper classifier!") | |
## Demo buttons and their callbacks | |
def demo_immunology_callback(): | |
""" | |
Use https://www.biorxiv.org/content/10.1101/2022.12.01.518788v1 for demo | |
""" | |
paper_title = "Using TCR and BCR sequencing to unravel the role of T and B cells in abdominal aortic aneurysm" | |
paper_abstract = "Recent evidence suggests that AAA displays characteristics of an autoimmune disease and it gained increasing prominence that specific antigen-driven T cells in the aortic tissue may contribute to the initial immune response. We found no clonal expansion of TCRs or BCRs in elastase-induced AAA in mice." | |
st.session_state["title"] = paper_title | |
st.session_state["abstract"] = paper_abstract | |
def demo_virology_callback(): | |
""" | |
Use https://doi.org/10.1016/j.cell.2020.08.001 for demo | |
""" | |
paper_title = "Severe COVID-19 Is Marked by a Dysregulated Myeloid Cell Compartment" | |
paper_abstract = "Coronavirus disease 2019 (COVID-19) is a mild to moderate respiratory tract infection, however, a subset of patients progress to severe disease and respiratory failure. The mechanism of protective immunity in mild forms and the pathogenesis of severe COVID-19 associated with increased neutrophil counts and dysregulated immune responses remain unclear. In a dual-center, two-cohort study, we combined single-cell RNA-sequencing and single-cell proteomics of whole-blood and peripheral-blood mononuclear cells to determine changes in immune cell composition and activation in mild versus severe COVID-19 (242 samples from 109 individuals) over time. HLA-DRhiCD11chi inflammatory monocytes with an interferon-stimulated gene signature were elevated in mild COVID-19. Severe COVID-19 was marked by occurrence of neutrophil precursors, as evidence of emergency myelopoiesis, dysfunctional mature neutrophils, and HLA-DRlo monocytes. Our study provides detailed insights into the systemic immune response to SARS-CoV-2 infection and reveals profound alterations in the myeloid cell compartment associated with severe COVID-19." | |
st.session_state["title"] = paper_title | |
st.session_state["abstract"] = paper_abstract | |
def clear_callback(): | |
""" | |
Clear input fields | |
""" | |
st.session_state["title"] = "" | |
st.session_state["abstract"] = "" | |
st.session_state["output"] = "" | |
col1, col2, col3 = st.columns([1, 1, 1]) | |
with col1: | |
st.button("Demo: immunology", on_click=demo_immunology_callback) | |
with col2: | |
st.button("Demo: virology", on_click=demo_virology_callback) | |
with col3: | |
st.button("Clear fields", on_click=clear_callback) | |
## Input fields | |
placeholder = st.empty() | |
title = st.text_input("Enter the title:", key="title") | |
abstract = st.text_area( | |
"... and maybe the abstract of the paper you want to classify:", key="abstract" | |
) | |
text = "\n".join([title, abstract]) | |
## Output | |
if len(text.strip()) > 0: | |
st.markdown(f"<h4>Predicted class: {process(text)}</h4>", unsafe_allow_html=True) | |