File size: 7,084 Bytes
c77734f
97b056e
 
 
 
7365e02
 
 
 
97b056e
69508a2
 
 
7365e02
97b056e
 
 
7365e02
 
 
847199e
97b056e
847199e
 
 
 
bb5547f
847199e
 
 
 
bb5547f
2b991f8
bb5547f
 
 
 
847199e
 
 
 
97b056e
 
7365e02
97b056e
2ecd2f9
 
69508a2
2ecd2f9
 
 
 
69508a2
 
2ecd2f9
69508a2
 
2ecd2f9
69508a2
 
2ecd2f9
 
 
 
54aab7d
97b056e
3336c4c
 
2ecd2f9
 
 
69508a2
c51be4c
2ecd2f9
c51be4c
2ecd2f9
c51be4c
 
2ecd2f9
 
 
69508a2
c51be4c
2ecd2f9
d954b92
2ecd2f9
d954b92
 
2ecd2f9
 
 
69508a2
ab43a0e
 
 
 
 
 
 
 
 
 
2ecd2f9
 
 
 
 
 
 
 
69508a2
582efe5
 
 
bb5547f
582efe5
 
9314f82
2b991f8
582efe5
 
2b991f8
9314f82
582efe5
 
 
 
ab43a0e
2ecd2f9
780f571
2ecd2f9
2c1f5f8
dae769e
 
ab43a0e
2ecd2f9
 
 
 
 
 
 
69508a2
 
 
97b056e
 
 
2ecd2f9
 
 
847199e
 
 
 
 
 
d954b92
f02578a
 
 
 
847199e
670de19
847199e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
import streamlit as st


@st.cache_data
def prepare_model():
    """
    Prepare the tokenizer and the model for classification.
    """
    tokenizer = AutoTokenizer.from_pretrained("oracat/bert-paper-classifier")
    model = AutoModelForSequenceClassification.from_pretrained(
        "oracat/bert-paper-classifier"
    )
    return (tokenizer, model)


def process(text):
    """
    Translate incoming text to tokens and classify it
    """
    pipe = pipeline("text-classification", model=model, tokenizer=tokenizer, top_k=3)
    result = pipe(text)[0]

    result = sorted(result, key=lambda x: -x["score"])

    cum_score = 0
    prev_score = 0
    for i, item in enumerate(result):
        cum_score += item["score"]
        if cum_score >= 0.95:
            break
        if i > 0:
            # Heuristic to drop less relevant categories
            if prev_score / item["score"] > 10:
                i -= 1
                break
        prev_score = item["score"]

    result = result[: (i + 1)]

    return result


tokenizer, model = prepare_model()


# State managements
#
# The state in the app is the title and the abstract.
# State management is used here in order to pre-fill
# input fields with values for demos.

if "title" not in st.session_state:
    st.session_state["title"] = ""

if "abstract" not in st.session_state:
    st.session_state["abstract"] = ""

if "output" not in st.session_state:
    st.session_state["output"] = ""


# Simple streamlit interface

st.markdown("### Biomedical paper classifier")

st.markdown("<img height=100px src='./header.png'>", unsafe_allow_html=True)


## Demo buttons and their callbacks


def demo_immunology_callback():
    """
    Use https://www.biorxiv.org/content/10.1101/2022.12.01.518788v1 for demo
    """
    paper_title = "Using TCR and BCR sequencing to unravel the role of T and B cells in abdominal aortic aneurysm"
    paper_abstract = "Recent evidence suggests that AAA displays characteristics of an autoimmune disease and it gained increasing prominence that specific antigen-driven T cells in the aortic tissue may contribute to the initial immune response. We found no clonal expansion of TCRs or BCRs in elastase-induced AAA in mice."
    st.session_state["title"] = paper_title
    st.session_state["abstract"] = paper_abstract


def demo_virology_callback():
    """
    Use https://doi.org/10.4269/ajtmh.20-0849 for demo
    """
    paper_title = "The Origin of COVID-19 and Why It Matters"
    paper_abstract = "The COVID-19 pandemic is among the deadliest infectious diseases to have emerged in recent history. As with all past pandemics, the specific mechanism of its emergence in humans remains unknown. Nevertheless, a large body of virologic, epidemiologic, veterinary, and ecologic data establishes that the new virus, SARS-CoV-2, evolved directly or indirectly from a β-coronavirus in the sarbecovirus (SARS-like virus) group that naturally infect bats and pangolins in Asia and Southeast Asia. Scientists have warned for decades that such sarbecoviruses are poised to emerge again and again, identified risk factors, and argued for enhanced pandemic prevention and control efforts. Unfortunately, few such preventive actions were taken resulting in the latest coronavirus emergence detected in late 2019 which quickly spread pandemically. The risk of similar coronavirus outbreaks in the future remains high. In addition to controlling the COVID-19 pandemic, we must undertake vigorous scientific, public health, and societal actions, including significantly increased funding for basic and applied research addressing disease emergence, to prevent this tragic history from repeating itself."
    st.session_state["title"] = paper_title
    st.session_state["abstract"] = paper_abstract


def demo_microbiology_callback():
    """
    Use https://doi.org/10.1016/j.cell.2023.01.002 for demo
    """
    paper_title = "Bacterial droplet-based single-cell RNA-seq reveals antibiotic-associated heterogeneous cellular states"
    paper_abstract = "We introduce BacDrop, a highly scalable technology for bacterial single-cell RNA sequencing that has overcome many challenges hindering the development of scRNA-seq in bacteria. BacDrop can be applied to thousands to millions of cells from both gram-negative and gram-positive species. It features universal ribosomal RNA depletion and combinatorial barcodes that enable multiplexing and massively parallel sequencing. We applied BacDrop to study Klebsiella pneumoniae clinical isolates and to elucidate their heterogeneous responses to antibiotic stress. In an unperturbed population presumed to be homogeneous, we found within-population heterogeneity largely driven by the expression of mobile genetic elements that promote the evolution of antibiotic resistance. Under antibiotic perturbation, BacDrop revealed transcriptionally distinct subpopulations associated with different phenotypic outcomes including antibiotic persistence. BacDrop thus can capture cellular states that cannot be detected by bulk RNA-seq, which will unlock new microbiological insights into bacterial responses to perturbations and larger bacterial communities such as the microbiome."
    st.session_state["title"] = paper_title
    st.session_state["abstract"] = paper_abstract


def clear_callback():
    """
    Clear input fields
    """
    st.session_state["title"] = ""
    st.session_state["abstract"] = ""
    st.session_state["output"] = ""


st.markdown(
    """<style>
.css-ocqkz7 > div:nth-child(4) button {
    background-color: #C0C0C0;
}
.css-ocqkz7 > div:nth-child(4) button:hover {
    color: #8B0000;
    background-color: #FAA0A0;
}
.css-ocqkz7 > div:nth-child(4) button:active {
    background-color: #FF6961;
    color: #FFFFFF;
}
</style>""",
    unsafe_allow_html=True,
)
col1, col2, col3, col4 = st.columns([1, 1, 1, 1])
with col1:
    st.button("Demo: immunology", on_click=demo_immunology_callback)
with col2:
    st.button("Demo: microbiology", on_click=demo_microbiology_callback)
with col3:
    st.button("Demo: virology", on_click=demo_virology_callback)
with col4:
    st.button("Clear fields", on_click=clear_callback)

## Input fields

placeholder = st.empty()

title = st.text_input("Enter the title:", key="title")
abstract = st.text_area(
    "... and maybe the abstract of the paper you want to classify:", key="abstract"
)

text = "\n".join([title, abstract])

## Output

if len(text.strip()) > 0:
    results = process(text)
    if len(results) == 0:
        out_text = ""
    else:
        out_text = f"This paper is likely to be from the category **{results[0]['label']}** *(score {results[0]['score']:.2f})*."
        if len(results) > 1:
            out_text += "\n(Other fitting categories are " + " and ".join(
                [
                    f"{item['label']} *(score {item['score']:.2f})*"
                    for item in results[1:]
                ]
            )
            out_text += ".)"
    st.markdown(out_text)