import gradio as gr from transformers import pipeline from sentence_transformers import SentenceTransformer, util import numpy as np # Load a freely available medical model from Hugging Face # We'll use a smaller model for faster inference in the Space MODEL_NAME = "d4data/biomedical-clinical-trials-bert" CLASSES = [ "Cardiology", "Neurology", "Oncology", "Pediatrics", "Orthopedics", "Dermatology", "Gastroenterology", "Endocrinology", "Psychiatry", "Pulmonology" ] # Few-shot examples for each specialty EXAMPLES = { "Cardiology": [ "chest pain", "shortness of breath", "palpitations", "high blood pressure", "irregular heartbeat" ], "Neurology": [ "headache", "dizziness", "numbness in limbs", "seizures", "memory problems" ], "Oncology": [ "unexplained weight loss", "persistent lumps", "unusual bleeding", "chronic fatigue", "skin changes" ], # Add more examples for other specialties... } # Initialize models classifier = pipeline( "text-classification", model=MODEL_NAME, tokenizer=MODEL_NAME ) embedder = SentenceTransformer('all-MiniLM-L6-v2') def predict_specialty(symptoms): """ Predict the most relevant medical specialty based on symptoms. Uses both classification and semantic similarity for robust predictions. """ # Get classification prediction pred = classifier(symptoms) predicted_class = pred[0]['label'] # Enhance with semantic similarity to example symptoms symptom_embedding = embedder.encode(symptoms, convert_to_tensor=True) similarities = {} for specialty, examples in EXAMPLES.items(): example_embeddings = embedder.encode(examples, convert_to_tensor=True) similarity = util.pytorch_cos_sim(symptom_embedding, example_embeddings).mean().item() similarities[specialty] = similarity # Combine predictions (simple average for demo) combined_scores = {} for specialty in CLASSES: class_score = 1.0 if specialty == predicted_class else 0.0 sim_score = similarities.get(specialty, 0.0) combined_scores[specialty] = (class_score + sim_score) / 2 # Get top 3 predictions sorted_specialties = sorted(combined_scores.items(), key=lambda x: x[1], reverse=True) top_3 = sorted_specialties[:3] # Format output result = { "Primary Specialty": top_3[0][0], "Confidence": f"{top_3[0][1]*100:.1f}%", "Alternative Suggestions": [s[0] for s in top_3[1:]] } return result # Create Gradio interface demo = gr.Interface( fn=predict_specialty, inputs=gr.Textbox(label="Describe your symptoms", placeholder="e.g., chest pain and shortness of breath..."), outputs=[ gr.Label(label="Primary Specialty"), gr.Textbox(label="Confidence"), gr.JSON(label="Alternative Suggestions") ], examples=[ ["chest pain and dizziness"], ["persistent headaches with nausea"], ["unexplained weight loss and fatigue"], ["skin rash and itching"] ], title="Medical Specialty Classifier", description="Enter your symptoms to find the most relevant medical specialty. Note: This is for educational purposes only and not a substitute for professional medical advice." ) demo.launch()