File size: 4,363 Bytes
261ea5b
e669abf
b6e3578
41b32e2
 
 
 
 
 
b6e3578
55ecc4c
b6e3578
41b32e2
b6e3578
41b32e2
b6e3578
 
 
 
41b32e2
b6e3578
 
 
 
 
41b32e2
b6e3578
 
41b32e2
b6e3578
 
 
 
41b32e2
b6e3578
 
 
41b32e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b6e3578
41b32e2
b6e3578
 
 
 
41b32e2
 
 
 
b6e3578
 
 
 
 
41b32e2
b6e3578
 
41b32e2
 
 
 
 
 
 
 
 
 
 
 
 
b6e3578
41b32e2
b6e3578
 
41b32e2
b6e3578
 
 
 
41b32e2
b6e3578
 
41b32e2
b6e3578
 
41b32e2
b6e3578
 
 
 
 
 
 
 
 
41b32e2
b6e3578
 
 
41b32e2
 
b6e3578
7c8cdc0
41b32e2
 
 
 
 
b6e3578
e669abf
cde5ee9
41b32e2
 
 
b6e3578
41b32e2
 
 
 
 
b6e3578
 
 
 
 
 
 
41b32e2
b6e3578
 
e669abf
b6e3578
41b32e2
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import numpy as np
import torch
import shap
from transformers import (
    pipeline,
    AutoTokenizer,
    AutoModelForSequenceClassification,
    AutoModelForTokenClassification
)
import gradio as gr

# 1) Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 2) Load ADR classifier model & tokenizer
model_name = "paragon-analytics/ADRv1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name).to(device)

# 3) Build HF text-classification pipeline
pred_pipeline = pipeline(
    "text-classification",
    model=model,
    tokenizer=tokenizer,
    return_all_scores=True,
    device=0 if device.type == "cuda" else -1
)

# 4) Base predict_proba: List[str] → np.ndarray of shape (n_samples, n_classes)
def predict_proba(texts):
    if isinstance(texts, str):
        texts = [texts]
    results = pred_pipeline(texts)
    # results: List[List[{"label":…, "score":…}]]
    probs = np.array([[d["score"] for d in sample] for sample in results])
    return probs

# 5) SHAP-compatible wrapper: joins token lists back into strings
def predict_proba_shap(inputs):
    # inputs: List[str] or List[List[str]]
    texts = [
        " ".join(x) if isinstance(x, list) else x
        for x in inputs
    ]
    return predict_proba(texts)

# 6) Instantiate SHAP explainer with a Text masker
masker = shap.maskers.Text(tokenizer)
# Grab output class labels from a dummy sample
_example = pred_pipeline(["test"])[0]
class_labels = [d["label"] for d in _example]

explainer = shap.Explainer(
    predict_proba_shap,
    masker=masker,
    output_names=class_labels
)

# 7) Load biomedical NER model & pipeline
ner_model_name = "d4data/biomedical-ner-all"
ner_tokenizer = AutoTokenizer.from_pretrained(ner_model_name)
ner_model = AutoModelForTokenClassification.from_pretrained(ner_model_name).to(device)
ner_pipe = pipeline(
    "ner",
    model=ner_model,
    tokenizer=ner_tokenizer,
    aggregation_strategy="simple",
    device=0 if device.type == "cuda" else -1
)

# 8) Mapping for entity highlight colors
ENTITY_COLORS = {
    "Severity": "red",
    "Sign_symptom": "green",
    "Medication": "lightblue",
    "Age": "yellow",
    "Sex": "yellow",
    "Diagnostic_procedure": "gray",
    "Biological_structure": "silver"
}

# 9) Full predict + explain + NER function
def adr_predict(text: str):
    # a) Predict probabilities
    probs = predict_proba([text])[0]
    prob_dict = {label: float(probs[i]) for i, label in enumerate(class_labels)}

    # b) SHAP explanation → Matplotlib figure
    shap_values = explainer([text])
    fig = shap.plots.text(shap_values[0], display=False)

    # c) NER highlighting
    ents = ner_pipe(text)
    highlighted = ""
    last_idx = 0
    for ent in ents:
        start, end = ent["start"], ent["end"]
        word = ent["word"].replace("##", "")
        color = ENTITY_COLORS.get(ent["entity_group"], "lightgray")
        highlighted += (
            text[last_idx:start]
            + f"<mark style='background-color:{color};'>{word}</mark>"
        )
        last_idx = end
    highlighted += text[last_idx:]

    return prob_dict, fig, highlighted

# 10) Build Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("## Welcome to **ADR Detector** 🪐")
    gr.Markdown(
        "Predicts the likelihood your text describes a **severe** vs. **non-severe** adverse reaction.  \n"
        "_(Not for medical or diagnostic use.)_"
    )

    txt = gr.Textbox(
        label="Enter Your Text Here:",
        lines=3,
        placeholder="Type a sentence about an adverse reaction…"
    )
    btn = gr.Button("Analyze")

    with gr.Row():
        label_out = gr.Label(label="Predicted Probabilities")
        shap_out = gr.Plot(label="SHAP Explanation")
        ner_out = gr.HTML(label="Biomedical Entities Highlighted")

    btn.click(
        fn=adr_predict,
        inputs=txt,
        outputs=[label_out, shap_out, ner_out]
    )

    gr.Examples(
        examples=[
            "A 35-year-old male experienced severe headache after taking Aspirin.",
            "A 35-year-old female had minor abdominal pain after Acetaminophen."
        ],
        inputs=txt,
        outputs=[label_out, shap_out, ner_out],
        fn=adr_predict,
        cache_examples=True
    )

if __name__ == "__main__":
    demo.launch()