Spaces:
Sleeping
Sleeping
# -*- coding: utf-8 -*- | |
# π₯ Gemma 3N SOAP Note Generator with Unsloth | |
# Optimized for offline medical documentation | |
import torch | |
import gradio as gr | |
import io | |
import base64 | |
from datetime import datetime | |
import os | |
import easyocr | |
from PIL import Image, ImageDraw, ImageFont | |
import cv2 | |
import numpy as np | |
import psutil | |
# Import Unsloth for optimized Gemma 3n | |
try: | |
from unsloth import FastModel | |
print("β Unsloth imported successfully") | |
UNSLOTH_AVAILABLE = True | |
except ImportError: | |
print("β Unsloth not available. Install with: pip install unsloth") | |
UNSLOTH_AVAILABLE = False | |
# Device setup | |
def setup_device(): | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"π₯οΈ Using device: {device}") | |
if torch.cuda.is_available(): | |
print(f"π GPU: {torch.cuda.get_device_name(0)}") | |
print(f"πΎ GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB") | |
else: | |
print("β οΈ Running on CPU - will be slower but works offline") | |
return device | |
# Load Unsloth Gemma 3n model | |
def load_unsloth_gemma_model(device): | |
"""Load optimized Gemma 3n model using Unsloth""" | |
if not UNSLOTH_AVAILABLE: | |
print("β Unsloth not available. Using fallback method.") | |
return load_fallback_model() | |
try: | |
print("π‘ Loading Unsloth-optimized Gemma 3n model...") | |
# Use the 4-bit quantized model for efficiency | |
model_name = "unsloth/gemma-3n-E4B-it-unsloth-bnb-4bit" | |
print(f"π§ Loading model: {model_name}") | |
# Load with Unsloth optimizations | |
model, tokenizer = FastModel.from_pretrained( | |
model_name=model_name, | |
dtype=None, # Auto-detect | |
max_seq_length=1024, # Good for medical notes | |
load_in_4bit=True, # 4-bit quantization for efficiency | |
full_finetuning=False, | |
) | |
print("β Unsloth Gemma 3n model loaded successfully!") | |
print(f"π Model: {model_name}") | |
print(f"πΎ Memory optimized with 4-bit quantization") | |
print(f"π― Ready for medical SOAP note generation!") | |
return model, tokenizer | |
except Exception as e: | |
print(f"β Error loading Unsloth model: {e}") | |
print("π‘ Trying fallback model...") | |
return load_fallback_model() | |
def load_fallback_model(): | |
"""Fallback model if Unsloth fails""" | |
try: | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
print("π Loading fallback model...") | |
model_name = "microsoft/DialoGPT-medium" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
low_cpu_mem_usage=True | |
) | |
print("β Fallback model loaded!") | |
return model, tokenizer | |
except Exception as e: | |
print(f"β Fallback model also failed: {e}") | |
return None, None | |
# Enhanced SOAP Note Generation with Gemma 3n | |
def generate_soap_note_gemma(doctor_notes, model=None, tokenizer=None, include_timestamp=True): | |
"""Generate SOAP note using Gemma 3n model""" | |
if not doctor_notes.strip(): | |
return "β Please enter some medical notes to process." | |
if model is None or tokenizer is None: | |
return generate_template_soap(doctor_notes, include_timestamp) | |
# Medical-specific prompt for Gemma 3n | |
prompt = f"""<bos><start_of_turn>user | |
You are a medical AI assistant specialized in creating SOAP notes. Convert the following unstructured medical notes into a professional SOAP note format. | |
Medical Notes: | |
{doctor_notes} | |
Please create a structured SOAP note with these sections: | |
- SUBJECTIVE: Patient's reported symptoms, complaints, and relevant history | |
- OBJECTIVE: Physical examination findings, vital signs, and observable data | |
- ASSESSMENT: Clinical diagnosis, differential diagnosis, and medical reasoning | |
- PLAN: Treatment recommendations, medications, tests, and follow-up care | |
<end_of_turn> | |
<start_of_turn>model | |
SOAP NOTE: | |
SUBJECTIVE:""" | |
try: | |
# Tokenize input | |
inputs = tokenizer( | |
prompt, | |
return_tensors="pt", | |
truncation=True, | |
max_length=512, | |
padding=True | |
) | |
# Generate with optimized settings for medical text | |
with torch.no_grad(): | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=400, | |
temperature=0.2, # Lower temperature for medical precision | |
top_p=0.9, | |
do_sample=True, | |
repetition_penalty=1.1, | |
pad_token_id=tokenizer.eos_token_id, | |
eos_token_id=tokenizer.eos_token_id | |
) | |
# Decode response | |
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Extract only the SOAP note part | |
if "SOAP NOTE:" in generated_text: | |
soap_response = generated_text.split("SOAP NOTE:")[1].strip() | |
else: | |
soap_response = generated_text[len(prompt):].strip() | |
# Clean up response | |
soap_response = clean_soap_response(soap_response) | |
# Add professional header | |
if include_timestamp: | |
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
header = f"""π SOAP NOTE - Generated by Gemma 3n | |
π Timestamp: {timestamp} | |
π€ Model: Unsloth-optimized Gemma 3n (4-bit quantized) | |
π Processed locally on device | |
π₯ Medical Documentation Assistant | |
{'='*60} | |
""" | |
return header + soap_response | |
return soap_response | |
except Exception as e: | |
print(f"β Generation error: {e}") | |
return generate_template_soap(doctor_notes, include_timestamp) | |
def clean_soap_response(response): | |
"""Clean and format SOAP note response""" | |
# Remove any incomplete sentences at the end | |
lines = response.split('\n') | |
cleaned_lines = [] | |
for line in lines: | |
line = line.strip() | |
if line: | |
# Ensure proper SOAP section headers | |
if line.upper().startswith(('SUBJECTIVE', 'OBJECTIVE', 'ASSESSMENT', 'PLAN')): | |
if not line.endswith(':'): | |
line += ':' | |
cleaned_lines.append(f"\n{line}") | |
else: | |
cleaned_lines.append(line) | |
return '\n'.join(cleaned_lines).strip() | |
# Template-based SOAP generation (enhanced fallback) | |
def generate_template_soap(doctor_notes, include_timestamp=True): | |
"""Enhanced template-based SOAP note generation""" | |
notes_lower = doctor_notes.lower() | |
lines = doctor_notes.split('\n') | |
# Enhanced keyword extraction | |
subjective_info = extract_section_info(lines, [ | |
'complains', 'reports', 'states', 'denies', 'pain', 'symptoms', | |
'history', 'onset', 'duration', 'patient says', 'chief complaint' | |
]) | |
objective_info = extract_section_info(lines, [ | |
'vital signs', 'vs:', 'bp', 'hr', 'temp', 'examination', 'exam', | |
'physical', 'inspection', 'palpation', 'auscultation', 'laboratory' | |
]) | |
assessment_info = extract_section_info(lines, [ | |
'diagnosis', 'impression', 'assessment', 'likely', 'possible', | |
'rule out', 'differential', 'icd', 'condition' | |
]) | |
plan_info = extract_section_info(lines, [ | |
'plan', 'treatment', 'medication', 'prescribe', 'follow', 'return', | |
'therapy', 'intervention', 'monitoring', 'referral' | |
]) | |
# Build comprehensive SOAP note | |
soap_note = build_soap_sections(subjective_info, objective_info, assessment_info, plan_info) | |
if include_timestamp: | |
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
header = f"""π SOAP NOTE (Template-Enhanced) | |
π Timestamp: {timestamp} | |
π Processed locally - HIPAA compliant | |
π₯ Scribbled Docs Medical Assistant | |
{'='*60} | |
""" | |
return header + soap_note | |
return soap_note | |
def extract_section_info(lines, keywords): | |
"""Extract relevant lines for each SOAP section""" | |
relevant_lines = [] | |
for line in lines: | |
if any(keyword in line.lower() for keyword in keywords): | |
relevant_lines.append(line.strip()) | |
return relevant_lines | |
def build_soap_sections(subjective, objective, assessment, plan): | |
"""Build formatted SOAP sections""" | |
soap = "SUBJECTIVE:\n" | |
if subjective: | |
soap += '\n'.join(f"β’ {line}" for line in subjective[:5]) # Limit to 5 most relevant | |
else: | |
soap += "β’ Patient complaints and reported symptoms as documented" | |
soap += "\n\nOBJECTIVE:\n" | |
if objective: | |
soap += '\n'.join(f"β’ {line}" for line in objective[:5]) | |
else: | |
soap += "β’ Physical examination findings and clinical observations as documented" | |
soap += "\n\nASSESSMENT:\n" | |
if assessment: | |
soap += '\n'.join(f"β’ {line}" for line in assessment[:3]) | |
else: | |
soap += "β’ Clinical assessment based on presenting symptoms and examination findings" | |
soap += "\n\nPLAN:\n" | |
if plan: | |
soap += '\n'.join(f"β’ {line}" for line in plan[:5]) | |
else: | |
soap += "β’ Treatment plan and follow-up care as clinically indicated" | |
return soap | |
# OCR Functions (same as before but optimized) | |
def initialize_ocr(): | |
"""Initialize OCR reader for handwritten notes""" | |
try: | |
# Initialize with English and medical text optimization | |
reader = easyocr.Reader(['en'], gpu=torch.cuda.is_available()) | |
print("β EasyOCR initialized for handwritten medical notes") | |
return reader | |
except Exception as e: | |
print(f"β οΈ EasyOCR initialization failed: {e}") | |
return None | |
def extract_text_from_image(image, ocr_reader=None): | |
"""Enhanced OCR for medical handwriting""" | |
if image is None: | |
return "β No image provided" | |
try: | |
# Preprocess specifically for medical handwriting | |
processed_img = preprocess_medical_image(image) | |
extracted_text = "" | |
# Try EasyOCR (better for handwritten text) | |
if ocr_reader is not None: | |
try: | |
results = ocr_reader.readtext(processed_img, detail=0, paragraph=True) | |
if results: | |
extracted_text = ' '.join(results) | |
if len(extracted_text.strip()) > 10: | |
return clean_medical_text(extracted_text) | |
except Exception as e: | |
print(f"EasyOCR failed: {e}") | |
# Fallback to Tesseract with medical optimization | |
try: | |
import pytesseract | |
# Medical-optimized Tesseract config | |
custom_config = r'--oem 3 --psm 6 -c tessedit_char_whitelist=ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789.,;:()[]{}/-+= ' | |
tesseract_text = pytesseract.image_to_string(processed_img, config=custom_config) | |
if len(tesseract_text.strip()) > 5: | |
return clean_medical_text(tesseract_text) | |
except Exception as e: | |
print(f"Tesseract failed: {e}") | |
return "β Could not extract text from image. Please ensure the image is clear and try again." | |
except Exception as e: | |
return f"β Error processing image: {str(e)}" | |
def preprocess_medical_image(image): | |
"""Optimized preprocessing for medical handwriting""" | |
try: | |
img_array = np.array(image) | |
# Convert to grayscale | |
if len(img_array.shape) == 3: | |
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY) | |
else: | |
gray = img_array | |
# Resize for optimal OCR (medical notes are often small) | |
height, width = gray.shape | |
if height < 400 or width < 400: | |
scale_factor = max(400/height, 400/width) | |
new_width = int(width * scale_factor) | |
new_height = int(height * scale_factor) | |
gray = cv2.resize(gray, (new_width, new_height), interpolation=cv2.INTER_CUBIC) | |
# Advanced preprocessing for handwritten medical text | |
# 1. Noise reduction | |
denoised = cv2.fastNlMeansDenoising(gray) | |
# 2. Contrast enhancement specifically for handwriting | |
clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8,8)) | |
enhanced = clahe.apply(denoised) | |
# 3. Morphological operations to clean up text | |
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1,1)) | |
cleaned = cv2.morphologyEx(enhanced, cv2.MORPH_CLOSE, kernel) | |
# 4. Adaptive thresholding (better for varying lighting) | |
thresh = cv2.adaptiveThreshold( | |
cleaned, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 11, 2 | |
) | |
return thresh | |
except Exception as e: | |
print(f"β Image preprocessing error: {e}") | |
return np.array(image) | |
def clean_medical_text(text): | |
"""Clean extracted text with medical context awareness""" | |
# Remove excessive whitespace and empty lines | |
lines = [line.strip() for line in text.split('\n') if line.strip()] | |
# Medical text cleaning | |
cleaned_lines = [] | |
for line in lines: | |
# Remove obvious OCR artifacts | |
line = line.replace('|', 'l').replace('_', ' ').replace('~', '-') | |
# Fix common medical abbreviations that OCR might misread | |
medical_corrections = { | |
'BP': 'BP', 'HR': 'HR', 'RR': 'RR', 'O2': 'O2', | |
'mg': 'mg', 'ml': 'ml', 'cc': 'cc', 'cm': 'cm' | |
} | |
for wrong, correct in medical_corrections.items(): | |
line = line.replace(wrong.lower(), correct) | |
if len(line) > 1: # Filter out single characters | |
cleaned_lines.append(line) | |
return '\n'.join(cleaned_lines) | |
# Enhanced Gradio Interface | |
def gradio_generate_soap(medical_notes, uploaded_image, model_data): | |
"""Main Gradio interface function""" | |
model, tokenizer = model_data if model_data else (None, None) | |
ocr_reader = getattr(gradio_generate_soap, 'ocr_reader', None) | |
text_to_process = medical_notes.strip() if medical_notes else "" | |
# Process uploaded image with enhanced OCR | |
if uploaded_image is not None: | |
try: | |
print("π Extracting text from medical image...") | |
extracted_text = extract_text_from_image(uploaded_image, ocr_reader) | |
if not extracted_text.startswith("β"): | |
if not text_to_process: | |
text_to_process = f"--- Extracted from uploaded image ---\n{extracted_text}" | |
else: | |
text_to_process = f"{text_to_process}\n\n--- Additional text from image ---\n{extracted_text}" | |
else: | |
return extracted_text | |
except Exception as e: | |
return f"β Error processing image: {str(e)}" | |
if not text_to_process: | |
return "β Please enter medical notes manually or upload an image with medical text" | |
# Generate SOAP note using Gemma 3n | |
try: | |
return generate_soap_note_gemma(text_to_process, model, tokenizer) | |
except Exception as e: | |
return f"β Error generating SOAP note: {str(e)}" | |
# Example medical notes for testing | |
medical_examples = { | |
'chest_pain': """Patient: John Smith, 45yo M | |
CC: Chest pain x 2 hours | |
HPI: Sudden onset sharp substernal chest pain 7/10, radiating to L arm. Associated SOB, diaphoresis. No N/V. | |
PMH: HTN, no CAD | |
VS: BP 150/90, HR 110, RR 22, O2 96% RA | |
PE: Anxious, diaphoretic. RRR, no murmur. CTAB. No edema. | |
A: Acute chest pain, r/o MI | |
P: EKG, troponins, CXR, ASA 325mg, monitor""", | |
'diabetes': """Patient: Maria Garcia, 52yo F | |
CC: Increased thirst, urination x 3 weeks | |
HPI: Polyuria, polydipsia, 10lb weight loss. FH DM. No fever, abd pain. | |
VS: BP 140/85, HR 88, BMI 28 | |
PE: Mild dehydration, dry MM. RRR. No diabetic foot changes. | |
Labs: Random glucose 280, HbA1c pending | |
A: New onset DM Type 2 | |
P: HbA1c, CMP, diabetic education, metformin, f/u 2 weeks""", | |
'pediatric': """Patient: Emma Thompson, 8yo F | |
CC: Fever, sore throat x 2 days | |
HPI: Fever 102F, sore throat, odynophagia, decreased appetite. No cough/rhinorrhea. | |
VS: T 101.8F, HR 110, RR 20, O2 99% | |
PE: Alert, mildly ill. Throat erythematous w/ tonsillar exudate. Anterior cervical LAD. | |
A: Strep pharyngitis (probable) | |
P: Rapid strep, throat culture, amoxicillin if +, supportive care, RTC PRN""" | |
} | |
# Initialize everything | |
def initialize_app(): | |
"""Initialize the complete application""" | |
print("π Initializing Scribbled Docs SOAP Generator...") | |
# Setup device | |
device = setup_device() | |
# Load model | |
model, tokenizer = load_unsloth_gemma_model(device) | |
# Initialize OCR | |
ocr_reader = initialize_ocr() | |
gradio_generate_soap.ocr_reader = ocr_reader | |
return model, tokenizer | |
# Create the main Gradio interface | |
def create_interface(model, tokenizer): | |
"""Create the main Gradio interface""" | |
interface = gr.Interface( | |
fn=lambda notes, image: gradio_generate_soap(notes, image, (model, tokenizer)), | |
inputs=[ | |
gr.Textbox( | |
lines=8, | |
placeholder="Enter medical notes here...\n\nExample:\nPatient: John Doe, 45yo M\nCC: Chest pain\nVS: BP 140/90, HR 88\n...", | |
label="π Medical Notes (Manual Entry)", | |
info="Enter unstructured medical notes or upload an image below" | |
), | |
gr.Image( | |
type="pil", | |
label="π· Upload Medical Image (Handwritten/Typed Notes)", | |
sources=["upload", "webcam"], | |
info="Upload PNG/JPG images of medical notes - handwritten or typed" | |
) | |
], | |
outputs=[ | |
gr.Textbox( | |
lines=20, | |
label="π Generated SOAP Note", | |
show_copy_button=True, | |
info="Professional SOAP note generated from your input" | |
) | |
], | |
title="π₯ Scribbled Docs - Medical SOAP Note Generator", | |
description=""" | |
**Transform medical notes into professional SOAP documentation using Gemma 3n AI** | |
π **100% Offline & HIPAA Compliant** - All processing happens locally on your device | |
π€ **Powered by Unsloth-optimized Gemma 3n** - 4-bit quantized for efficiency | |
π **Supports handwritten & typed notes** - Advanced OCR for medical handwriting | |
**Instructions:** | |
1. Enter medical notes manually OR upload an image | |
2. Click Submit to generate a structured SOAP note | |
3. Copy the result for use in your medical records | |
**Perfect for:** Emergency medicine, family practice, internal medicine, pediatrics | |
""", | |
examples=[ | |
[medical_examples['chest_pain'], None], | |
[medical_examples['diabetes'], None], | |
[medical_examples['pediatric'], None] | |
], | |
theme=gr.themes.Soft( | |
primary_hue="blue", | |
secondary_hue="green" | |
), | |
allow_flagging="never", | |
analytics_enabled=False | |
) | |
return interface | |
# Main execution | |
if __name__ == "__main__": | |
try: | |
# Initialize app | |
model, tokenizer = initialize_app() | |
# Create and launch interface | |
interface = create_interface(model, tokenizer) | |
print("\nπ― Scribbled Docs SOAP Generator Ready!") | |
print("π± Features:") | |
print(" β Offline processing (HIPAA compliant)") | |
print(" β Unsloth-optimized Gemma 3n model") | |
print(" β Handwritten note OCR") | |
print(" β Professional SOAP formatting") | |
print(" β Medical terminology aware") | |
# Launch interface | |
interface.launch( | |
share=True, # Creates public link | |
server_port=7860, | |
show_error=True, | |
quiet=False | |
) | |
except Exception as e: | |
print(f"β Error launching application: {e}") | |
print("π‘ Make sure you have installed: pip install unsloth gradio easyocr opencv-python") |