Bonosa2 commited on
Commit
4828c8f
Β·
verified Β·
1 Parent(s): 96d91c7

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +309 -571
app.py CHANGED
@@ -1,571 +1,309 @@
1
- # -*- coding: utf-8 -*-
2
- # πŸ₯ Gemma 3N SOAP Note Generator with Unsloth
3
- # Optimized for offline medical documentation
4
-
5
- import torch
6
- import gradio as gr
7
- import io
8
- import base64
9
- from datetime import datetime
10
- import os
11
- import easyocr
12
- from PIL import Image, ImageDraw, ImageFont
13
- import cv2
14
- import numpy as np
15
- import psutil
16
-
17
- # Import Unsloth for optimized Gemma 3n
18
- try:
19
- from unsloth import FastModel
20
- print("βœ… Unsloth imported successfully")
21
- UNSLOTH_AVAILABLE = True
22
- except ImportError:
23
- print("❌ Unsloth not available. Install with: pip install unsloth")
24
- UNSLOTH_AVAILABLE = False
25
-
26
- # Device setup
27
- def setup_device():
28
- device = "cuda" if torch.cuda.is_available() else "cpu"
29
- print(f"πŸ–₯️ Using device: {device}")
30
-
31
- if torch.cuda.is_available():
32
- print(f"πŸš€ GPU: {torch.cuda.get_device_name(0)}")
33
- print(f"πŸ’Ύ GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
34
- else:
35
- print("⚠️ Running on CPU - will be slower but works offline")
36
-
37
- return device
38
-
39
- # Load Unsloth Gemma 3n model
40
- def load_unsloth_gemma_model(device):
41
- """Load optimized Gemma 3n model using Unsloth"""
42
-
43
- if not UNSLOTH_AVAILABLE:
44
- print("❌ Unsloth not available. Using fallback method.")
45
- return load_fallback_model()
46
-
47
- try:
48
- print("πŸ“‘ Loading Unsloth-optimized Gemma 3n model...")
49
-
50
- # Use the 4-bit quantized model for efficiency
51
- model_name = "unsloth/gemma-3n-E4B-it-unsloth-bnb-4bit"
52
-
53
- print(f"πŸ”§ Loading model: {model_name}")
54
-
55
- # Load with Unsloth optimizations
56
- model, tokenizer = FastModel.from_pretrained(
57
- model_name=model_name,
58
- dtype=None, # Auto-detect
59
- max_seq_length=1024, # Good for medical notes
60
- load_in_4bit=True, # 4-bit quantization for efficiency
61
- full_finetuning=False,
62
- )
63
-
64
- print("βœ… Unsloth Gemma 3n model loaded successfully!")
65
- print(f"πŸ“Š Model: {model_name}")
66
- print(f"πŸ’Ύ Memory optimized with 4-bit quantization")
67
- print(f"🎯 Ready for medical SOAP note generation!")
68
-
69
- return model, tokenizer
70
-
71
- except Exception as e:
72
- print(f"❌ Error loading Unsloth model: {e}")
73
- print("πŸ’‘ Trying fallback model...")
74
- return load_fallback_model()
75
-
76
- def load_fallback_model():
77
- """Fallback model if Unsloth fails"""
78
- try:
79
- from transformers import AutoTokenizer, AutoModelForCausalLM
80
-
81
- print("πŸ”„ Loading fallback model...")
82
- model_name = "microsoft/DialoGPT-medium"
83
-
84
- tokenizer = AutoTokenizer.from_pretrained(model_name)
85
- if tokenizer.pad_token is None:
86
- tokenizer.pad_token = tokenizer.eos_token
87
-
88
- model = AutoModelForCausalLM.from_pretrained(
89
- model_name,
90
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
91
- low_cpu_mem_usage=True
92
- )
93
-
94
- print("βœ… Fallback model loaded!")
95
- return model, tokenizer
96
-
97
- except Exception as e:
98
- print(f"❌ Fallback model also failed: {e}")
99
- return None, None
100
-
101
- # Enhanced SOAP Note Generation with Gemma 3n
102
- def generate_soap_note_gemma(doctor_notes, model=None, tokenizer=None, include_timestamp=True):
103
- """Generate SOAP note using Gemma 3n model"""
104
-
105
- if not doctor_notes.strip():
106
- return "❌ Please enter some medical notes to process."
107
-
108
- if model is None or tokenizer is None:
109
- return generate_template_soap(doctor_notes, include_timestamp)
110
-
111
- # Medical-specific prompt for Gemma 3n
112
- prompt = f"""<bos><start_of_turn>user
113
- You are a medical AI assistant specialized in creating SOAP notes. Convert the following unstructured medical notes into a professional SOAP note format.
114
-
115
- Medical Notes:
116
- {doctor_notes}
117
-
118
- Please create a structured SOAP note with these sections:
119
- - SUBJECTIVE: Patient's reported symptoms, complaints, and relevant history
120
- - OBJECTIVE: Physical examination findings, vital signs, and observable data
121
- - ASSESSMENT: Clinical diagnosis, differential diagnosis, and medical reasoning
122
- - PLAN: Treatment recommendations, medications, tests, and follow-up care
123
-
124
- <end_of_turn>
125
- <start_of_turn>model
126
- SOAP NOTE:
127
-
128
- SUBJECTIVE:"""
129
-
130
- try:
131
- # Tokenize input
132
- inputs = tokenizer(
133
- prompt,
134
- return_tensors="pt",
135
- truncation=True,
136
- max_length=512,
137
- padding=True
138
- )
139
-
140
- # Generate with optimized settings for medical text
141
- with torch.no_grad():
142
- outputs = model.generate(
143
- **inputs,
144
- max_new_tokens=400,
145
- temperature=0.2, # Lower temperature for medical precision
146
- top_p=0.9,
147
- do_sample=True,
148
- repetition_penalty=1.1,
149
- pad_token_id=tokenizer.eos_token_id,
150
- eos_token_id=tokenizer.eos_token_id
151
- )
152
-
153
- # Decode response
154
- generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
155
-
156
- # Extract only the SOAP note part
157
- if "SOAP NOTE:" in generated_text:
158
- soap_response = generated_text.split("SOAP NOTE:")[1].strip()
159
- else:
160
- soap_response = generated_text[len(prompt):].strip()
161
-
162
- # Clean up response
163
- soap_response = clean_soap_response(soap_response)
164
-
165
- # Add professional header
166
- if include_timestamp:
167
- timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
168
- header = f"""πŸ“‹ SOAP NOTE - Generated by Gemma 3n
169
- πŸ• Timestamp: {timestamp}
170
- πŸ€– Model: Unsloth-optimized Gemma 3n (4-bit quantized)
171
- πŸ”’ Processed locally on device
172
- πŸ₯ Medical Documentation Assistant
173
-
174
- {'='*60}
175
- """
176
- return header + soap_response
177
-
178
- return soap_response
179
-
180
- except Exception as e:
181
- print(f"❌ Generation error: {e}")
182
- return generate_template_soap(doctor_notes, include_timestamp)
183
-
184
- def clean_soap_response(response):
185
- """Clean and format SOAP note response"""
186
-
187
- # Remove any incomplete sentences at the end
188
- lines = response.split('\n')
189
- cleaned_lines = []
190
-
191
- for line in lines:
192
- line = line.strip()
193
- if line:
194
- # Ensure proper SOAP section headers
195
- if line.upper().startswith(('SUBJECTIVE', 'OBJECTIVE', 'ASSESSMENT', 'PLAN')):
196
- if not line.endswith(':'):
197
- line += ':'
198
- cleaned_lines.append(f"\n{line}")
199
- else:
200
- cleaned_lines.append(line)
201
-
202
- return '\n'.join(cleaned_lines).strip()
203
-
204
- # Template-based SOAP generation (enhanced fallback)
205
- def generate_template_soap(doctor_notes, include_timestamp=True):
206
- """Enhanced template-based SOAP note generation"""
207
-
208
- notes_lower = doctor_notes.lower()
209
- lines = doctor_notes.split('\n')
210
-
211
- # Enhanced keyword extraction
212
- subjective_info = extract_section_info(lines, [
213
- 'complains', 'reports', 'states', 'denies', 'pain', 'symptoms',
214
- 'history', 'onset', 'duration', 'patient says', 'chief complaint'
215
- ])
216
-
217
- objective_info = extract_section_info(lines, [
218
- 'vital signs', 'vs:', 'bp', 'hr', 'temp', 'examination', 'exam',
219
- 'physical', 'inspection', 'palpation', 'auscultation', 'laboratory'
220
- ])
221
-
222
- assessment_info = extract_section_info(lines, [
223
- 'diagnosis', 'impression', 'assessment', 'likely', 'possible',
224
- 'rule out', 'differential', 'icd', 'condition'
225
- ])
226
-
227
- plan_info = extract_section_info(lines, [
228
- 'plan', 'treatment', 'medication', 'prescribe', 'follow', 'return',
229
- 'therapy', 'intervention', 'monitoring', 'referral'
230
- ])
231
-
232
- # Build comprehensive SOAP note
233
- soap_note = build_soap_sections(subjective_info, objective_info, assessment_info, plan_info)
234
-
235
- if include_timestamp:
236
- timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
237
- header = f"""πŸ“‹ SOAP NOTE (Template-Enhanced)
238
- πŸ• Timestamp: {timestamp}
239
- πŸ”’ Processed locally - HIPAA compliant
240
- πŸ₯ Scribbled Docs Medical Assistant
241
-
242
- {'='*60}
243
- """
244
- return header + soap_note
245
-
246
- return soap_note
247
-
248
- def extract_section_info(lines, keywords):
249
- """Extract relevant lines for each SOAP section"""
250
- relevant_lines = []
251
- for line in lines:
252
- if any(keyword in line.lower() for keyword in keywords):
253
- relevant_lines.append(line.strip())
254
- return relevant_lines
255
-
256
- def build_soap_sections(subjective, objective, assessment, plan):
257
- """Build formatted SOAP sections"""
258
-
259
- soap = "SUBJECTIVE:\n"
260
- if subjective:
261
- soap += '\n'.join(f"β€’ {line}" for line in subjective[:5]) # Limit to 5 most relevant
262
- else:
263
- soap += "β€’ Patient complaints and reported symptoms as documented"
264
-
265
- soap += "\n\nOBJECTIVE:\n"
266
- if objective:
267
- soap += '\n'.join(f"β€’ {line}" for line in objective[:5])
268
- else:
269
- soap += "β€’ Physical examination findings and clinical observations as documented"
270
-
271
- soap += "\n\nASSESSMENT:\n"
272
- if assessment:
273
- soap += '\n'.join(f"β€’ {line}" for line in assessment[:3])
274
- else:
275
- soap += "β€’ Clinical assessment based on presenting symptoms and examination findings"
276
-
277
- soap += "\n\nPLAN:\n"
278
- if plan:
279
- soap += '\n'.join(f"β€’ {line}" for line in plan[:5])
280
- else:
281
- soap += "β€’ Treatment plan and follow-up care as clinically indicated"
282
-
283
- return soap
284
-
285
- # OCR Functions (same as before but optimized)
286
- def initialize_ocr():
287
- """Initialize OCR reader for handwritten notes"""
288
- try:
289
- # Initialize with English and medical text optimization
290
- reader = easyocr.Reader(['en'], gpu=torch.cuda.is_available())
291
- print("βœ… EasyOCR initialized for handwritten medical notes")
292
- return reader
293
- except Exception as e:
294
- print(f"⚠️ EasyOCR initialization failed: {e}")
295
- return None
296
-
297
- def extract_text_from_image(image, ocr_reader=None):
298
- """Enhanced OCR for medical handwriting"""
299
- if image is None:
300
- return "❌ No image provided"
301
-
302
- try:
303
- # Preprocess specifically for medical handwriting
304
- processed_img = preprocess_medical_image(image)
305
-
306
- extracted_text = ""
307
-
308
- # Try EasyOCR (better for handwritten text)
309
- if ocr_reader is not None:
310
- try:
311
- results = ocr_reader.readtext(processed_img, detail=0, paragraph=True)
312
- if results:
313
- extracted_text = ' '.join(results)
314
- if len(extracted_text.strip()) > 10:
315
- return clean_medical_text(extracted_text)
316
- except Exception as e:
317
- print(f"EasyOCR failed: {e}")
318
-
319
- # Fallback to Tesseract with medical optimization
320
- try:
321
- import pytesseract
322
-
323
- # Medical-optimized Tesseract config
324
- custom_config = r'--oem 3 --psm 6 -c tessedit_char_whitelist=ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789.,;:()[]{}/-+= '
325
-
326
- tesseract_text = pytesseract.image_to_string(processed_img, config=custom_config)
327
-
328
- if len(tesseract_text.strip()) > 5:
329
- return clean_medical_text(tesseract_text)
330
-
331
- except Exception as e:
332
- print(f"Tesseract failed: {e}")
333
-
334
- return "❌ Could not extract text from image. Please ensure the image is clear and try again."
335
-
336
- except Exception as e:
337
- return f"❌ Error processing image: {str(e)}"
338
-
339
- def preprocess_medical_image(image):
340
- """Optimized preprocessing for medical handwriting"""
341
- try:
342
- img_array = np.array(image)
343
-
344
- # Convert to grayscale
345
- if len(img_array.shape) == 3:
346
- gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
347
- else:
348
- gray = img_array
349
-
350
- # Resize for optimal OCR (medical notes are often small)
351
- height, width = gray.shape
352
- if height < 400 or width < 400:
353
- scale_factor = max(400/height, 400/width)
354
- new_width = int(width * scale_factor)
355
- new_height = int(height * scale_factor)
356
- gray = cv2.resize(gray, (new_width, new_height), interpolation=cv2.INTER_CUBIC)
357
-
358
- # Advanced preprocessing for handwritten medical text
359
- # 1. Noise reduction
360
- denoised = cv2.fastNlMeansDenoising(gray)
361
-
362
- # 2. Contrast enhancement specifically for handwriting
363
- clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8,8))
364
- enhanced = clahe.apply(denoised)
365
-
366
- # 3. Morphological operations to clean up text
367
- kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1,1))
368
- cleaned = cv2.morphologyEx(enhanced, cv2.MORPH_CLOSE, kernel)
369
-
370
- # 4. Adaptive thresholding (better for varying lighting)
371
- thresh = cv2.adaptiveThreshold(
372
- cleaned, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 11, 2
373
- )
374
-
375
- return thresh
376
-
377
- except Exception as e:
378
- print(f"❌ Image preprocessing error: {e}")
379
- return np.array(image)
380
-
381
- def clean_medical_text(text):
382
- """Clean extracted text with medical context awareness"""
383
- # Remove excessive whitespace and empty lines
384
- lines = [line.strip() for line in text.split('\n') if line.strip()]
385
-
386
- # Medical text cleaning
387
- cleaned_lines = []
388
- for line in lines:
389
- # Remove obvious OCR artifacts
390
- line = line.replace('|', 'l').replace('_', ' ').replace('~', '-')
391
-
392
- # Fix common medical abbreviations that OCR might misread
393
- medical_corrections = {
394
- 'BP': 'BP', 'HR': 'HR', 'RR': 'RR', 'O2': 'O2',
395
- 'mg': 'mg', 'ml': 'ml', 'cc': 'cc', 'cm': 'cm'
396
- }
397
-
398
- for wrong, correct in medical_corrections.items():
399
- line = line.replace(wrong.lower(), correct)
400
-
401
- if len(line) > 1: # Filter out single characters
402
- cleaned_lines.append(line)
403
-
404
- return '\n'.join(cleaned_lines)
405
-
406
- # Enhanced Gradio Interface
407
- def gradio_generate_soap(medical_notes, uploaded_image, model_data):
408
- """Main Gradio interface function"""
409
- model, tokenizer = model_data if model_data else (None, None)
410
- ocr_reader = getattr(gradio_generate_soap, 'ocr_reader', None)
411
-
412
- text_to_process = medical_notes.strip() if medical_notes else ""
413
-
414
- # Process uploaded image with enhanced OCR
415
- if uploaded_image is not None:
416
- try:
417
- print("πŸ” Extracting text from medical image...")
418
- extracted_text = extract_text_from_image(uploaded_image, ocr_reader)
419
-
420
- if not extracted_text.startswith("❌"):
421
- if not text_to_process:
422
- text_to_process = f"--- Extracted from uploaded image ---\n{extracted_text}"
423
- else:
424
- text_to_process = f"{text_to_process}\n\n--- Additional text from image ---\n{extracted_text}"
425
- else:
426
- return extracted_text
427
-
428
- except Exception as e:
429
- return f"❌ Error processing image: {str(e)}"
430
-
431
- if not text_to_process:
432
- return "❌ Please enter medical notes manually or upload an image with medical text"
433
-
434
- # Generate SOAP note using Gemma 3n
435
- try:
436
- return generate_soap_note_gemma(text_to_process, model, tokenizer)
437
- except Exception as e:
438
- return f"❌ Error generating SOAP note: {str(e)}"
439
-
440
- # Example medical notes for testing
441
- medical_examples = {
442
- 'chest_pain': """Patient: John Smith, 45yo M
443
- CC: Chest pain x 2 hours
444
- HPI: Sudden onset sharp substernal chest pain 7/10, radiating to L arm. Associated SOB, diaphoresis. No N/V.
445
- PMH: HTN, no CAD
446
- VS: BP 150/90, HR 110, RR 22, O2 96% RA
447
- PE: Anxious, diaphoretic. RRR, no murmur. CTAB. No edema.
448
- A: Acute chest pain, r/o MI
449
- P: EKG, troponins, CXR, ASA 325mg, monitor""",
450
-
451
- 'diabetes': """Patient: Maria Garcia, 52yo F
452
- CC: Increased thirst, urination x 3 weeks
453
- HPI: Polyuria, polydipsia, 10lb weight loss. FH DM. No fever, abd pain.
454
- VS: BP 140/85, HR 88, BMI 28
455
- PE: Mild dehydration, dry MM. RRR. No diabetic foot changes.
456
- Labs: Random glucose 280, HbA1c pending
457
- A: New onset DM Type 2
458
- P: HbA1c, CMP, diabetic education, metformin, f/u 2 weeks""",
459
-
460
- 'pediatric': """Patient: Emma Thompson, 8yo F
461
- CC: Fever, sore throat x 2 days
462
- HPI: Fever 102F, sore throat, odynophagia, decreased appetite. No cough/rhinorrhea.
463
- VS: T 101.8F, HR 110, RR 20, O2 99%
464
- PE: Alert, mildly ill. Throat erythematous w/ tonsillar exudate. Anterior cervical LAD.
465
- A: Strep pharyngitis (probable)
466
- P: Rapid strep, throat culture, amoxicillin if +, supportive care, RTC PRN"""
467
- }
468
-
469
- # Initialize everything
470
- def initialize_app():
471
- """Initialize the complete application"""
472
- print("πŸš€ Initializing Scribbled Docs SOAP Generator...")
473
-
474
- # Setup device
475
- device = setup_device()
476
-
477
- # Load model
478
- model, tokenizer = load_unsloth_gemma_model(device)
479
-
480
- # Initialize OCR
481
- ocr_reader = initialize_ocr()
482
- gradio_generate_soap.ocr_reader = ocr_reader
483
-
484
- return model, tokenizer
485
-
486
- # Create the main Gradio interface
487
- def create_interface(model, tokenizer):
488
- """Create the main Gradio interface"""
489
-
490
- interface = gr.Interface(
491
- fn=lambda notes, image: gradio_generate_soap(notes, image, (model, tokenizer)),
492
- inputs=[
493
- gr.Textbox(
494
- lines=8,
495
- placeholder="Enter medical notes here...\n\nExample:\nPatient: John Doe, 45yo M\nCC: Chest pain\nVS: BP 140/90, HR 88\n...",
496
- label="πŸ“ Medical Notes (Manual Entry)",
497
- info="Enter unstructured medical notes or upload an image below"
498
- ),
499
- gr.Image(
500
- type="pil",
501
- label="πŸ“· Upload Medical Image (Handwritten/Typed Notes)",
502
- sources=["upload", "webcam"],
503
- info="Upload PNG/JPG images of medical notes - handwritten or typed"
504
- )
505
- ],
506
- outputs=[
507
- gr.Textbox(
508
- lines=20,
509
- label="πŸ“‹ Generated SOAP Note",
510
- show_copy_button=True,
511
- info="Professional SOAP note generated from your input"
512
- )
513
- ],
514
- title="πŸ₯ Scribbled Docs - Medical SOAP Note Generator",
515
- description="""
516
- **Transform medical notes into professional SOAP documentation using Gemma 3n AI**
517
-
518
- πŸ”’ **100% Offline & HIPAA Compliant** - All processing happens locally on your device
519
- πŸ€– **Powered by Unsloth-optimized Gemma 3n** - 4-bit quantized for efficiency
520
- πŸ“ **Supports handwritten & typed notes** - Advanced OCR for medical handwriting
521
-
522
- **Instructions:**
523
- 1. Enter medical notes manually OR upload an image
524
- 2. Click Submit to generate a structured SOAP note
525
- 3. Copy the result for use in your medical records
526
-
527
- **Perfect for:** Emergency medicine, family practice, internal medicine, pediatrics
528
- """,
529
- examples=[
530
- [medical_examples['chest_pain'], None],
531
- [medical_examples['diabetes'], None],
532
- [medical_examples['pediatric'], None]
533
- ],
534
- theme=gr.themes.Soft(
535
- primary_hue="blue",
536
- secondary_hue="green"
537
- ),
538
- allow_flagging="never",
539
- analytics_enabled=False
540
- )
541
-
542
- return interface
543
-
544
- # Main execution
545
- if __name__ == "__main__":
546
- try:
547
- # Initialize app
548
- model, tokenizer = initialize_app()
549
-
550
- # Create and launch interface
551
- interface = create_interface(model, tokenizer)
552
-
553
- print("\n🎯 Scribbled Docs SOAP Generator Ready!")
554
- print("πŸ“± Features:")
555
- print(" βœ… Offline processing (HIPAA compliant)")
556
- print(" βœ… Unsloth-optimized Gemma 3n model")
557
- print(" βœ… Handwritten note OCR")
558
- print(" βœ… Professional SOAP formatting")
559
- print(" βœ… Medical terminology aware")
560
-
561
- # Launch interface
562
- interface.launch(
563
- share=True, # Creates public link
564
- server_port=7860,
565
- show_error=True,
566
- quiet=False
567
- )
568
-
569
- except Exception as e:
570
- print(f"❌ Error launching application: {e}")
571
- print("πŸ’‘ Make sure you have installed: pip install unsloth gradio easyocr opencv-python")
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ import time
6
+ import io
7
+ import subprocess
8
+ import sys
9
+
10
+ # Install required packages
11
+ def install_packages():
12
+ packages = [
13
+ "transformers",
14
+ "accelerate",
15
+ "timm",
16
+ "easyocr"
17
+ ]
18
+ for package in packages:
19
+ try:
20
+ subprocess.check_call([sys.executable, "-m", "pip", "install", package])
21
+ except:
22
+ print(f"Warning: Could not install {package}")
23
+
24
+ # Install packages at startup
25
+ install_packages()
26
+
27
+ from transformers import AutoProcessor, AutoModelForImageTextToText, AutoConfig
28
+
29
+ # Global variables for model
30
+ processor = None
31
+ model = None
32
+ config = None
33
+ ocr_reader = None
34
+
35
+ def load_model():
36
+ """Load the Gemma 3n model"""
37
+ global processor, model, config, ocr_reader
38
+
39
+ try:
40
+ print("πŸš€ Loading Gemma 3n model...")
41
+ GEMMA_PATH = "google/gemma-3n-e2b-it"
42
+
43
+ # Load configuration
44
+ config = AutoConfig.from_pretrained(GEMMA_PATH, trust_remote_code=True)
45
+ print("βœ… Config loaded")
46
+
47
+ # Load processor
48
+ processor = AutoProcessor.from_pretrained(GEMMA_PATH, trust_remote_code=True)
49
+ print("βœ… Processor loaded")
50
+
51
+ # Load model
52
+ model = AutoModelForImageTextToText.from_pretrained(
53
+ GEMMA_PATH,
54
+ config=config,
55
+ torch_dtype="auto",
56
+ device_map="auto",
57
+ trust_remote_code=True
58
+ )
59
+ print("βœ… Model loaded successfully!")
60
+
61
+ # Set up compilation fix
62
+ import torch._dynamo
63
+ torch._dynamo.config.suppress_errors = True
64
+
65
+ # Initialize OCR
66
+ try:
67
+ import easyocr
68
+ ocr_reader = easyocr.Reader(['en'], gpu=False, verbose=False)
69
+ print("βœ… EasyOCR initialized")
70
+ except Exception as e:
71
+ print(f"⚠️ EasyOCR not available: {e}")
72
+ ocr_reader = None
73
+
74
+ return True
75
+
76
+ except Exception as e:
77
+ print(f"❌ Model loading failed: {e}")
78
+ return False
79
+
80
+ def generate_soap_note(text):
81
+ """Generate SOAP note using Gemma 3n"""
82
+ if model is None or processor is None:
83
+ return "❌ Model not loaded. Please wait for initialization."
84
+
85
+ soap_prompt = f"""You are a medical AI assistant. Convert the following medical notes into a properly formatted SOAP note.
86
+
87
+ Medical notes:
88
+ {text}
89
+
90
+ Please format as:
91
+ S - SUBJECTIVE: (chief complaint, history of present illness, past medical history, medications, allergies)
92
+ O - OBJECTIVE: (vital signs, physical examination findings)
93
+ A - ASSESSMENT: (diagnosis/clinical impression)
94
+ P - PLAN: (treatment plan, follow-up instructions)
95
+
96
+ Generate a complete, professional SOAP note:"""
97
+
98
+ messages = [{
99
+ "role": "system",
100
+ "content": [{"type": "text", "text": "You are an expert medical AI assistant specialized in creating SOAP notes from medical documentation."}]
101
+ }, {
102
+ "role": "user",
103
+ "content": [{"type": "text", "text": soap_prompt}]
104
+ }]
105
+
106
+ try:
107
+ inputs = processor.apply_chat_template(
108
+ messages,
109
+ add_generation_prompt=True,
110
+ tokenize=True,
111
+ return_dict=True,
112
+ return_tensors="pt"
113
+ ).to(model.device)
114
+
115
+ input_len = inputs["input_ids"].shape[-1]
116
+
117
+ with torch.no_grad():
118
+ outputs = model.generate(
119
+ **inputs,
120
+ max_new_tokens=400,
121
+ do_sample=True,
122
+ temperature=0.1,
123
+ top_p=0.95,
124
+ pad_token_id=processor.tokenizer.eos_token_id,
125
+ disable_compile=True
126
+ )
127
+
128
+ response = processor.batch_decode(
129
+ outputs[:, input_len:],
130
+ skip_special_tokens=True
131
+ )[0].strip()
132
+
133
+ return response
134
+
135
+ except Exception as e:
136
+ return f"❌ SOAP generation failed: {str(e)}"
137
+
138
+ def extract_text_from_image(image):
139
+ """Extract text using EasyOCR"""
140
+ if ocr_reader is None:
141
+ return "❌ OCR not available"
142
+
143
+ try:
144
+ if hasattr(image, 'convert'):
145
+ image = image.convert('RGB')
146
+ img_array = np.array(image)
147
+
148
+ results = ocr_reader.readtext(img_array, detail=0, paragraph=True)
149
+ if results:
150
+ return ' '.join(results).strip()
151
+ else:
152
+ return "❌ No text detected in image"
153
+
154
+ except Exception as e:
155
+ return f"❌ OCR failed: {str(e)}"
156
+
157
+ def process_medical_input(image, text):
158
+ """Main processing function for the Gradio interface"""
159
+
160
+ if image is not None and text.strip():
161
+ return "⚠️ Please provide either an image OR text, not both.", ""
162
+
163
+ if image is not None:
164
+ # Process image
165
+ print("πŸ” Extracting text from image...")
166
+ extracted_text = extract_text_from_image(image)
167
+
168
+ if extracted_text.startswith('❌'):
169
+ return extracted_text, ""
170
+
171
+ print("πŸ€– Generating SOAP note...")
172
+ soap_note = generate_soap_note(extracted_text)
173
+
174
+ return extracted_text, soap_note
175
+
176
+ elif text.strip():
177
+ # Process text directly
178
+ print("πŸ€– Generating SOAP note from text...")
179
+ soap_note = generate_soap_note(text.strip())
180
+ return text.strip(), soap_note
181
+
182
+ else:
183
+ return "❌ Please provide either an image or text input.", ""
184
+
185
+ def create_demo():
186
+ """Create the Gradio demo interface"""
187
+
188
+ # Sample text for demonstration
189
+ sample_text = """Patient: John Smith, 45yo male
190
+ CC: Chest pain
191
+ Vitals: BP 140/90, HR 88, RR 16, O2 98%, Temp 98.6F
192
+ HPI: Patient reports crushing chest pain x 2 hours, radiating to left arm
193
+ PMH: HTN, DM Type 2
194
+ Current Meds: Lisinopril 10mg daily, Metformin 500mg BID
195
+ PE: Diaphoretic, anxious appearance
196
+ EKG: ST elevation in leads II, III, aVF"""
197
+
198
+ with gr.Blocks(title="Medical OCR SOAP Generator", theme=gr.themes.Soft()) as demo:
199
+
200
+ gr.Markdown("""
201
+ # πŸ₯ Medical OCR SOAP Generator
202
+ ### Powered by Gemma 3n - Convert handwritten medical notes to professional SOAP format
203
+
204
+ **Instructions:**
205
+ - **Option 1:** Upload an image of handwritten medical notes
206
+ - **Option 2:** Enter medical text directly
207
+ - The system will generate a properly formatted SOAP note
208
+
209
+ ⚠️ **Note:** First generation may take ~60-90 seconds as the model loads
210
+ """)
211
+
212
+ with gr.Row():
213
+ with gr.Column():
214
+ image_input = gr.Image(
215
+ type="pil",
216
+ label="πŸ“· Upload Medical Image",
217
+ height=300
218
+ )
219
+
220
+ text_input = gr.Textbox(
221
+ label="πŸ“ Or Enter Medical Text",
222
+ placeholder=sample_text,
223
+ lines=8,
224
+ max_lines=15
225
+ )
226
+
227
+ submit_btn = gr.Button(
228
+ "Generate SOAP Note",
229
+ variant="primary",
230
+ size="lg"
231
+ )
232
+
233
+ with gr.Column():
234
+ extracted_output = gr.Textbox(
235
+ label="πŸ“‹ Extracted/Input Text",
236
+ lines=6,
237
+ max_lines=10
238
+ )
239
+
240
+ soap_output = gr.Textbox(
241
+ label="πŸ₯ Generated SOAP Note",
242
+ lines=12,
243
+ max_lines=20
244
+ )
245
+
246
+ # Example section
247
+ gr.Markdown("### πŸ“‹ Quick Test Example")
248
+ example_btn = gr.Button("Try Sample Medical Text", variant="secondary")
249
+
250
+ def load_example():
251
+ return sample_text, None
252
+
253
+ example_btn.click(
254
+ load_example,
255
+ outputs=[text_input, image_input]
256
+ )
257
+
258
+ # Process function
259
+ submit_btn.click(
260
+ process_medical_input,
261
+ inputs=[image_input, text_input],
262
+ outputs=[extracted_output, soap_output]
263
+ )
264
+
265
+ gr.Markdown("""
266
+ ---
267
+ **About:** This application uses Google's Gemma 3n model for medical text understanding and EasyOCR for handwriting recognition.
268
+ All processing is done locally for HIPAA compliance.
269
+
270
+ **Competition Entry:** Medical AI Innovation Challenge 2024
271
+ """)
272
+
273
+ return demo
274
+
275
+ # Initialize the application
276
+ if __name__ == "__main__":
277
+ print("πŸš€ Starting Medical OCR SOAP Generator...")
278
+
279
+ # Load model
280
+ model_loaded = load_model()
281
+
282
+ if model_loaded:
283
+ print("βœ… All systems ready!")
284
+ demo = create_demo()
285
+ demo.launch(
286
+ share=True,
287
+ server_name="0.0.0.0",
288
+ server_port=7860
289
+ )
290
+ else:
291
+ print("❌ Failed to load model. Creating fallback demo...")
292
+
293
+ def fallback_demo():
294
+ return "❌ Model loading failed. Please check the logs.", "❌ Model not available."
295
+
296
+ demo = gr.Interface(
297
+ fn=fallback_demo,
298
+ inputs=[
299
+ gr.Image(type="pil", label="Upload Medical Image"),
300
+ gr.Textbox(label="Enter Medical Text", lines=5)
301
+ ],
302
+ outputs=[
303
+ gr.Textbox(label="Status"),
304
+ gr.Textbox(label="Error Message")
305
+ ],
306
+ title="❌ Medical OCR - Model Loading Failed"
307
+ )
308
+
309
+ demo.launch(share=True)