cvips commited on
Commit
e13af0d
·
1 Parent(s): e7165f7

bomedllamv2 integrated

Browse files
Files changed (1) hide show
  1. app.py +37 -50
app.py CHANGED
@@ -9,16 +9,13 @@ import cv2
9
  import gradio as gr
10
  import numpy as np
11
  import spaces
12
- # import supervision as sv
13
  import torch
14
  from PIL import Image
15
  from tqdm import tqdm
16
- import sys
17
  from pathlib import Path
18
  from huggingface_hub import login
19
  from transformers import AutoModel, AutoTokenizer, BitsAndBytesConfig
20
 
21
-
22
  token = os.getenv("HF_TOKEN")
23
  if token:
24
  login(token=token)
@@ -135,14 +132,14 @@ MODALITY_PROMPTS = {
135
  "Endoscopy": ["neoplastic polyp", "polyp", "non-neoplastic polyp"],
136
  "Fundus": ["optic cup", "optic disc"],
137
  "Dermoscopy": ["lesion", "melanoma"],
138
- "OCT": ["edema"] }
 
139
 
140
 
141
  def extract_modality_from_llm(llm_output):
142
  """Extract modality from LLM output and map it to BIOMEDPARSE_MODES"""
143
  llm_output = llm_output.lower()
144
 
145
- # Direct modality mapping
146
  modality_keywords = {
147
  'ct': {
148
  'abdomen': 'CT-Abdomen',
@@ -156,7 +153,7 @@ def extract_modality_from_llm(llm_output):
156
  'flair': 'MRI-FLAIR-Brain',
157
  't1': 'MRI-T1-Gd-Brain',
158
  'contrast': 'MRI-T1-Gd-Brain',
159
- 'brain': 'MRI-FLAIR-Brain' # default to FLAIR if just "brain" is mentioned
160
  },
161
  'x-ray': {'chest': 'X-Ray-Chest'},
162
  'ultrasound': {'cardiac': 'Ultrasound-Cardiac', 'heart': 'Ultrasound-Cardiac'},
@@ -169,12 +166,9 @@ def extract_modality_from_llm(llm_output):
169
 
170
  for modality, subtypes in modality_keywords.items():
171
  if modality in llm_output:
172
- # For modalities with subtypes, try to find the specific subtype
173
- if subtypes:
174
- for keyword, specific_modality in subtypes.items():
175
- if not keyword or keyword in llm_output:
176
- return specific_modality
177
- # For modalities without subtypes, return the direct mapping
178
  return next(iter(subtypes.values()))
179
 
180
  return None
@@ -202,7 +196,6 @@ def extract_clinical_findings(llm_output, modality):
202
 
203
  def on_mode_dropdown_change(selected_mode):
204
  if selected_mode in IMAGE_INFERENCE_MODES:
205
- # Show modality dropdown and hide other inputs initially
206
  return [
207
  gr.Dropdown(visible=True, choices=list(BIOMEDPARSE_MODES.keys()), label="Modality"),
208
  gr.Dropdown(visible=True, label="Anatomical Site"),
@@ -210,7 +203,6 @@ def on_mode_dropdown_change(selected_mode):
210
  gr.Textbox(visible=False)
211
  ]
212
  else:
213
- # Original behavior for other modes
214
  return [
215
  gr.Dropdown(visible=False),
216
  gr.Dropdown(visible=False),
@@ -223,7 +215,6 @@ def on_modality_change(modality):
223
  return gr.Dropdown(choices=BIOMEDPARSE_MODES[modality], visible=True)
224
  return gr.Dropdown(visible=False)
225
 
226
-
227
  def initialize_model():
228
  opt = load_opt_from_config_files(["configs/biomedparse_inference.yaml"])
229
  pretrained_pth = 'hf_hub:microsoft/BiomedParse'
@@ -238,7 +229,6 @@ def initialize_model():
238
  def initialize_llm():
239
  try:
240
  print("Starting LLM initialization...")
241
- # Add quantization config
242
  quantization_config = BitsAndBytesConfig(
243
  load_in_4bit=True,
244
  bnb_4bit_compute_dtype=torch.float16
@@ -273,7 +263,6 @@ def update_example_prompts(modality):
273
  return f"Example prompts for {modality}:\n" + ", ".join(examples)
274
  return ""
275
 
276
- # Utility functions
277
  @spaces.GPU
278
  @torch.inference_mode()
279
  @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
@@ -285,46 +274,55 @@ def process_image(image_path, user_prompt, modality=None):
285
  image = read_rgb(image_path)
286
  pil_image = Image.fromarray(image)
287
 
288
- # Step 1: Get LLM analysis
289
- question = f"Analyze this medical image considering the following context: {user_prompt}. Include modality, anatomical structures, and any abnormalities."
 
 
290
  msgs = [{'role': 'user', 'content': [pil_image, question]}]
291
 
292
  llm_response = ""
293
- for new_text in llm_model.chat(
294
- image=pil_image,
295
- msgs=msgs,
296
- tokenizer=llm_tokenizer,
297
- sampling=True,
298
- temperature=0.95,
299
- stream=True
300
- ):
301
- llm_response += new_text
302
-
303
- # Step 2: Extract modality from LLM output
 
 
304
  detected_modality = extract_modality_from_llm(llm_response)
305
  if not detected_modality:
306
- raise ValueError("Could not determine image modality from LLM output")
307
-
308
- # Step 3: Extract relevant clinical findings
309
  clinical_findings = extract_clinical_findings(llm_response, detected_modality)
 
 
 
310
 
311
- # Step 4: Generate masks for each finding
312
  results = []
313
  analysis_results = []
314
- colors = [(255,0,0), (0,255,0), (0,0,255), (255,255,0), (255,0,255)] # Different colors for different findings
315
 
316
  for idx, finding in enumerate(clinical_findings):
317
- pred_mask = interactive_infer_image(model, pil_image, [finding])[0]
318
- p_value = check_mask_stats(image, pred_mask * 255, detected_modality, finding)
 
 
 
 
 
319
  analysis_results.append(f"P-value for '{finding}' ({detected_modality}): {p_value:.4f}")
320
 
321
- # Create colored overlay
322
  overlay_image = image.copy()
323
  color = colors[idx % len(colors)]
324
  overlay_image[pred_mask > 0.5] = color
325
  results.append(overlay_image)
326
 
327
- # Update LLM response with color references
328
  enhanced_response = llm_response + "\n\nSegmentation Results:\n"
329
  for idx, finding in enumerate(clinical_findings):
330
  color_name = ["red", "green", "blue", "yellow", "magenta"][idx % len(colors)]
@@ -345,7 +343,6 @@ def process_image(image_path, user_prompt, modality=None):
345
  print(f"Error details: {str(e)}", flush=True)
346
  return None, error_msg
347
 
348
- # Define Gradio interface
349
  with gr.Blocks() as demo:
350
  gr.HTML(MARKDOWN)
351
  with gr.Row():
@@ -376,17 +373,7 @@ with gr.Blocks() as demo:
376
  show_label=True,
377
  lines=10
378
  )
379
-
380
- # Examples section - Fixed version
381
- # gr.Examples(
382
- # examples=IMAGE_PROCESSING_EXAMPLES,
383
- # inputs=[image_input, prompt_input],
384
- # outputs=[output_gallery, analysis_output, detected_modality],
385
- # fn=process_image,
386
- # cache_examples=True,
387
- # )
388
 
389
- # Connect the submit button to the process_image function
390
  submit_btn.click(
391
  fn=process_image,
392
  inputs=[image_input, prompt_input],
 
9
  import gradio as gr
10
  import numpy as np
11
  import spaces
 
12
  import torch
13
  from PIL import Image
14
  from tqdm import tqdm
 
15
  from pathlib import Path
16
  from huggingface_hub import login
17
  from transformers import AutoModel, AutoTokenizer, BitsAndBytesConfig
18
 
 
19
  token = os.getenv("HF_TOKEN")
20
  if token:
21
  login(token=token)
 
132
  "Endoscopy": ["neoplastic polyp", "polyp", "non-neoplastic polyp"],
133
  "Fundus": ["optic cup", "optic disc"],
134
  "Dermoscopy": ["lesion", "melanoma"],
135
+ "OCT": ["edema"]
136
+ }
137
 
138
 
139
  def extract_modality_from_llm(llm_output):
140
  """Extract modality from LLM output and map it to BIOMEDPARSE_MODES"""
141
  llm_output = llm_output.lower()
142
 
 
143
  modality_keywords = {
144
  'ct': {
145
  'abdomen': 'CT-Abdomen',
 
153
  'flair': 'MRI-FLAIR-Brain',
154
  't1': 'MRI-T1-Gd-Brain',
155
  'contrast': 'MRI-T1-Gd-Brain',
156
+ 'brain': 'MRI-FLAIR-Brain'
157
  },
158
  'x-ray': {'chest': 'X-Ray-Chest'},
159
  'ultrasound': {'cardiac': 'Ultrasound-Cardiac', 'heart': 'Ultrasound-Cardiac'},
 
166
 
167
  for modality, subtypes in modality_keywords.items():
168
  if modality in llm_output:
169
+ for keyword, specific_modality in subtypes.items():
170
+ if not keyword or keyword in llm_output:
171
+ return specific_modality
 
 
 
172
  return next(iter(subtypes.values()))
173
 
174
  return None
 
196
 
197
  def on_mode_dropdown_change(selected_mode):
198
  if selected_mode in IMAGE_INFERENCE_MODES:
 
199
  return [
200
  gr.Dropdown(visible=True, choices=list(BIOMEDPARSE_MODES.keys()), label="Modality"),
201
  gr.Dropdown(visible=True, label="Anatomical Site"),
 
203
  gr.Textbox(visible=False)
204
  ]
205
  else:
 
206
  return [
207
  gr.Dropdown(visible=False),
208
  gr.Dropdown(visible=False),
 
215
  return gr.Dropdown(choices=BIOMEDPARSE_MODES[modality], visible=True)
216
  return gr.Dropdown(visible=False)
217
 
 
218
  def initialize_model():
219
  opt = load_opt_from_config_files(["configs/biomedparse_inference.yaml"])
220
  pretrained_pth = 'hf_hub:microsoft/BiomedParse'
 
229
  def initialize_llm():
230
  try:
231
  print("Starting LLM initialization...")
 
232
  quantization_config = BitsAndBytesConfig(
233
  load_in_4bit=True,
234
  bnb_4bit_compute_dtype=torch.float16
 
263
  return f"Example prompts for {modality}:\n" + ", ".join(examples)
264
  return ""
265
 
 
266
  @spaces.GPU
267
  @torch.inference_mode()
268
  @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
 
274
  image = read_rgb(image_path)
275
  pil_image = Image.fromarray(image)
276
 
277
+ question = (
278
+ f"Analyze this medical image considering the following context: {user_prompt}. "
279
+ "Include modality, anatomical structures, and any abnormalities."
280
+ )
281
  msgs = [{'role': 'user', 'content': [pil_image, question]}]
282
 
283
  llm_response = ""
284
+ if llm_model and llm_tokenizer:
285
+ for new_text in llm_model.chat(
286
+ image=pil_image,
287
+ msgs=msgs,
288
+ tokenizer=llm_tokenizer,
289
+ sampling=True,
290
+ temperature=0.95,
291
+ stream=True
292
+ ):
293
+ llm_response += new_text
294
+ else:
295
+ llm_response = "LLM not available. Please check LLM initialization logs."
296
+
297
  detected_modality = extract_modality_from_llm(llm_response)
298
  if not detected_modality:
299
+ # Fallback if modality wasn't detected
300
+ detected_modality = "X-Ray-Chest"
301
+
302
  clinical_findings = extract_clinical_findings(llm_response, detected_modality)
303
+ if not clinical_findings:
304
+ # Fallback if no findings are detected
305
+ clinical_findings = [detected_modality.split("-")[-1]]
306
 
 
307
  results = []
308
  analysis_results = []
309
+ colors = [(255,0,0), (0,255,0), (0,0,255), (255,255,0), (255,0,255)]
310
 
311
  for idx, finding in enumerate(clinical_findings):
312
+ mask_list = interactive_infer_image(model, pil_image, [finding])
313
+ if not mask_list:
314
+ analysis_results.append(f"No mask found for '{finding}'.")
315
+ continue
316
+
317
+ pred_mask = mask_list[0]
318
+ p_value = check_mask_stats(image, pred_mask*255, detected_modality, finding)
319
  analysis_results.append(f"P-value for '{finding}' ({detected_modality}): {p_value:.4f}")
320
 
 
321
  overlay_image = image.copy()
322
  color = colors[idx % len(colors)]
323
  overlay_image[pred_mask > 0.5] = color
324
  results.append(overlay_image)
325
 
 
326
  enhanced_response = llm_response + "\n\nSegmentation Results:\n"
327
  for idx, finding in enumerate(clinical_findings):
328
  color_name = ["red", "green", "blue", "yellow", "magenta"][idx % len(colors)]
 
343
  print(f"Error details: {str(e)}", flush=True)
344
  return None, error_msg
345
 
 
346
  with gr.Blocks() as demo:
347
  gr.HTML(MARKDOWN)
348
  with gr.Row():
 
373
  show_label=True,
374
  lines=10
375
  )
 
 
 
 
 
 
 
 
 
376
 
 
377
  submit_btn.click(
378
  fn=process_image,
379
  inputs=[image_input, prompt_input],