ADR_Detector / app.py
paragon-analytics's picture
Update app.py
d2b9b3e verified
raw
history blame
4.05 kB
import numpy as np
import torch
import shap
from transformers import (
pipeline,
AutoTokenizer,
AutoModelForSequenceClassification,
AutoModelForTokenClassification
)
import gradio as gr
# β€”β€”β€”β€”β€”β€”β€”β€”β€” 1) Device setup β€”β€”β€”β€”β€”β€”β€”β€”β€”
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# β€”β€”β€”β€”β€”β€”β€”β€”β€” 2) ADR classifier β€”β€”β€”β€”β€”β€”β€”β€”β€”
model_name = "paragon-analytics/ADRv1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name).to(device)
pred_pipeline = pipeline(
"text-classification",
model=model,
tokenizer=tokenizer,
return_all_scores=True,
device=0 if device.type == "cuda" else -1
)
def predict_proba(texts):
if isinstance(texts, str):
texts = [texts]
results = pred_pipeline(texts)
return np.array([[d["score"] for d in sample] for sample in results])
def predict_proba_shap(inputs):
texts = [" ".join(x) if isinstance(x, list) else x for x in inputs]
return predict_proba(texts)
# β€”β€”β€”β€”β€”β€”β€”β€”β€” 3) SHAP explainer β€”β€”β€”β€”β€”β€”β€”β€”β€”
masker = shap.maskers.Text(tokenizer)
_example = pred_pipeline(["test"])[0]
class_labels = [d["label"] for d in _example]
explainer = shap.Explainer(
predict_proba_shap,
masker=masker,
output_names=class_labels
)
# β€”β€”β€”β€”β€”β€”β€”β€”β€” 4) Biomedical NER β€”β€”β€”β€”β€”β€”β€”β€”β€”
ner_name = "d4data/biomedical-ner-all"
ner_tokenizer = AutoTokenizer.from_pretrained(ner_name)
ner_model = AutoModelForTokenClassification.from_pretrained(ner_name).to(device)
ner_pipe = pipeline(
"ner",
model=ner_model,
tokenizer=ner_tokenizer,
aggregation_strategy="simple",
device=0 if device.type == "cuda" else -1
)
ENTITY_COLORS = {
"Severity": "red",
"Sign_symptom": "green",
"Medication": "lightblue",
"Age": "yellow",
"Sex": "yellow",
"Diagnostic_procedure": "gray",
"Biological_structure": "silver"
}
# β€”β€”β€”β€”β€”β€”β€”β€”β€” 5) Prediction + SHAP + NER β€”β€”β€”β€”β€”β€”β€”β€”β€”
def adr_predict(text: str):
# Probabilities
probs = predict_proba([text])[0]
prob_dict = {cls: float(probs[i]) for i, cls in enumerate(class_labels)}
# SHAP
shap_vals = explainer([text])
fig = shap.plots.text(shap_vals[0], display=False)
# NER highlight
ents = ner_pipe(text)
highlighted, last = "", 0
for ent in ents:
s, e = ent["start"], ent["end"]
w = ent["word"].replace("##", "")
color = ENTITY_COLORS.get(ent["entity_group"], "lightgray")
highlighted += text[last:s] + f"<mark style='background-color:{color};'>{w}</mark>"
last = e
highlighted += text[last:]
return prob_dict, fig, highlighted
# β€”β€”β€”β€”β€”β€”β€”β€”β€” 6) Gradio UI β€”β€”β€”β€”β€”β€”β€”β€”β€”
with gr.Blocks() as demo:
gr.Markdown("## Welcome to **ADR Detector** πŸͺ")
gr.Markdown(
"Predicts how likely your text describes a **severe** vs. **non-severe** adverse reaction. \n"
"_(Not for medical or diagnostic use.)_"
)
txt = gr.Textbox(
label="Enter Your Text Here:", lines=3,
placeholder="Type a sentence about an adverse reaction…"
)
btn = gr.Button("Analyze")
with gr.Row():
out_prob = gr.Label(label="Predicted Probabilities")
out_shap = gr.Plot(label="SHAP Explanation")
out_ner = gr.HTML(label="Biomedical Entities Highlighted")
btn.click(
fn=adr_predict,
inputs=txt,
outputs=[out_prob, out_shap, out_ner]
)
gr.Examples(
examples=[
"A 35-year-old male experienced severe headache after taking Aspirin.",
"A 35-year-old female had minor abdominal pain after Acetaminophen."
],
inputs=txt,
outputs=[out_prob, out_shap, out_ner],
fn=adr_predict,
cache_examples=False # ← disable startup caching here
)
if __name__ == "__main__":
demo.launch()