Spaces:
Sleeping
Sleeping
File size: 4,052 Bytes
261ea5b e669abf b6e3578 41b32e2 b6e3578 55ecc4c d2b9b3e 41b32e2 b6e3578 d2b9b3e b6e3578 41b32e2 b6e3578 d2b9b3e b6e3578 41b32e2 d2b9b3e 41b32e2 d2b9b3e 41b32e2 b6e3578 41b32e2 b6e3578 d2b9b3e b6e3578 41b32e2 b6e3578 41b32e2 d2b9b3e 41b32e2 d2b9b3e 41b32e2 d2b9b3e 41b32e2 d2b9b3e 41b32e2 d2b9b3e 41b32e2 d2b9b3e b6e3578 d2b9b3e b6e3578 d2b9b3e 41b32e2 b6e3578 7c8cdc0 41b32e2 d2b9b3e 41b32e2 b6e3578 e669abf cde5ee9 d2b9b3e b6e3578 41b32e2 d2b9b3e 41b32e2 b6e3578 d2b9b3e b6e3578 d2b9b3e e669abf b6e3578 41b32e2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import numpy as np
import torch
import shap
from transformers import (
pipeline,
AutoTokenizer,
AutoModelForSequenceClassification,
AutoModelForTokenClassification
)
import gradio as gr
# βββββββββ 1) Device setup βββββββββ
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# βββββββββ 2) ADR classifier βββββββββ
model_name = "paragon-analytics/ADRv1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name).to(device)
pred_pipeline = pipeline(
"text-classification",
model=model,
tokenizer=tokenizer,
return_all_scores=True,
device=0 if device.type == "cuda" else -1
)
def predict_proba(texts):
if isinstance(texts, str):
texts = [texts]
results = pred_pipeline(texts)
return np.array([[d["score"] for d in sample] for sample in results])
def predict_proba_shap(inputs):
texts = [" ".join(x) if isinstance(x, list) else x for x in inputs]
return predict_proba(texts)
# βββββββββ 3) SHAP explainer βββββββββ
masker = shap.maskers.Text(tokenizer)
_example = pred_pipeline(["test"])[0]
class_labels = [d["label"] for d in _example]
explainer = shap.Explainer(
predict_proba_shap,
masker=masker,
output_names=class_labels
)
# βββββββββ 4) Biomedical NER βββββββββ
ner_name = "d4data/biomedical-ner-all"
ner_tokenizer = AutoTokenizer.from_pretrained(ner_name)
ner_model = AutoModelForTokenClassification.from_pretrained(ner_name).to(device)
ner_pipe = pipeline(
"ner",
model=ner_model,
tokenizer=ner_tokenizer,
aggregation_strategy="simple",
device=0 if device.type == "cuda" else -1
)
ENTITY_COLORS = {
"Severity": "red",
"Sign_symptom": "green",
"Medication": "lightblue",
"Age": "yellow",
"Sex": "yellow",
"Diagnostic_procedure": "gray",
"Biological_structure": "silver"
}
# βββββββββ 5) Prediction + SHAP + NER βββββββββ
def adr_predict(text: str):
# Probabilities
probs = predict_proba([text])[0]
prob_dict = {cls: float(probs[i]) for i, cls in enumerate(class_labels)}
# SHAP
shap_vals = explainer([text])
fig = shap.plots.text(shap_vals[0], display=False)
# NER highlight
ents = ner_pipe(text)
highlighted, last = "", 0
for ent in ents:
s, e = ent["start"], ent["end"]
w = ent["word"].replace("##", "")
color = ENTITY_COLORS.get(ent["entity_group"], "lightgray")
highlighted += text[last:s] + f"<mark style='background-color:{color};'>{w}</mark>"
last = e
highlighted += text[last:]
return prob_dict, fig, highlighted
# βββββββββ 6) Gradio UI βββββββββ
with gr.Blocks() as demo:
gr.Markdown("## Welcome to **ADR Detector** πͺ")
gr.Markdown(
"Predicts how likely your text describes a **severe** vs. **non-severe** adverse reaction. \n"
"_(Not for medical or diagnostic use.)_"
)
txt = gr.Textbox(
label="Enter Your Text Here:", lines=3,
placeholder="Type a sentence about an adverse reactionβ¦"
)
btn = gr.Button("Analyze")
with gr.Row():
out_prob = gr.Label(label="Predicted Probabilities")
out_shap = gr.Plot(label="SHAP Explanation")
out_ner = gr.HTML(label="Biomedical Entities Highlighted")
btn.click(
fn=adr_predict,
inputs=txt,
outputs=[out_prob, out_shap, out_ner]
)
gr.Examples(
examples=[
"A 35-year-old male experienced severe headache after taking Aspirin.",
"A 35-year-old female had minor abdominal pain after Acetaminophen."
],
inputs=txt,
outputs=[out_prob, out_shap, out_ner],
fn=adr_predict,
cache_examples=False # β disable startup caching here
)
if __name__ == "__main__":
demo.launch()
|