scdrand23 commited on
Commit
b816c7b
·
1 Parent(s): 3bf392c

integrated biomedllama

Browse files
Files changed (1) hide show
  1. app.py +146 -92
app.py CHANGED
@@ -16,31 +16,15 @@ from tqdm import tqdm
16
  import sys
17
  from pathlib import Path
18
  from huggingface_hub import login
19
- # from dotenv import load_dotenv
20
  from transformers import AutoModel, AutoTokenizer, BitsAndBytesConfig
21
 
22
- # For Hugging Face Spaces, secrets are automatically loaded as environment variables
23
  token = os.getenv("HF_TOKEN")
24
  if token:
25
  login(token=token)
26
- # Clear Hugging Face cache
27
- # cache_dirs = [
28
- # "/home/user/.cache/huggingface/",
29
- # "/home/user/.cache/torch/",
30
- # "/home/user/.cache/pip/"
31
- # ]
32
-
33
- # for cache_dir in cache_dirs:
34
- # if os.path.exists(cache_dir):
35
- # print(f"Clearing cache: {cache_dir}")
36
- # shutil.rmtree(cache_dir, ignore_errors=True)
37
- # Add the current directory to Python path
38
  current_dir = Path(__file__).parent
39
  sys.path.append(str(current_dir))
40
- # sys.path.append("./BiomedParse/")
41
- # BIOMEDPARSE_PATH = Path(__file__).parent / "BiomedParse"
42
- # sys.path.append(str(BIOMEDPARSE_PATH))
43
- # sys.path.append(str(BIOMEDPARSE_PATH / "BiomedParse")) # Add the inner BiomedParse directory
44
  from modeling.BaseModel import BaseModel
45
  from modeling import build_model
46
  from utilities.arguments import load_opt_from_config_files
@@ -51,7 +35,7 @@ from inference_utils.processing_utils import read_rgb
51
 
52
  import spaces
53
 
54
- # breakpoint()
55
  MARKDOWN = """
56
  <div align="center" style="padding: 20px 0;">
57
  <h1 style="font-size: 3em; margin: 0;">
@@ -154,6 +138,68 @@ MODALITY_PROMPTS = {
154
  "OCT": ["edema"] }
155
 
156
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  def on_mode_dropdown_change(selected_mode):
158
  if selected_mode in IMAGE_INFERENCE_MODES:
159
  # Show modality dropdown and hide other inputs initially
@@ -231,72 +277,68 @@ def update_example_prompts(modality):
231
  @spaces.GPU
232
  @torch.inference_mode()
233
  @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
234
- def process_image(image_path, text_prompts, modality):
235
  try:
236
- # Input validation
237
  if not image_path:
238
  raise ValueError("Please upload an image")
239
- if not text_prompts or text_prompts.strip() == "":
240
- raise ValueError("Please enter prompts for analysis")
241
- if not modality:
242
- raise ValueError("Please select a modality")
243
-
244
- # Original BiomedParse processing
245
  image = read_rgb(image_path)
246
- text_prompts = [prompt.strip() for prompt in text_prompts.split(',')]
247
- pred_masks = interactive_infer_image(model, Image.fromarray(image), text_prompts)
 
 
 
248
 
249
- # Prepare outputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
  results = []
251
  analysis_results = []
 
252
 
253
- # Process with BiomedParse
254
- for i, prompt in enumerate(text_prompts):
255
- p_value = check_mask_stats(image, pred_masks[i] * 255, modality, prompt)
256
- analysis_results.append(f"P-value for '{prompt}' ({modality}): {p_value:.4f}")
257
 
 
258
  overlay_image = image.copy()
259
- overlay_image[pred_masks[i] > 0.5] = [255, 0, 0]
 
260
  results.append(overlay_image)
261
 
262
- # Process with LLM only if available
263
- if llm_model is not None and llm_tokenizer is not None:
264
- print("LLM model and tokenizer are available")
265
- try:
266
- pil_image = Image.fromarray(image)
267
- question = 'Give the modality, organ, analysis, abnormalities (if any), treatment (if abnormalities are present)?'
268
- msgs = [{'role': 'user', 'content': [pil_image, question]}]
269
-
270
- print("Starting LLM inference...")
271
- llm_response = ""
272
- for new_text in llm_model.chat(
273
- image=pil_image,
274
- msgs=msgs,
275
- tokenizer=llm_tokenizer,
276
- sampling=True,
277
- temperature=0.95,
278
- stream=True
279
- ):
280
- llm_response += new_text
281
- print(f"LLM generated response: {llm_response}")
282
-
283
- # Make the combined analysis more visible
284
- combined_analysis = "\n\n" + "="*50 + "\n"
285
- combined_analysis += "BiomedParse Analysis:\n"
286
- combined_analysis += "\n".join(analysis_results)
287
- combined_analysis += "\n\n" + "="*50 + "\n"
288
- combined_analysis += "LLM Analysis:\n"
289
- combined_analysis += llm_response
290
- combined_analysis += "\n" + "="*50
291
-
292
- except Exception as e:
293
- print(f"LLM analysis failed with error: {str(e)}")
294
- combined_analysis = "\n".join(analysis_results)
295
- else:
296
- print("LLM model or tokenizer is not available")
297
- combined_analysis = "\n".join(analysis_results)
298
 
299
- return results, combined_analysis
 
 
 
 
 
 
 
 
300
 
301
  except Exception as e:
302
  error_msg = f"⚠️ An error occurred: {str(e)}"
@@ -309,33 +351,45 @@ with gr.Blocks() as demo:
309
  with gr.Row():
310
  with gr.Column():
311
  image_input = gr.Image(type="filepath", label="Input Image")
312
- prompts_input = gr.Textbox(
313
- lines=2,
314
- placeholder="Enter prompts separated by commas...",
315
- label="Prompts"
316
  )
317
- modality_dropdown = gr.Dropdown(
318
- choices=list(BIOMEDPARSE_MODES.keys()),
319
- value=list(BIOMEDPARSE_MODES.keys())[0],
320
- label="Modality"
321
  )
322
- submit_btn = gr.Button("Submit")
 
323
  with gr.Column():
324
- output_gallery = gr.Gallery(label="Findings")
325
- pvalue_output = gr.Textbox(
326
- label="Results",
 
 
 
 
 
327
  interactive=False,
328
- show_label=True
 
329
  )
330
- with gr.Accordion("Example Prompts by Modality", open=False):
331
- for modality, prompts in MODALITY_PROMPTS.items():
332
- prompt_str = ", ".join(prompts)
333
- gr.Markdown(f"**{modality}**: {prompt_str}")
334
- # Add error handling for the submit button
 
 
 
 
 
335
  submit_btn.click(
336
  fn=process_image,
337
- inputs=[image_input, prompts_input, modality_dropdown],
338
- outputs=[output_gallery, pvalue_output],
339
  api_name="process"
340
  )
341
 
 
16
  import sys
17
  from pathlib import Path
18
  from huggingface_hub import login
 
19
  from transformers import AutoModel, AutoTokenizer, BitsAndBytesConfig
20
 
21
+
22
  token = os.getenv("HF_TOKEN")
23
  if token:
24
  login(token=token)
 
 
 
 
 
 
 
 
 
 
 
 
25
  current_dir = Path(__file__).parent
26
  sys.path.append(str(current_dir))
27
+
 
 
 
28
  from modeling.BaseModel import BaseModel
29
  from modeling import build_model
30
  from utilities.arguments import load_opt_from_config_files
 
35
 
36
  import spaces
37
 
38
+
39
  MARKDOWN = """
40
  <div align="center" style="padding: 20px 0;">
41
  <h1 style="font-size: 3em; margin: 0;">
 
138
  "OCT": ["edema"] }
139
 
140
 
141
+ def extract_modality_from_llm(llm_output):
142
+ """Extract modality from LLM output and map it to BIOMEDPARSE_MODES"""
143
+ llm_output = llm_output.lower()
144
+
145
+ # Direct modality mapping
146
+ modality_keywords = {
147
+ 'ct': {
148
+ 'abdomen': 'CT-Abdomen',
149
+ 'chest': 'CT-Chest',
150
+ 'liver': 'CT-Liver'
151
+ },
152
+ 'mri': {
153
+ 'abdomen': 'MRI-Abdomen',
154
+ 'cardiac': 'MRI-Cardiac',
155
+ 'heart': 'MRI-Cardiac',
156
+ 'flair': 'MRI-FLAIR-Brain',
157
+ 't1': 'MRI-T1-Gd-Brain',
158
+ 'contrast': 'MRI-T1-Gd-Brain',
159
+ 'brain': 'MRI-FLAIR-Brain' # default to FLAIR if just "brain" is mentioned
160
+ },
161
+ 'x-ray': {'chest': 'X-Ray-Chest'},
162
+ 'ultrasound': {'cardiac': 'Ultrasound-Cardiac', 'heart': 'Ultrasound-Cardiac'},
163
+ 'endoscopy': {'': 'Endoscopy'},
164
+ 'fundus': {'': 'Fundus'},
165
+ 'dermoscopy': {'': 'Dermoscopy'},
166
+ 'oct': {'': 'OCT'},
167
+ 'pathology': {'': 'Pathology'}
168
+ }
169
+
170
+ for modality, subtypes in modality_keywords.items():
171
+ if modality in llm_output:
172
+ # For modalities with subtypes, try to find the specific subtype
173
+ if subtypes:
174
+ for keyword, specific_modality in subtypes.items():
175
+ if not keyword or keyword in llm_output:
176
+ return specific_modality
177
+ # For modalities without subtypes, return the direct mapping
178
+ return next(iter(subtypes.values()))
179
+
180
+ return None
181
+
182
+ def extract_clinical_findings(llm_output, modality):
183
+ """Extract relevant clinical findings that match available anatomical sites in BIOMEDPARSE_MODES"""
184
+ available_sites = BIOMEDPARSE_MODES.get(modality, [])
185
+ findings = []
186
+
187
+ # Convert sites to lowercase for case-insensitive matching
188
+ sites_lower = {site.lower(): site for site in available_sites}
189
+
190
+ # Look for each available site in the LLM output
191
+ for site_lower, original_site in sites_lower.items():
192
+ if site_lower in llm_output.lower():
193
+ findings.append(original_site)
194
+
195
+ # Add additional findings from MODALITY_PROMPTS if available
196
+ if modality in MODALITY_PROMPTS:
197
+ for prompt in MODALITY_PROMPTS[modality]:
198
+ if prompt.lower() in llm_output.lower() and prompt not in findings:
199
+ findings.append(prompt)
200
+
201
+ return findings
202
+
203
  def on_mode_dropdown_change(selected_mode):
204
  if selected_mode in IMAGE_INFERENCE_MODES:
205
  # Show modality dropdown and hide other inputs initially
 
277
  @spaces.GPU
278
  @torch.inference_mode()
279
  @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
280
+ def process_image(image_path, user_prompt, modality=None):
281
  try:
 
282
  if not image_path:
283
  raise ValueError("Please upload an image")
284
+
 
 
 
 
 
285
  image = read_rgb(image_path)
286
+ pil_image = Image.fromarray(image)
287
+
288
+ # Step 1: Get LLM analysis
289
+ question = f"Analyze this medical image considering the following context: {user_prompt}. Include modality, anatomical structures, and any abnormalities."
290
+ msgs = [{'role': 'user', 'content': [pil_image, question]}]
291
 
292
+ llm_response = ""
293
+ for new_text in llm_model.chat(
294
+ image=pil_image,
295
+ msgs=msgs,
296
+ tokenizer=llm_tokenizer,
297
+ sampling=True,
298
+ temperature=0.95,
299
+ stream=True
300
+ ):
301
+ llm_response += new_text
302
+
303
+ # Step 2: Extract modality from LLM output
304
+ detected_modality = extract_modality_from_llm(llm_response)
305
+ if not detected_modality:
306
+ raise ValueError("Could not determine image modality from LLM output")
307
+
308
+ # Step 3: Extract relevant clinical findings
309
+ clinical_findings = extract_clinical_findings(llm_response, detected_modality)
310
+
311
+ # Step 4: Generate masks for each finding
312
  results = []
313
  analysis_results = []
314
+ colors = [(255,0,0), (0,255,0), (0,0,255), (255,255,0), (255,0,255)] # Different colors for different findings
315
 
316
+ for idx, finding in enumerate(clinical_findings):
317
+ pred_mask = interactive_infer_image(model, pil_image, [finding])[0]
318
+ p_value = check_mask_stats(image, pred_mask * 255, detected_modality, finding)
319
+ analysis_results.append(f"P-value for '{finding}' ({detected_modality}): {p_value:.4f}")
320
 
321
+ # Create colored overlay
322
  overlay_image = image.copy()
323
+ color = colors[idx % len(colors)]
324
+ overlay_image[pred_mask > 0.5] = color
325
  results.append(overlay_image)
326
 
327
+ # Update LLM response with color references
328
+ enhanced_response = llm_response + "\n\nSegmentation Results:\n"
329
+ for idx, finding in enumerate(clinical_findings):
330
+ color_name = ["red", "green", "blue", "yellow", "magenta"][idx % len(colors)]
331
+ enhanced_response += f"- {finding} (shown in {color_name})\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
332
 
333
+ combined_analysis = "\n\n" + "="*50 + "\n"
334
+ combined_analysis += "BiomedParse Analysis:\n"
335
+ combined_analysis += "\n".join(analysis_results)
336
+ combined_analysis += "\n\n" + "="*50 + "\n"
337
+ combined_analysis += "Enhanced LLM Analysis:\n"
338
+ combined_analysis += enhanced_response
339
+ combined_analysis += "\n" + "="*50
340
+
341
+ return results, combined_analysis, detected_modality
342
 
343
  except Exception as e:
344
  error_msg = f"⚠️ An error occurred: {str(e)}"
 
351
  with gr.Row():
352
  with gr.Column():
353
  image_input = gr.Image(type="filepath", label="Input Image")
354
+ prompt_input = gr.Textbox(
355
+ lines=4,
356
+ placeholder="Ask any question about the medical image...",
357
+ label="Question/Prompt"
358
  )
359
+ detected_modality = gr.Textbox(
360
+ label="Detected Modality",
361
+ interactive=False,
362
+ visible=True
363
  )
364
+ submit_btn = gr.Button("Analyze")
365
+
366
  with gr.Column():
367
+ output_gallery = gr.Gallery(
368
+ label="Segmentation Results",
369
+ show_label=True,
370
+ columns=[2],
371
+ height="auto"
372
+ )
373
+ analysis_output = gr.Textbox(
374
+ label="Analysis",
375
  interactive=False,
376
+ show_label=True,
377
+ lines=10
378
  )
379
+
380
+ # Examples section
381
+ gr.Examples(
382
+ examples=IMAGE_PROCESSING_EXAMPLES,
383
+ inputs=[image_input, prompt_input],
384
+ outputs=[output_gallery, analysis_output, detected_modality],
385
+ cache_examples=True,
386
+ )
387
+
388
+ # Connect the submit button to the process_image function
389
  submit_btn.click(
390
  fn=process_image,
391
+ inputs=[image_input, prompt_input],
392
+ outputs=[output_gallery, analysis_output, detected_modality],
393
  api_name="process"
394
  )
395