ADR_Detector / app.py
paragon-analytics's picture
Update app.py
b6e3578 verified
raw
history blame
3.81 kB
import numpy as np
import torch
import shap
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForTokenClassification
import gradio as gr
# 1) Device setup
device = "cuda" if torch.cuda.is_available() else "cpu"
# 2) Load ADR classifier
model_name = "paragon-analytics/ADRv1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name).to(device)
# 3) Hugging Face text‐classification pipeline with return_all_scores
pred_pipeline = pipeline(
"text-classification",
model=model,
tokenizer=tokenizer,
return_all_scores=True,
device=0 if device == "cuda" else -1
)
# 4) Wrapper: list[str]→np.ndarray of shape (n, n_classes)
def predict_proba(texts):
if isinstance(texts, str):
texts = [texts]
results = pred_pipeline(texts)
# results is List[List[{"label":…, "score":…}]]
probs = np.array([[d["score"] for d in sample] for sample in results])
return probs
# 5) Build SHAP explainer
masker = shap.maskers.Text(tokenizer) # for text explainability
# get output names from a dummy call
example = pred_pipeline(["test"])[0]
class_labels = [d["label"] for d in example]
explainer = shap.Explainer(
predict_proba,
masker=masker,
output_names=class_labels
)
# 6) Load biomedical NER pipeline
ner_tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all")
ner_model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all")
ner_pipe = pipeline(
"ner",
model=ner_model,
tokenizer=ner_tokenizer,
aggregation_strategy="simple",
device=0 if device == "cuda" else -1
)
# 7) Single‐text prediction + SHAP + NER
def adr_predict(text):
# a) Predict probabilities
probs = predict_proba(text)[0]
prob_dict = {label: float(probs[i]) for i, label in enumerate(class_labels)}
# b) SHAP explanation (returns a Matplotlib figure)
shap_values = explainer([text])
fig = shap.plots.text(shap_values[0], display=False)
# c) NER highlighting
entities = ner_pipe(text)
colors = {
"Severity": "red",
"Sign_symptom": "green",
"Medication": "lightblue",
"Age": "yellow",
"Sex": "yellow",
"Diagnostic_procedure": "gray",
"Biological_structure": "silver"
}
highlighted = ""
last_idx = 0
for ent in entities:
start, end = ent["start"], ent["end"]
word = ent["word"].replace("##", "")
color = colors.get(ent["entity_group"], "lightgray")
highlighted += (
text[last_idx:start]
+ f"<mark style='background-color:{color};'>{word}</mark>"
)
last_idx = end
highlighted += text[last_idx:]
return prob_dict, fig, highlighted
# 8) Gradio UI
with gr.Blocks() as demo:
gr.Markdown("## Welcome to **ADR Detector** 🪐")
gr.Markdown(
"Predicts the likelihood your text describes a severe vs. non-severe adverse reaction. "
"_(Not for medical diagnosis.)_"
)
txt = gr.Textbox(label="Enter Your Text Here:", lines=3, placeholder="Type a sentence about a reaction…")
btn = gr.Button("Analyze")
with gr.Row():
lbl = gr.Label(label="Predicted Probabilities")
shp = gr.Plot(label="SHAP Explanation")
ner = gr.HTML(label="Biomedical Entities Highlighted")
btn.click(fn=adr_predict, inputs=txt, outputs=[lbl, shp, ner])
gr.Examples(
examples=[
"A 35-year-old male experienced severe headache after taking Aspirin.",
"A 35-year-old female had minor abdominal pain after Acetaminophen."
],
inputs=txt,
outputs=[lbl, shp, ner],
fn=adr_predict,
cache_examples=True
)
demo.launch()