Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import pipeline | |
from sentence_transformers import SentenceTransformer, util | |
import numpy as np | |
# Load a freely available medical model from Hugging Face | |
# We'll use a smaller model for faster inference in the Space | |
MODEL_NAME = "d4data/biomedical-clinical-trials-bert" | |
CLASSES = [ | |
"Cardiology", "Neurology", "Oncology", "Pediatrics", | |
"Orthopedics", "Dermatology", "Gastroenterology", | |
"Endocrinology", "Psychiatry", "Pulmonology" | |
] | |
# Few-shot examples for each specialty | |
EXAMPLES = { | |
"Cardiology": [ | |
"chest pain", "shortness of breath", "palpitations", | |
"high blood pressure", "irregular heartbeat" | |
], | |
"Neurology": [ | |
"headache", "dizziness", "numbness in limbs", | |
"seizures", "memory problems" | |
], | |
"Oncology": [ | |
"unexplained weight loss", "persistent lumps", | |
"unusual bleeding", "chronic fatigue", "skin changes" | |
], | |
# Add more examples for other specialties... | |
} | |
# Initialize models | |
classifier = pipeline( | |
"text-classification", | |
model=MODEL_NAME, | |
tokenizer=MODEL_NAME | |
) | |
embedder = SentenceTransformer('all-MiniLM-L6-v2') | |
def predict_specialty(symptoms): | |
""" | |
Predict the most relevant medical specialty based on symptoms. | |
Uses both classification and semantic similarity for robust predictions. | |
""" | |
# Get classification prediction | |
pred = classifier(symptoms) | |
predicted_class = pred[0]['label'] | |
# Enhance with semantic similarity to example symptoms | |
symptom_embedding = embedder.encode(symptoms, convert_to_tensor=True) | |
similarities = {} | |
for specialty, examples in EXAMPLES.items(): | |
example_embeddings = embedder.encode(examples, convert_to_tensor=True) | |
similarity = util.pytorch_cos_sim(symptom_embedding, example_embeddings).mean().item() | |
similarities[specialty] = similarity | |
# Combine predictions (simple average for demo) | |
combined_scores = {} | |
for specialty in CLASSES: | |
class_score = 1.0 if specialty == predicted_class else 0.0 | |
sim_score = similarities.get(specialty, 0.0) | |
combined_scores[specialty] = (class_score + sim_score) / 2 | |
# Get top 3 predictions | |
sorted_specialties = sorted(combined_scores.items(), key=lambda x: x[1], reverse=True) | |
top_3 = sorted_specialties[:3] | |
# Format output | |
result = { | |
"Primary Specialty": top_3[0][0], | |
"Confidence": f"{top_3[0][1]*100:.1f}%", | |
"Alternative Suggestions": [s[0] for s in top_3[1:]] | |
} | |
return result | |
# Create Gradio interface | |
demo = gr.Interface( | |
fn=predict_specialty, | |
inputs=gr.Textbox(label="Describe your symptoms", placeholder="e.g., chest pain and shortness of breath..."), | |
outputs=[ | |
gr.Label(label="Primary Specialty"), | |
gr.Textbox(label="Confidence"), | |
gr.JSON(label="Alternative Suggestions") | |
], | |
examples=[ | |
["chest pain and dizziness"], | |
["persistent headaches with nausea"], | |
["unexplained weight loss and fatigue"], | |
["skin rash and itching"] | |
], | |
title="Medical Specialty Classifier", | |
description="Enter your symptoms to find the most relevant medical specialty. Note: This is for educational purposes only and not a substitute for professional medical advice." | |
) | |
demo.launch() |