import numpy as np import torch import shap from transformers import ( pipeline, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForTokenClassification ) import gradio as gr # 1) Device setup device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 2) Load ADR classifier model & tokenizer model_name = "paragon-analytics/ADRv1" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSequenceClassification.from_pretrained(model_name).to(device) # 3) Build HF text-classification pipeline pred_pipeline = pipeline( "text-classification", model=model, tokenizer=tokenizer, return_all_scores=True, device=0 if device.type == "cuda" else -1 ) # 4) Base predict_proba: List[str] → np.ndarray of shape (n_samples, n_classes) def predict_proba(texts): if isinstance(texts, str): texts = [texts] results = pred_pipeline(texts) # results: List[List[{"label":…, "score":…}]] probs = np.array([[d["score"] for d in sample] for sample in results]) return probs # 5) SHAP-compatible wrapper: joins token lists back into strings def predict_proba_shap(inputs): # inputs: List[str] or List[List[str]] texts = [ " ".join(x) if isinstance(x, list) else x for x in inputs ] return predict_proba(texts) # 6) Instantiate SHAP explainer with a Text masker masker = shap.maskers.Text(tokenizer) # Grab output class labels from a dummy sample _example = pred_pipeline(["test"])[0] class_labels = [d["label"] for d in _example] explainer = shap.Explainer( predict_proba_shap, masker=masker, output_names=class_labels ) # 7) Load biomedical NER model & pipeline ner_model_name = "d4data/biomedical-ner-all" ner_tokenizer = AutoTokenizer.from_pretrained(ner_model_name) ner_model = AutoModelForTokenClassification.from_pretrained(ner_model_name).to(device) ner_pipe = pipeline( "ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple", device=0 if device.type == "cuda" else -1 ) # 8) Mapping for entity highlight colors ENTITY_COLORS = { "Severity": "red", "Sign_symptom": "green", "Medication": "lightblue", "Age": "yellow", "Sex": "yellow", "Diagnostic_procedure": "gray", "Biological_structure": "silver" } # 9) Full predict + explain + NER function def adr_predict(text: str): # a) Predict probabilities probs = predict_proba([text])[0] prob_dict = {label: float(probs[i]) for i, label in enumerate(class_labels)} # b) SHAP explanation → Matplotlib figure shap_values = explainer([text]) fig = shap.plots.text(shap_values[0], display=False) # c) NER highlighting ents = ner_pipe(text) highlighted = "" last_idx = 0 for ent in ents: start, end = ent["start"], ent["end"] word = ent["word"].replace("##", "") color = ENTITY_COLORS.get(ent["entity_group"], "lightgray") highlighted += ( text[last_idx:start] + f"{word}" ) last_idx = end highlighted += text[last_idx:] return prob_dict, fig, highlighted # 10) Build Gradio UI with gr.Blocks() as demo: gr.Markdown("## Welcome to **ADR Detector** 🪐") gr.Markdown( "Predicts the likelihood your text describes a **severe** vs. **non-severe** adverse reaction. \n" "_(Not for medical or diagnostic use.)_" ) txt = gr.Textbox( label="Enter Your Text Here:", lines=3, placeholder="Type a sentence about an adverse reaction…" ) btn = gr.Button("Analyze") with gr.Row(): label_out = gr.Label(label="Predicted Probabilities") shap_out = gr.Plot(label="SHAP Explanation") ner_out = gr.HTML(label="Biomedical Entities Highlighted") btn.click( fn=adr_predict, inputs=txt, outputs=[label_out, shap_out, ner_out] ) gr.Examples( examples=[ "A 35-year-old male experienced severe headache after taking Aspirin.", "A 35-year-old female had minor abdominal pain after Acetaminophen." ], inputs=txt, outputs=[label_out, shap_out, ner_out], fn=adr_predict, cache_examples=True ) if __name__ == "__main__": demo.launch()