Spaces:
Sleeping
Sleeping
import time | |
import gradio as gr | |
import torch | |
import numpy as np | |
from PIL import Image | |
import time | |
import io | |
import subprocess | |
import sys | |
# Install required packages | |
def install_packages(): | |
packages = [ | |
"transformers", | |
"accelerate", | |
"timm", | |
"easyocr" | |
] | |
for package in packages: | |
try: | |
subprocess.check_call([sys.executable, "-m", "pip", "install", package]) | |
except: | |
print(f"Warning: Could not install {package}") | |
# Install packages at startup | |
install_packages() | |
from transformers import AutoProcessor, AutoModelForImageTextToText, AutoConfig | |
# Global variables for model | |
processor = None | |
model = None | |
config = None | |
ocr_reader = None | |
def load_model(): | |
"""Load the Gemma 3n model""" | |
global processor, model, config, ocr_reader | |
try: | |
print("π Loading Gemma 3n model...") | |
GEMMA_PATH = "google/gemma-3n-e2b-it" | |
# Load configuration | |
config = AutoConfig.from_pretrained(GEMMA_PATH, trust_remote_code=True) | |
print("β Config loaded") | |
# Load processor | |
processor = AutoProcessor.from_pretrained(GEMMA_PATH, trust_remote_code=True) | |
print("β Processor loaded") | |
# Load model | |
model = AutoModelForImageTextToText.from_pretrained( | |
GEMMA_PATH, | |
config=config, | |
torch_dtype="auto", | |
device_map="auto", | |
trust_remote_code=True | |
) | |
print("β Model loaded successfully!") | |
# Set up compilation fix | |
import torch._dynamo | |
torch._dynamo.config.suppress_errors = True | |
# Initialize OCR | |
try: | |
import easyocr | |
ocr_reader = easyocr.Reader(['en'], gpu=False, verbose=False) | |
print("β EasyOCR initialized") | |
except Exception as e: | |
print(f"β οΈ EasyOCR not available: {e}") | |
ocr_reader = None | |
return True | |
except Exception as e: | |
print(f"β Model loading failed: {e}") | |
return False | |
def generate_soap_note(text): | |
"""Generate SOAP note using Gemma 3n""" | |
if model is None or processor is None: | |
return "β Model not loaded. Please wait for initialization." | |
soap_prompt = f"""You are a medical AI assistant. Convert the following medical notes into a properly formatted SOAP note. | |
Medical notes: | |
{text} | |
Please format as: | |
S - SUBJECTIVE: (chief complaint, history of present illness, past medical history, medications, allergies) | |
O - OBJECTIVE: (vital signs, physical examination findings) | |
A - ASSESSMENT: (diagnosis/clinical impression) | |
P - PLAN: (treatment plan, follow-up instructions) | |
Generate a complete, professional SOAP note:""" | |
messages = [{ | |
"role": "system", | |
"content": [{"type": "text", "text": "You are an expert medical AI assistant specialized in creating SOAP notes from medical documentation."}] | |
}, { | |
"role": "user", | |
"content": [{"type": "text", "text": soap_prompt}] | |
}] | |
try: | |
inputs = processor.apply_chat_template( | |
messages, | |
add_generation_prompt=True, | |
tokenize=True, | |
return_dict=True, | |
return_tensors="pt" | |
).to(model.device) | |
input_len = inputs["input_ids"].shape[-1] | |
with torch.no_grad(): | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=400, | |
do_sample=True, | |
temperature=0.1, | |
top_p=0.95, | |
pad_token_id=processor.tokenizer.eos_token_id, | |
disable_compile=True | |
) | |
response = processor.batch_decode( | |
outputs[:, input_len:], | |
skip_special_tokens=True | |
)[0].strip() | |
return response | |
except Exception as e: | |
return f"β SOAP generation failed: {str(e)}" | |
def extract_text_from_image(image): | |
"""Extract text using EasyOCR - fast processing""" | |
if ocr_reader is None: | |
return "β OCR not available" | |
try: | |
if hasattr(image, 'convert'): | |
image = image.convert('RGB') | |
img_array = np.array(image) | |
results = ocr_reader.readtext(img_array, detail=0, paragraph=True) | |
if results: | |
return ' '.join(results).strip() | |
else: | |
return "β No text detected in image" | |
except Exception as e: | |
return f"β OCR failed: {str(e)}" | |
def process_medical_input(image, text): | |
"""Main processing function for the Gradio interface""" | |
if image is not None and text.strip(): | |
return "β οΈ Please provide either an image OR text, not both.", "" | |
if image is not None: | |
# Process image | |
print("π Extracting text from image...") | |
extracted_text = extract_text_from_image(image) | |
if extracted_text.startswith('β'): | |
return extracted_text, "" | |
print("π€ Generating SOAP note...") | |
soap_note = generate_soap_note(extracted_text) | |
return extracted_text, soap_note | |
elif text.strip(): | |
# Process text directly | |
print("π€ Generating SOAP note from text...") | |
soap_note = generate_soap_note(text.strip()) | |
return text.strip(), soap_note | |
else: | |
return "β Please provide either an image or text input.", "" | |
def create_demo(): | |
"""Create the Gradio demo interface""" | |
# Sample text for demonstration | |
sample_text = """Patient: John Smith, 45yo male | |
CC: Chest pain | |
Vitals: BP 140/90, HR 88, RR 16, O2 98%, Temp 98.6F | |
HPI: Patient reports crushing chest pain x 2 hours, radiating to left arm | |
PMH: HTN, DM Type 2 | |
Current Meds: Lisinopril 10mg daily, Metformin 500mg BID | |
PE: Diaphoretic, anxious appearance | |
EKG: ST elevation in leads II, III, aVF""" | |
with gr.Blocks(title="Medical OCR SOAP Generator", theme=gr.themes.Soft()) as demo: | |
gr.HTML(""" | |
<h1>π₯ Medical OCR SOAP Generator - LIVE DEMO</h1> | |
<h2>π― For Competition Judges - Quick 2-Minute Demo:</h2> | |
<div style="background-color: #e6f3ff; padding: 15px; border-radius: 10px; margin: 10px 0;"> | |
<h3>π SAMPLE IMAGE PROVIDED:</h3> | |
<p><strong>π Download "docs-note-to-upload.jpg" from the Files tab above, then upload it below</strong></p> | |
<p><strong>OR</strong> click "Try Sample Medical Text" button for instant text demo</p> | |
</div> | |
<h3>Demo Steps:</h3> | |
<ol> | |
<li><strong>Upload the sample image</strong> (docs-note-to-upload.jpg from Files tab) <strong>OR</strong> click sample text button</li> | |
<li><strong>Click "Generate SOAP Note"</strong></li> | |
<li><strong>Wait ~2 minutes</strong> for AI processing (first time only)</li> | |
<li><strong>See professional SOAP note</strong> generated by Gemma 3n</li> | |
</ol> | |
<h3>β What This Demo Shows:</h3> | |
<ul> | |
<li><strong>Real OCR</strong> extraction from handwritten medical notes</li> | |
<li><strong>AI-powered medical reasoning</strong> with Gemma 3n</li> | |
<li><strong>Professional SOAP formatting</strong> (Subjective, Objective, Assessment, Plan)</li> | |
<li><strong>HIPAA-compliant</strong> local processing</li> | |
</ul> | |
<p><strong>β οΈ Note:</strong> First generation takes ~2 minutes as model loads. Subsequent ones are faster.</p> | |
<hr> | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
image_input = gr.Image( | |
type="pil", | |
label="π· Upload Medical Image", | |
height=300 | |
) | |
text_input = gr.Textbox( | |
label="π Or Enter Medical Text", | |
placeholder=sample_text, | |
lines=8, | |
max_lines=15 | |
) | |
submit_btn = gr.Button( | |
"Generate SOAP Note", | |
variant="primary", | |
size="lg" | |
) | |
with gr.Column(): | |
extracted_output = gr.Textbox( | |
label="π Extracted/Input Text", | |
lines=6, | |
max_lines=10 | |
) | |
soap_output = gr.Textbox( | |
label="π₯ Generated SOAP Note", | |
lines=12, | |
max_lines=20 | |
) | |
# Example section | |
gr.Markdown("### π Quick Test Example") | |
example_btn = gr.Button("Try Sample Medical Text", variant="secondary") | |
def load_example(): | |
return sample_text, None | |
example_btn.click( | |
load_example, | |
outputs=[text_input, image_input] | |
) | |
# Process function | |
submit_btn.click( | |
process_medical_input, | |
inputs=[image_input, text_input], | |
outputs=[extracted_output, soap_output] | |
) | |
gr.Markdown(""" | |
--- | |
**About:** This application uses Google's Gemma 3n model for medical text understanding and EasyOCR for handwriting recognition. | |
All processing is done locally for HIPAA compliance. | |
**Competition Entry:** Medical AI Innovation Challenge 2024 | |
""") | |
return demo | |
# Initialize the application | |
if __name__ == "__main__": | |
print("π Starting Medical OCR SOAP Generator...") | |
# Load model | |
model_loaded = load_model() | |
if model_loaded: | |
print("β All systems ready!") | |
demo = create_demo() | |
demo.launch( | |
share=True, | |
server_name="0.0.0.0", | |
server_port=7860 | |
) | |
else: | |
print("β Failed to load model. Creating fallback demo...") | |
def fallback_demo(): | |
return "β Model loading failed. Please check the logs.", "β Model not available." | |
demo = gr.Interface( | |
fn=fallback_demo, | |
inputs=[ | |
gr.Image(type="pil", label="Upload Medical Image"), | |
gr.Textbox(label="Enter Medical Text", lines=5) | |
], | |
outputs=[ | |
gr.Textbox(label="Status"), | |
gr.Textbox(label="Error Message") | |
], | |
title="β Medical OCR - Model Loading Failed" | |
) | |
demo.launch(share=True) | |
import io | |
import subprocess | |
import sys | |
import cv2 | |
# Install required packages | |
def install_packages(): | |
packages = [ | |
"transformers", | |
"accelerate", | |
"timm", | |
"easyocr", | |
"opencv-python" | |
] | |
for package in packages: | |
try: | |
subprocess.check_call([sys.executable, "-m", "pip", "install", package]) | |
except: | |
print(f"Warning: Could not install {package}") | |
# Install packages at startup | |
install_packages() | |
from transformers import AutoProcessor, AutoModelForImageTextToText, AutoConfig | |
# Global variables for model | |
processor = None | |
model = None | |
config = None | |
ocr_reader = None | |
def load_model(): | |
"""Load the Gemma 3n model""" | |
global processor, model, config, ocr_reader | |
try: | |
print("π Loading Gemma 3n model...") | |
GEMMA_PATH = "google/gemma-3n-e2b-it" | |
# Load configuration | |
config = AutoConfig.from_pretrained(GEMMA_PATH, trust_remote_code=True) | |
print("β Config loaded") | |
# Load processor | |
processor = AutoProcessor.from_pretrained(GEMMA_PATH, trust_remote_code=True) | |
print("β Processor loaded") | |
# Load model | |
model = AutoModelForImageTextToText.from_pretrained( | |
GEMMA_PATH, | |
config=config, | |
torch_dtype="auto", | |
device_map="auto", | |
trust_remote_code=True | |
) | |
print("β Model loaded successfully!") | |
# Set up compilation fix | |
import torch._dynamo | |
torch._dynamo.config.suppress_errors = True | |
# Initialize OCR | |
try: | |
import easyocr | |
ocr_reader = easyocr.Reader(['en'], gpu=False, verbose=False) | |
print("β EasyOCR initialized") | |
except Exception as e: | |
print(f"β οΈ EasyOCR not available: {e}") | |
ocr_reader = None | |
return True | |
except Exception as e: | |
print(f"β Model loading failed: {e}") | |
return False | |
def generate_soap_note(text): | |
"""Generate SOAP note using Gemma 3n""" | |
if model is None or processor is None: | |
return "β Model not loaded. Please wait for initialization." | |
soap_prompt = f"""You are a medical AI assistant. Convert the following medical notes into a properly formatted SOAP note. | |
Medical notes: | |
{text} | |
Please format as: | |
S - SUBJECTIVE: (chief complaint, history of present illness, past medical history, medications, allergies) | |
O - OBJECTIVE: (vital signs, physical examination findings) | |
A - ASSESSMENT: (diagnosis/clinical impression) | |
P - PLAN: (treatment plan, follow-up instructions) | |
Generate a complete, professional SOAP note:""" | |
messages = [{ | |
"role": "system", | |
"content": [{"type": "text", "text": "You are an expert medical AI assistant specialized in creating SOAP notes from medical documentation."}] | |
}, { | |
"role": "user", | |
"content": [{"type": "text", "text": soap_prompt}] | |
}] | |
try: | |
inputs = processor.apply_chat_template( | |
messages, | |
add_generation_prompt=True, | |
tokenize=True, | |
return_dict=True, | |
return_tensors="pt" | |
).to(model.device) | |
input_len = inputs["input_ids"].shape[-1] | |
with torch.no_grad(): | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=400, | |
do_sample=True, | |
temperature=0.1, | |
top_p=0.95, | |
pad_token_id=processor.tokenizer.eos_token_id, | |
disable_compile=True | |
) | |
response = processor.batch_decode( | |
outputs[:, input_len:], | |
skip_special_tokens=True | |
)[0].strip() | |
return response | |
except Exception as e: | |
return f"β SOAP generation failed: {str(e)}" | |
def preprocess_image_for_ocr(image): | |
"""Preprocess image for better OCR results using CLAHE""" | |
try: | |
if hasattr(image, 'convert'): | |
image = image.convert('RGB') | |
img_array = np.array(image) | |
# Convert to grayscale | |
if len(img_array.shape) == 3: | |
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY) | |
else: | |
gray = img_array | |
# Resize if too small | |
height, width = gray.shape | |
if height < 300 or width < 300: | |
scale = max(300/height, 300/width) | |
new_height = int(height * scale) | |
new_width = int(width * scale) | |
gray = cv2.resize(gray, (new_width, new_height), interpolation=cv2.INTER_CUBIC) | |
# Enhance image with CLAHE | |
gray = cv2.medianBlur(gray, 3) | |
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8)) | |
gray = clahe.apply(gray) | |
_, gray = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) | |
return gray | |
except Exception as e: | |
print(f"β οΈ Image preprocessing failed: {e}") | |
# Fallback to original image if preprocessing fails | |
return np.array(image) | |
def extract_text_from_image(image): | |
"""Extract text using EasyOCR with CLAHE preprocessing""" | |
if ocr_reader is None: | |
return "β OCR not available" | |
try: | |
# Apply CLAHE preprocessing for better OCR | |
processed_img = preprocess_image_for_ocr(image) | |
results = ocr_reader.readtext(processed_img, detail=0, paragraph=True) | |
if results: | |
return ' '.join(results).strip() | |
else: | |
return "β No text detected in image" | |
except Exception as e: | |
return f"β OCR failed: {str(e)}" | |
def process_medical_input(image, text): | |
"""Main processing function for the Gradio interface""" | |
if image is not None and text.strip(): | |
return "β οΈ Please provide either an image OR text, not both.", "" | |
if image is not None: | |
# Process image | |
print("π Extracting text from image...") | |
extracted_text = extract_text_from_image(image) | |
if extracted_text.startswith('β'): | |
return extracted_text, "" | |
print("π€ Generating SOAP note...") | |
soap_note = generate_soap_note(extracted_text) | |
return extracted_text, soap_note | |
elif text.strip(): | |
# Process text directly | |
print("π€ Generating SOAP note from text...") | |
soap_note = generate_soap_note(text.strip()) | |
return text.strip(), soap_note | |
else: | |
return "β Please provide either an image or text input.", "" | |
def create_demo(): | |
"""Create the Gradio demo interface""" | |
# Sample text for demonstration | |
sample_text = """Patient: John Smith, 45yo male | |
CC: Chest pain | |
Vitals: BP 140/90, HR 88, RR 16, O2 98%, Temp 98.6F | |
HPI: Patient reports crushing chest pain x 2 hours, radiating to left arm | |
PMH: HTN, DM Type 2 | |
Current Meds: Lisinopril 10mg daily, Metformin 500mg BID | |
PE: Diaphoretic, anxious appearance | |
EKG: ST elevation in leads II, III, aVF""" | |
with gr.Blocks(title="Medical OCR SOAP Generator", theme=gr.themes.Soft()) as demo: | |
gr.HTML(""" | |
<h1>π₯ Medical OCR SOAP Generator - LIVE DEMO</h1> | |
<h2>π― For Competition Judges - Quick 2-Minute Demo:</h2> | |
<div style="background-color: #e6f3ff; padding: 15px; border-radius: 10px; margin: 10px 0;"> | |
<h3>π SAMPLE IMAGE PROVIDED:</h3> | |
<p><strong>π Download "docs-note-to-upload.jpg" from the Files tab above, then upload it below</strong></p> | |
<p><strong>OR</strong> click "Try Sample Medical Text" button for instant text demo</p> | |
</div> | |
<h3>Demo Steps:</h3> | |
<ol> | |
<li><strong>Upload the sample image</strong> (docs-note-to-upload.jpg from Files tab) <strong>OR</strong> click sample text button</li> | |
<li><strong>Click "Generate SOAP Note"</strong></li> | |
<li><strong>Wait ~60-90 seconds</strong> for AI processing (first time only)</li> | |
<li><strong>See professional SOAP note</strong> generated by Gemma 3n</li> | |
</ol> | |
<h3>β What This Demo Shows:</h3> | |
<ul> | |
<li><strong>Real OCR</strong> extraction from handwritten medical notes</li> | |
<li><strong>AI-powered medical reasoning</strong> with Gemma 3n</li> | |
<li><strong>Professional SOAP formatting</strong> (Subjective, Objective, Assessment, Plan)</li> | |
<li><strong>HIPAA-compliant</strong> local processing</li> | |
</ul> | |
<p><strong>β οΈ Note:</strong> First generation takes ~60-90 seconds as model loads. Subsequent ones are faster.</p> | |
<hr> | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
image_input = gr.Image( | |
type="pil", | |
label="π· Upload Medical Image", | |
height=300 | |
) | |
text_input = gr.Textbox( | |
label="π Or Enter Medical Text", | |
placeholder=sample_text, | |
lines=8, | |
max_lines=15 | |
) | |
submit_btn = gr.Button( | |
"Generate SOAP Note", | |
variant="primary", | |
size="lg" | |
) | |
with gr.Column(): | |
extracted_output = gr.Textbox( | |
label="π Extracted/Input Text", | |
lines=6, | |
max_lines=10 | |
) | |
soap_output = gr.Textbox( | |
label="π₯ Generated SOAP Note", | |
lines=12, | |
max_lines=20 | |
) | |
# Example section | |
gr.Markdown("### π Quick Test Example") | |
example_btn = gr.Button("Try Sample Medical Text", variant="secondary") | |
def load_example(): | |
return sample_text, None | |
example_btn.click( | |
load_example, | |
outputs=[text_input, image_input] | |
) | |
# Process function | |
submit_btn.click( | |
process_medical_input, | |
inputs=[image_input, text_input], | |
outputs=[extracted_output, soap_output] | |
) | |
gr.Markdown(""" | |
--- | |
**About:** This application uses Google's Gemma 3n model for medical text understanding and EasyOCR for handwriting recognition. | |
All processing is done locally for HIPAA compliance. | |
**Competition Entry:** Medical AI Innovation Challenge 2024 | |
""") | |
return demo | |
# Initialize the application | |
if __name__ == "__main__": | |
print("π Starting Medical OCR SOAP Generator...") | |
# Load model | |
model_loaded = load_model() | |
if model_loaded: | |
print("β All systems ready!") | |
demo = create_demo() | |
demo.launch( | |
share=True, | |
server_name="0.0.0.0", | |
server_port=7860 | |
) | |
else: | |
print("β Failed to load model. Creating fallback demo...") | |
def fallback_demo(): | |
return "β Model loading failed. Please check the logs.", "β Model not available." | |
demo = gr.Interface( | |
fn=fallback_demo, | |
inputs=[ | |
gr.Image(type="pil", label="Upload Medical Image"), | |
gr.Textbox(label="Enter Medical Text", lines=5) | |
], | |
outputs=[ | |
gr.Textbox(label="Status"), | |
gr.Textbox(label="Error Message") | |
], | |
title="β Medical OCR - Model Loading Failed" | |
) | |
demo.launch(share=True) |