import gradio as gr import torch import numpy as np from PIL import Image import time import io import subprocess import sys import cv2 # Install required packages def install_packages(): packages = [ "transformers", "accelerate", "timm", "easyocr", "opencv-python" ] for package in packages: try: subprocess.check_call([sys.executable, "-m", "pip", "install", package]) except: print(f"Warning: Could not install {package}") # Install packages at startup install_packages() from transformers import AutoProcessor, AutoModelForImageTextToText, AutoConfig # Global variables for model processor = None model = None config = None ocr_reader = None def load_model(): """Load the Gemma 3n model""" global processor, model, config, ocr_reader try: print("🚀 Loading Gemma 3n model...") GEMMA_PATH = "google/gemma-3n-e2b-it" # Load configuration config = AutoConfig.from_pretrained(GEMMA_PATH, trust_remote_code=True) print("✅ Config loaded") # Load processor processor = AutoProcessor.from_pretrained(GEMMA_PATH, trust_remote_code=True) print("✅ Processor loaded") # Load model model = AutoModelForImageTextToText.from_pretrained( GEMMA_PATH, config=config, torch_dtype="auto", device_map="auto", trust_remote_code=True ) print("✅ Model loaded successfully!") # Set up compilation fix import torch._dynamo torch._dynamo.config.suppress_errors = True # Initialize OCR try: import easyocr ocr_reader = easyocr.Reader(['en'], gpu=False, verbose=False) print("✅ EasyOCR initialized") except Exception as e: print(f"⚠️ EasyOCR not available: {e}") ocr_reader = None return True except Exception as e: print(f"❌ Model loading failed: {e}") return False def generate_soap_note(text): """Generate SOAP note using Gemma 3n""" if model is None or processor is None: return "❌ Model not loaded. Please wait for initialization." soap_prompt = f"""You are a medical AI assistant. Convert the following medical notes into a properly formatted SOAP note. Medical notes: {text} Please format as: S - SUBJECTIVE: (chief complaint, history of present illness, past medical history, medications, allergies) O - OBJECTIVE: (vital signs, physical examination findings) A - ASSESSMENT: (diagnosis/clinical impression) P - PLAN: (treatment plan, follow-up instructions) Generate a complete, professional SOAP note:""" messages = [{ "role": "system", "content": [{"type": "text", "text": "You are an expert medical AI assistant specialized in creating SOAP notes from medical documentation."}] }, { "role": "user", "content": [{"type": "text", "text": soap_prompt}] }] try: inputs = processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" ).to(model.device) input_len = inputs["input_ids"].shape[-1] with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=400, do_sample=True, temperature=0.1, top_p=0.95, pad_token_id=processor.tokenizer.eos_token_id, disable_compile=True ) response = processor.batch_decode( outputs[:, input_len:], skip_special_tokens=True )[0].strip() return response except Exception as e: return f"❌ SOAP generation failed: {str(e)}" def preprocess_image_for_ocr(image): """Preprocess image for better OCR results using CLAHE""" try: if hasattr(image, 'convert'): image = image.convert('RGB') img_array = np.array(image) # Convert to grayscale if len(img_array.shape) == 3: gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY) else: gray = img_array # Resize if too small height, width = gray.shape if height < 300 or width < 300: scale = max(300/height, 300/width) new_height = int(height * scale) new_width = int(width * scale) gray = cv2.resize(gray, (new_width, new_height), interpolation=cv2.INTER_CUBIC) # Enhance image with CLAHE gray = cv2.medianBlur(gray, 3) clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8)) gray = clahe.apply(gray) _, gray = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) return gray except Exception as e: print(f"⚠️ Image preprocessing failed: {e}") # Fallback to original image if preprocessing fails return np.array(image) def extract_text_from_image(image): """Extract text using EasyOCR with CLAHE preprocessing""" if ocr_reader is None: return "❌ OCR not available" try: # Apply CLAHE preprocessing for better OCR processed_img = preprocess_image_for_ocr(image) results = ocr_reader.readtext(processed_img, detail=0, paragraph=True) if results: return ' '.join(results).strip() else: return "❌ No text detected in image" except Exception as e: return f"❌ OCR failed: {str(e)}" def process_medical_input(image, text): """Main processing function for the Gradio interface""" if image is not None and text.strip(): return "⚠️ Please provide either an image OR text, not both.", "" if image is not None: # Process image print("🔍 Extracting text from image...") extracted_text = extract_text_from_image(image) if extracted_text.startswith('❌'): return extracted_text, "" print("🤖 Generating SOAP note...") soap_note = generate_soap_note(extracted_text) return extracted_text, soap_note elif text.strip(): # Process text directly print("🤖 Generating SOAP note from text...") soap_note = generate_soap_note(text.strip()) return text.strip(), soap_note else: return "❌ Please provide either an image or text input.", "" def create_demo(): """Create the Gradio demo interface""" # Sample text for demonstration sample_text = """Patient: John Smith, 45yo male CC: Chest pain Vitals: BP 140/90, HR 88, RR 16, O2 98%, Temp 98.6F HPI: Patient reports crushing chest pain x 2 hours, radiating to left arm PMH: HTN, DM Type 2 Current Meds: Lisinopril 10mg daily, Metformin 500mg BID PE: Diaphoretic, anxious appearance EKG: ST elevation in leads II, III, aVF""" with gr.Blocks(title="Medical OCR SOAP Generator", theme=gr.themes.Soft()) as demo: gr.HTML("""
👆 Download "docs-note-to-upload.jpg" from the Files tab above, then upload it below
OR click "Try Sample Medical Text" button for instant text demo
⚠️ Note: First generation takes ~60-90 seconds as model loads. Subsequent ones are faster.