import gradio as gr import torch import numpy as np from PIL import Image import time import io import subprocess import sys # Install required packages def install_packages(): packages = [ "transformers", "accelerate", "timm", "easyocr" ] for package in packages: try: subprocess.check_call([sys.executable, "-m", "pip", "install", package]) except: print(f"Warning: Could not install {package}") # Install packages at startup install_packages() from transformers import AutoProcessor, AutoModelForImageTextToText, AutoConfig # Global variables for model processor = None model = None config = None ocr_reader = None def load_model(): """Load the Gemma 3n model""" global processor, model, config, ocr_reader try: print("🚀 Loading Gemma 3n model...") GEMMA_PATH = "google/gemma-3n-e2b-it" # Load configuration config = AutoConfig.from_pretrained(GEMMA_PATH, trust_remote_code=True) print("✅ Config loaded") # Load processor processor = AutoProcessor.from_pretrained(GEMMA_PATH, trust_remote_code=True) print("✅ Processor loaded") # Load model model = AutoModelForImageTextToText.from_pretrained( GEMMA_PATH, config=config, torch_dtype="auto", device_map="auto", trust_remote_code=True ) print("✅ Model loaded successfully!") # Set up compilation fix import torch._dynamo torch._dynamo.config.suppress_errors = True # Initialize OCR try: import easyocr ocr_reader = easyocr.Reader(['en'], gpu=False, verbose=False) print("✅ EasyOCR initialized") except Exception as e: print(f"⚠️ EasyOCR not available: {e}") ocr_reader = None return True except Exception as e: print(f"❌ Model loading failed: {e}") return False def generate_soap_note(text): """Generate SOAP note using Gemma 3n""" if model is None or processor is None: return "❌ Model not loaded. Please wait for initialization." soap_prompt = f"""You are a medical AI assistant. Convert the following medical notes into a properly formatted SOAP note. Medical notes: {text} Please format as: S - SUBJECTIVE: (chief complaint, history of present illness, past medical history, medications, allergies) O - OBJECTIVE: (vital signs, physical examination findings) A - ASSESSMENT: (diagnosis/clinical impression) P - PLAN: (treatment plan, follow-up instructions) Generate a complete, professional SOAP note:""" messages = [{ "role": "system", "content": [{"type": "text", "text": "You are an expert medical AI assistant specialized in creating SOAP notes from medical documentation."}] }, { "role": "user", "content": [{"type": "text", "text": soap_prompt}] }] try: inputs = processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" ).to(model.device) input_len = inputs["input_ids"].shape[-1] with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=400, do_sample=True, temperature=0.1, top_p=0.95, pad_token_id=processor.tokenizer.eos_token_id, disable_compile=True ) response = processor.batch_decode( outputs[:, input_len:], skip_special_tokens=True )[0].strip() return response except Exception as e: return f"❌ SOAP generation failed: {str(e)}" def extract_text_from_image(image): """Extract text using EasyOCR""" if ocr_reader is None: return "❌ OCR not available" try: if hasattr(image, 'convert'): image = image.convert('RGB') img_array = np.array(image) results = ocr_reader.readtext(img_array, detail=0, paragraph=True) if results: return ' '.join(results).strip() else: return "❌ No text detected in image" except Exception as e: return f"❌ OCR failed: {str(e)}" def process_medical_input(image, text): """Main processing function for the Gradio interface""" if image is not None and text.strip(): return "⚠️ Please provide either an image OR text, not both.", "" if image is not None: # Process image print("🔍 Extracting text from image...") extracted_text = extract_text_from_image(image) if extracted_text.startswith('❌'): return extracted_text, "" print("🤖 Generating SOAP note...") soap_note = generate_soap_note(extracted_text) return extracted_text, soap_note elif text.strip(): # Process text directly print("🤖 Generating SOAP note from text...") soap_note = generate_soap_note(text.strip()) return text.strip(), soap_note else: return "❌ Please provide either an image or text input.", "" def create_demo(): """Create the Gradio demo interface""" # Sample text for demonstration sample_text = """Patient: John Smith, 45yo male CC: Chest pain Vitals: BP 140/90, HR 88, RR 16, O2 98%, Temp 98.6F HPI: Patient reports crushing chest pain x 2 hours, radiating to left arm PMH: HTN, DM Type 2 Current Meds: Lisinopril 10mg daily, Metformin 500mg BID PE: Diaphoretic, anxious appearance EKG: ST elevation in leads II, III, aVF""" with gr.Blocks(title="Medical OCR SOAP Generator", theme=gr.themes.Soft()) as demo: gr.HTML("""
Option A (OCR Demo): Download "docs-note-to-upload.jpg" from Files tab above, then upload it below
Option B (Text Demo): Click "Try Sample Medical Text" button for instant text-to-SOAP demo
⚠️ Note: First generation takes ~2-3 minutes as model loads. Subsequent ones are faster.