SumayyaMalik's picture
Create app.py
ce729b9 verified
raw
history blame
3.35 kB
import gradio as gr
from transformers import pipeline
from sentence_transformers import SentenceTransformer, util
import numpy as np
# Load a freely available medical model from Hugging Face
# We'll use a smaller model for faster inference in the Space
MODEL_NAME = "d4data/biomedical-clinical-trials-bert"
CLASSES = [
"Cardiology", "Neurology", "Oncology", "Pediatrics",
"Orthopedics", "Dermatology", "Gastroenterology",
"Endocrinology", "Psychiatry", "Pulmonology"
]
# Few-shot examples for each specialty
EXAMPLES = {
"Cardiology": [
"chest pain", "shortness of breath", "palpitations",
"high blood pressure", "irregular heartbeat"
],
"Neurology": [
"headache", "dizziness", "numbness in limbs",
"seizures", "memory problems"
],
"Oncology": [
"unexplained weight loss", "persistent lumps",
"unusual bleeding", "chronic fatigue", "skin changes"
],
# Add more examples for other specialties...
}
# Initialize models
classifier = pipeline(
"text-classification",
model=MODEL_NAME,
tokenizer=MODEL_NAME
)
embedder = SentenceTransformer('all-MiniLM-L6-v2')
def predict_specialty(symptoms):
"""
Predict the most relevant medical specialty based on symptoms.
Uses both classification and semantic similarity for robust predictions.
"""
# Get classification prediction
pred = classifier(symptoms)
predicted_class = pred[0]['label']
# Enhance with semantic similarity to example symptoms
symptom_embedding = embedder.encode(symptoms, convert_to_tensor=True)
similarities = {}
for specialty, examples in EXAMPLES.items():
example_embeddings = embedder.encode(examples, convert_to_tensor=True)
similarity = util.pytorch_cos_sim(symptom_embedding, example_embeddings).mean().item()
similarities[specialty] = similarity
# Combine predictions (simple average for demo)
combined_scores = {}
for specialty in CLASSES:
class_score = 1.0 if specialty == predicted_class else 0.0
sim_score = similarities.get(specialty, 0.0)
combined_scores[specialty] = (class_score + sim_score) / 2
# Get top 3 predictions
sorted_specialties = sorted(combined_scores.items(), key=lambda x: x[1], reverse=True)
top_3 = sorted_specialties[:3]
# Format output
result = {
"Primary Specialty": top_3[0][0],
"Confidence": f"{top_3[0][1]*100:.1f}%",
"Alternative Suggestions": [s[0] for s in top_3[1:]]
}
return result
# Create Gradio interface
demo = gr.Interface(
fn=predict_specialty,
inputs=gr.Textbox(label="Describe your symptoms", placeholder="e.g., chest pain and shortness of breath..."),
outputs=[
gr.Label(label="Primary Specialty"),
gr.Textbox(label="Confidence"),
gr.JSON(label="Alternative Suggestions")
],
examples=[
["chest pain and dizziness"],
["persistent headaches with nausea"],
["unexplained weight loss and fatigue"],
["skin rash and itching"]
],
title="Medical Specialty Classifier",
description="Enter your symptoms to find the most relevant medical specialty. Note: This is for educational purposes only and not a substitute for professional medical advice."
)
demo.launch()