File size: 3,808 Bytes
261ea5b
e669abf
b6e3578
 
68fbe9e
b6e3578
55ecc4c
b6e3578
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8a754d8
 
b6e3578
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7c8cdc0
b6e3578
 
e669abf
cde5ee9
b6e3578
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e669abf
b6e3578
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import numpy as np
import torch
import shap
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForTokenClassification

import gradio as gr

# 1) Device setup
device = "cuda" if torch.cuda.is_available() else "cpu"

# 2) Load ADR classifier
model_name = "paragon-analytics/ADRv1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name).to(device)

# 3) Hugging Face text‐classification pipeline with return_all_scores
pred_pipeline = pipeline(
    "text-classification",
    model=model,
    tokenizer=tokenizer,
    return_all_scores=True,
    device=0 if device == "cuda" else -1
)

# 4) Wrapper: list[str]→np.ndarray of shape (n, n_classes)
def predict_proba(texts):
    if isinstance(texts, str):
        texts = [texts]
    results = pred_pipeline(texts)
    # results is List[List[{"label":…, "score":…}]]
    probs = np.array([[d["score"] for d in sample] for sample in results])
    return probs

# 5) Build SHAP explainer
masker = shap.maskers.Text(tokenizer)  # for text explainability
# get output names from a dummy call
example = pred_pipeline(["test"])[0]
class_labels = [d["label"] for d in example]
explainer = shap.Explainer(
    predict_proba,
    masker=masker,
    output_names=class_labels
)

# 6) Load biomedical NER pipeline
ner_tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all")
ner_model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all")
ner_pipe = pipeline(
    "ner",
    model=ner_model,
    tokenizer=ner_tokenizer,
    aggregation_strategy="simple",
    device=0 if device == "cuda" else -1
)

# 7) Single‐text prediction + SHAP + NER
def adr_predict(text):
    # a) Predict probabilities
    probs = predict_proba(text)[0]
    prob_dict = {label: float(probs[i]) for i, label in enumerate(class_labels)}

    # b) SHAP explanation (returns a Matplotlib figure)
    shap_values = explainer([text])
    fig = shap.plots.text(shap_values[0], display=False)

    # c) NER highlighting
    entities = ner_pipe(text)
    colors = {
        "Severity": "red",
        "Sign_symptom": "green",
        "Medication": "lightblue",
        "Age": "yellow",
        "Sex": "yellow",
        "Diagnostic_procedure": "gray",
        "Biological_structure": "silver"
    }
    highlighted = ""
    last_idx = 0
    for ent in entities:
        start, end = ent["start"], ent["end"]
        word = ent["word"].replace("##", "")
        color = colors.get(ent["entity_group"], "lightgray")
        highlighted += (
            text[last_idx:start]
            + f"<mark style='background-color:{color};'>{word}</mark>"
        )
        last_idx = end
    highlighted += text[last_idx:]

    return prob_dict, fig, highlighted

# 8) Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("## Welcome to **ADR Detector** 🪐")
    gr.Markdown(
        "Predicts the likelihood your text describes a severe vs. non-severe adverse reaction. "
        "_(Not for medical diagnosis.)_"
    )

    txt = gr.Textbox(label="Enter Your Text Here:", lines=3, placeholder="Type a sentence about a reaction…")
    btn = gr.Button("Analyze")

    with gr.Row():
        lbl = gr.Label(label="Predicted Probabilities")
        shp = gr.Plot(label="SHAP Explanation")
        ner = gr.HTML(label="Biomedical Entities Highlighted")

    btn.click(fn=adr_predict, inputs=txt, outputs=[lbl, shp, ner])

    gr.Examples(
        examples=[
            "A 35-year-old male experienced severe headache after taking Aspirin.",
            "A 35-year-old female had minor abdominal pain after Acetaminophen."
        ],
        inputs=txt,
        outputs=[lbl, shp, ner],
        fn=adr_predict,
        cache_examples=True
    )

demo.launch()