Spaces:
Sleeping
Sleeping
import numpy as np | |
import torch | |
import shap | |
from transformers import ( | |
pipeline, | |
AutoTokenizer, | |
AutoModelForSequenceClassification, | |
AutoModelForTokenClassification | |
) | |
import gradio as gr | |
# 1) Device setup | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# 2) Load ADR classifier model & tokenizer | |
model_name = "paragon-analytics/ADRv1" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForSequenceClassification.from_pretrained(model_name).to(device) | |
# 3) Build HF text-classification pipeline | |
pred_pipeline = pipeline( | |
"text-classification", | |
model=model, | |
tokenizer=tokenizer, | |
return_all_scores=True, | |
device=0 if device.type == "cuda" else -1 | |
) | |
# 4) Base predict_proba: List[str] → np.ndarray of shape (n_samples, n_classes) | |
def predict_proba(texts): | |
if isinstance(texts, str): | |
texts = [texts] | |
results = pred_pipeline(texts) | |
# results: List[List[{"label":…, "score":…}]] | |
probs = np.array([[d["score"] for d in sample] for sample in results]) | |
return probs | |
# 5) SHAP-compatible wrapper: joins token lists back into strings | |
def predict_proba_shap(inputs): | |
# inputs: List[str] or List[List[str]] | |
texts = [ | |
" ".join(x) if isinstance(x, list) else x | |
for x in inputs | |
] | |
return predict_proba(texts) | |
# 6) Instantiate SHAP explainer with a Text masker | |
masker = shap.maskers.Text(tokenizer) | |
# Grab output class labels from a dummy sample | |
_example = pred_pipeline(["test"])[0] | |
class_labels = [d["label"] for d in _example] | |
explainer = shap.Explainer( | |
predict_proba_shap, | |
masker=masker, | |
output_names=class_labels | |
) | |
# 7) Load biomedical NER model & pipeline | |
ner_model_name = "d4data/biomedical-ner-all" | |
ner_tokenizer = AutoTokenizer.from_pretrained(ner_model_name) | |
ner_model = AutoModelForTokenClassification.from_pretrained(ner_model_name).to(device) | |
ner_pipe = pipeline( | |
"ner", | |
model=ner_model, | |
tokenizer=ner_tokenizer, | |
aggregation_strategy="simple", | |
device=0 if device.type == "cuda" else -1 | |
) | |
# 8) Mapping for entity highlight colors | |
ENTITY_COLORS = { | |
"Severity": "red", | |
"Sign_symptom": "green", | |
"Medication": "lightblue", | |
"Age": "yellow", | |
"Sex": "yellow", | |
"Diagnostic_procedure": "gray", | |
"Biological_structure": "silver" | |
} | |
# 9) Full predict + explain + NER function | |
def adr_predict(text: str): | |
# a) Predict probabilities | |
probs = predict_proba([text])[0] | |
prob_dict = {label: float(probs[i]) for i, label in enumerate(class_labels)} | |
# b) SHAP explanation → Matplotlib figure | |
shap_values = explainer([text]) | |
fig = shap.plots.text(shap_values[0], display=False) | |
# c) NER highlighting | |
ents = ner_pipe(text) | |
highlighted = "" | |
last_idx = 0 | |
for ent in ents: | |
start, end = ent["start"], ent["end"] | |
word = ent["word"].replace("##", "") | |
color = ENTITY_COLORS.get(ent["entity_group"], "lightgray") | |
highlighted += ( | |
text[last_idx:start] | |
+ f"<mark style='background-color:{color};'>{word}</mark>" | |
) | |
last_idx = end | |
highlighted += text[last_idx:] | |
return prob_dict, fig, highlighted | |
# 10) Build Gradio UI | |
with gr.Blocks() as demo: | |
gr.Markdown("## Welcome to **ADR Detector** 🪐") | |
gr.Markdown( | |
"Predicts the likelihood your text describes a **severe** vs. **non-severe** adverse reaction. \n" | |
"_(Not for medical or diagnostic use.)_" | |
) | |
txt = gr.Textbox( | |
label="Enter Your Text Here:", | |
lines=3, | |
placeholder="Type a sentence about an adverse reaction…" | |
) | |
btn = gr.Button("Analyze") | |
with gr.Row(): | |
label_out = gr.Label(label="Predicted Probabilities") | |
shap_out = gr.Plot(label="SHAP Explanation") | |
ner_out = gr.HTML(label="Biomedical Entities Highlighted") | |
btn.click( | |
fn=adr_predict, | |
inputs=txt, | |
outputs=[label_out, shap_out, ner_out] | |
) | |
gr.Examples( | |
examples=[ | |
"A 35-year-old male experienced severe headache after taking Aspirin.", | |
"A 35-year-old female had minor abdominal pain after Acetaminophen." | |
], | |
inputs=txt, | |
outputs=[label_out, shap_out, ner_out], | |
fn=adr_predict, | |
cache_examples=True | |
) | |
if __name__ == "__main__": | |
demo.launch() | |