Spaces:
Sleeping
Sleeping
File size: 4,363 Bytes
261ea5b e669abf b6e3578 41b32e2 b6e3578 55ecc4c b6e3578 41b32e2 b6e3578 41b32e2 b6e3578 41b32e2 b6e3578 41b32e2 b6e3578 41b32e2 b6e3578 41b32e2 b6e3578 41b32e2 b6e3578 41b32e2 b6e3578 41b32e2 b6e3578 41b32e2 b6e3578 41b32e2 b6e3578 41b32e2 b6e3578 41b32e2 b6e3578 41b32e2 b6e3578 41b32e2 b6e3578 41b32e2 b6e3578 41b32e2 b6e3578 41b32e2 b6e3578 7c8cdc0 41b32e2 b6e3578 e669abf cde5ee9 41b32e2 b6e3578 41b32e2 b6e3578 41b32e2 b6e3578 e669abf b6e3578 41b32e2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import numpy as np
import torch
import shap
from transformers import (
pipeline,
AutoTokenizer,
AutoModelForSequenceClassification,
AutoModelForTokenClassification
)
import gradio as gr
# 1) Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 2) Load ADR classifier model & tokenizer
model_name = "paragon-analytics/ADRv1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name).to(device)
# 3) Build HF text-classification pipeline
pred_pipeline = pipeline(
"text-classification",
model=model,
tokenizer=tokenizer,
return_all_scores=True,
device=0 if device.type == "cuda" else -1
)
# 4) Base predict_proba: List[str] → np.ndarray of shape (n_samples, n_classes)
def predict_proba(texts):
if isinstance(texts, str):
texts = [texts]
results = pred_pipeline(texts)
# results: List[List[{"label":…, "score":…}]]
probs = np.array([[d["score"] for d in sample] for sample in results])
return probs
# 5) SHAP-compatible wrapper: joins token lists back into strings
def predict_proba_shap(inputs):
# inputs: List[str] or List[List[str]]
texts = [
" ".join(x) if isinstance(x, list) else x
for x in inputs
]
return predict_proba(texts)
# 6) Instantiate SHAP explainer with a Text masker
masker = shap.maskers.Text(tokenizer)
# Grab output class labels from a dummy sample
_example = pred_pipeline(["test"])[0]
class_labels = [d["label"] for d in _example]
explainer = shap.Explainer(
predict_proba_shap,
masker=masker,
output_names=class_labels
)
# 7) Load biomedical NER model & pipeline
ner_model_name = "d4data/biomedical-ner-all"
ner_tokenizer = AutoTokenizer.from_pretrained(ner_model_name)
ner_model = AutoModelForTokenClassification.from_pretrained(ner_model_name).to(device)
ner_pipe = pipeline(
"ner",
model=ner_model,
tokenizer=ner_tokenizer,
aggregation_strategy="simple",
device=0 if device.type == "cuda" else -1
)
# 8) Mapping for entity highlight colors
ENTITY_COLORS = {
"Severity": "red",
"Sign_symptom": "green",
"Medication": "lightblue",
"Age": "yellow",
"Sex": "yellow",
"Diagnostic_procedure": "gray",
"Biological_structure": "silver"
}
# 9) Full predict + explain + NER function
def adr_predict(text: str):
# a) Predict probabilities
probs = predict_proba([text])[0]
prob_dict = {label: float(probs[i]) for i, label in enumerate(class_labels)}
# b) SHAP explanation → Matplotlib figure
shap_values = explainer([text])
fig = shap.plots.text(shap_values[0], display=False)
# c) NER highlighting
ents = ner_pipe(text)
highlighted = ""
last_idx = 0
for ent in ents:
start, end = ent["start"], ent["end"]
word = ent["word"].replace("##", "")
color = ENTITY_COLORS.get(ent["entity_group"], "lightgray")
highlighted += (
text[last_idx:start]
+ f"<mark style='background-color:{color};'>{word}</mark>"
)
last_idx = end
highlighted += text[last_idx:]
return prob_dict, fig, highlighted
# 10) Build Gradio UI
with gr.Blocks() as demo:
gr.Markdown("## Welcome to **ADR Detector** 🪐")
gr.Markdown(
"Predicts the likelihood your text describes a **severe** vs. **non-severe** adverse reaction. \n"
"_(Not for medical or diagnostic use.)_"
)
txt = gr.Textbox(
label="Enter Your Text Here:",
lines=3,
placeholder="Type a sentence about an adverse reaction…"
)
btn = gr.Button("Analyze")
with gr.Row():
label_out = gr.Label(label="Predicted Probabilities")
shap_out = gr.Plot(label="SHAP Explanation")
ner_out = gr.HTML(label="Biomedical Entities Highlighted")
btn.click(
fn=adr_predict,
inputs=txt,
outputs=[label_out, shap_out, ner_out]
)
gr.Examples(
examples=[
"A 35-year-old male experienced severe headache after taking Aspirin.",
"A 35-year-old female had minor abdominal pain after Acetaminophen."
],
inputs=txt,
outputs=[label_out, shap_out, ner_out],
fn=adr_predict,
cache_examples=True
)
if __name__ == "__main__":
demo.launch()
|