import gradio as gr import torch import numpy as np from PIL import Image import time import io import subprocess import sys # Install required packages def install_packages(): packages = [ "transformers", "accelerate", "timm", "easyocr" ] for package in packages: try: subprocess.check_call([sys.executable, "-m", "pip", "install", package]) except: print(f"Warning: Could not install {package}") # Install packages at startup install_packages() from transformers import AutoProcessor, AutoModelForImageTextToText, AutoConfig # Global variables for model processor = None model = None config = None ocr_reader = None def load_model(): """Load the Gemma 3n model""" global processor, model, config, ocr_reader try: print("🚀 Loading Gemma 3n model...") GEMMA_PATH = "google/gemma-3n-e2b-it" # Load configuration config = AutoConfig.from_pretrained(GEMMA_PATH, trust_remote_code=True) print("✅ Config loaded") # Load processor processor = AutoProcessor.from_pretrained(GEMMA_PATH, trust_remote_code=True) print("✅ Processor loaded") # Load model model = AutoModelForImageTextToText.from_pretrained( GEMMA_PATH, config=config, torch_dtype="auto", device_map="auto", trust_remote_code=True ) print("✅ Model loaded successfully!") # Set up compilation fix import torch._dynamo torch._dynamo.config.suppress_errors = True # Initialize OCR try: import easyocr ocr_reader = easyocr.Reader(['en'], gpu=False, verbose=False) print("✅ EasyOCR initialized") except Exception as e: print(f"⚠️ EasyOCR not available: {e}") ocr_reader = None return True except Exception as e: print(f"❌ Model loading failed: {e}") return False def generate_soap_note(text): """Generate SOAP note using Gemma 3n""" if model is None or processor is None: return "❌ Model not loaded. Please wait for initialization." soap_prompt = f"""You are a medical AI assistant. Convert the following medical notes into a properly formatted SOAP note. Medical notes: {text} Please format as: S - SUBJECTIVE: (chief complaint, history of present illness, past medical history, medications, allergies) O - OBJECTIVE: (vital signs, physical examination findings) A - ASSESSMENT: (diagnosis/clinical impression) P - PLAN: (treatment plan, follow-up instructions) Generate a complete, professional SOAP note:""" messages = [{ "role": "system", "content": [{"type": "text", "text": "You are an expert medical AI assistant specialized in creating SOAP notes from medical documentation."}] }, { "role": "user", "content": [{"type": "text", "text": soap_prompt}] }] try: inputs = processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" ).to(model.device) input_len = inputs["input_ids"].shape[-1] with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=400, do_sample=True, temperature=0.1, top_p=0.95, pad_token_id=processor.tokenizer.eos_token_id, disable_compile=True ) response = processor.batch_decode( outputs[:, input_len:], skip_special_tokens=True )[0].strip() return response except Exception as e: return f"❌ SOAP generation failed: {str(e)}" def extract_text_from_image(image): """Extract text using EasyOCR""" if ocr_reader is None: return "❌ OCR not available" try: if hasattr(image, 'convert'): image = image.convert('RGB') img_array = np.array(image) results = ocr_reader.readtext(img_array, detail=0, paragraph=True) if results: return ' '.join(results).strip() else: return "❌ No text detected in image" except Exception as e: return f"❌ OCR failed: {str(e)}" def process_medical_input(image, text): """Main processing function for the Gradio interface""" if image is not None and text.strip(): return "⚠️ Please provide either an image OR text, not both.", "" if image is not None: # Process image print("🔍 Extracting text from image...") extracted_text = extract_text_from_image(image) if extracted_text.startswith('❌'): return extracted_text, "" print("🤖 Generating SOAP note...") soap_note = generate_soap_note(extracted_text) return extracted_text, soap_note elif text.strip(): # Process text directly print("🤖 Generating SOAP note from text...") soap_note = generate_soap_note(text.strip()) return text.strip(), soap_note else: return "❌ Please provide either an image or text input.", "" def create_demo(): """Create the Gradio demo interface""" # Sample text for demonstration sample_text = """Patient: John Smith, 45yo male CC: Chest pain Vitals: BP 140/90, HR 88, RR 16, O2 98%, Temp 98.6F HPI: Patient reports crushing chest pain x 2 hours, radiating to left arm PMH: HTN, DM Type 2 Current Meds: Lisinopril 10mg daily, Metformin 500mg BID PE: Diaphoretic, anxious appearance EKG: ST elevation in leads II, III, aVF""" with gr.Blocks(title="Medical OCR SOAP Generator", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # 🏥 Medical OCR SOAP Generator ### Powered by Gemma 3n - Convert handwritten medical notes to professional SOAP format **Instructions:** - **Option 1:** Upload an image of handwritten medical notes - **Option 2:** Enter medical text directly - The system will generate a properly formatted SOAP note ⚠️ **Note:** First generation may take ~60-90 seconds as the model loads """) with gr.Row(): with gr.Column(): image_input = gr.Image( type="pil", label="📷 Upload Medical Image", height=300 ) text_input = gr.Textbox( label="📝 Or Enter Medical Text", placeholder=sample_text, lines=8, max_lines=15 ) submit_btn = gr.Button( "Generate SOAP Note", variant="primary", size="lg" ) with gr.Column(): extracted_output = gr.Textbox( label="📋 Extracted/Input Text", lines=6, max_lines=10 ) soap_output = gr.Textbox( label="🏥 Generated SOAP Note", lines=12, max_lines=20 ) # Example section gr.Markdown("### 📋 Quick Test Example") example_btn = gr.Button("Try Sample Medical Text", variant="secondary") def load_example(): return sample_text, None example_btn.click( load_example, outputs=[text_input, image_input] ) # Process function submit_btn.click( process_medical_input, inputs=[image_input, text_input], outputs=[extracted_output, soap_output] ) gr.Markdown(""" --- **About:** This application uses Google's Gemma 3n model for medical text understanding and EasyOCR for handwriting recognition. All processing is done locally for HIPAA compliance. **Competition Entry:** Medical AI Innovation Challenge 2024 """) return demo # Initialize the application if __name__ == "__main__": print("🚀 Starting Medical OCR SOAP Generator...") # Load model model_loaded = load_model() if model_loaded: print("✅ All systems ready!") demo = create_demo() demo.launch( share=True, server_name="0.0.0.0", server_port=7860 ) else: print("❌ Failed to load model. Creating fallback demo...") def fallback_demo(): return "❌ Model loading failed. Please check the logs.", "❌ Model not available." demo = gr.Interface( fn=fallback_demo, inputs=[ gr.Image(type="pil", label="Upload Medical Image"), gr.Textbox(label="Enter Medical Text", lines=5) ], outputs=[ gr.Textbox(label="Status"), gr.Textbox(label="Error Message") ], title="❌ Medical OCR - Model Loading Failed" ) demo.launch(share=True)