File size: 3,345 Bytes
ce729b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import gradio as gr
from transformers import pipeline
from sentence_transformers import SentenceTransformer, util
import numpy as np

# Load a freely available medical model from Hugging Face
# We'll use a smaller model for faster inference in the Space
MODEL_NAME = "d4data/biomedical-clinical-trials-bert"
CLASSES = [
    "Cardiology", "Neurology", "Oncology", "Pediatrics", 
    "Orthopedics", "Dermatology", "Gastroenterology", 
    "Endocrinology", "Psychiatry", "Pulmonology"
]

# Few-shot examples for each specialty
EXAMPLES = {
    "Cardiology": [
        "chest pain", "shortness of breath", "palpitations",
        "high blood pressure", "irregular heartbeat"
    ],
    "Neurology": [
        "headache", "dizziness", "numbness in limbs",
        "seizures", "memory problems"
    ],
    "Oncology": [
        "unexplained weight loss", "persistent lumps",
        "unusual bleeding", "chronic fatigue", "skin changes"
    ],
    # Add more examples for other specialties...
}

# Initialize models
classifier = pipeline(
    "text-classification", 
    model=MODEL_NAME,
    tokenizer=MODEL_NAME
)
embedder = SentenceTransformer('all-MiniLM-L6-v2')

def predict_specialty(symptoms):
    """
    Predict the most relevant medical specialty based on symptoms.
    Uses both classification and semantic similarity for robust predictions.
    """
    # Get classification prediction
    pred = classifier(symptoms)
    predicted_class = pred[0]['label']
    
    # Enhance with semantic similarity to example symptoms
    symptom_embedding = embedder.encode(symptoms, convert_to_tensor=True)
    
    similarities = {}
    for specialty, examples in EXAMPLES.items():
        example_embeddings = embedder.encode(examples, convert_to_tensor=True)
        similarity = util.pytorch_cos_sim(symptom_embedding, example_embeddings).mean().item()
        similarities[specialty] = similarity
    
    # Combine predictions (simple average for demo)
    combined_scores = {}
    for specialty in CLASSES:
        class_score = 1.0 if specialty == predicted_class else 0.0
        sim_score = similarities.get(specialty, 0.0)
        combined_scores[specialty] = (class_score + sim_score) / 2
    
    # Get top 3 predictions
    sorted_specialties = sorted(combined_scores.items(), key=lambda x: x[1], reverse=True)
    top_3 = sorted_specialties[:3]
    
    # Format output
    result = {
        "Primary Specialty": top_3[0][0],
        "Confidence": f"{top_3[0][1]*100:.1f}%",
        "Alternative Suggestions": [s[0] for s in top_3[1:]]
    }
    
    return result

# Create Gradio interface
demo = gr.Interface(
    fn=predict_specialty,
    inputs=gr.Textbox(label="Describe your symptoms", placeholder="e.g., chest pain and shortness of breath..."),
    outputs=[
        gr.Label(label="Primary Specialty"),
        gr.Textbox(label="Confidence"),
        gr.JSON(label="Alternative Suggestions")
    ],
    examples=[
        ["chest pain and dizziness"],
        ["persistent headaches with nausea"],
        ["unexplained weight loss and fatigue"],
        ["skin rash and itching"]
    ],
    title="Medical Specialty Classifier",
    description="Enter your symptoms to find the most relevant medical specialty. Note: This is for educational purposes only and not a substitute for professional medical advice."
)

demo.launch()