integrated biomedllama
Browse files
app.py
CHANGED
@@ -16,31 +16,15 @@ from tqdm import tqdm
|
|
16 |
import sys
|
17 |
from pathlib import Path
|
18 |
from huggingface_hub import login
|
19 |
-
# from dotenv import load_dotenv
|
20 |
from transformers import AutoModel, AutoTokenizer, BitsAndBytesConfig
|
21 |
|
22 |
-
|
23 |
token = os.getenv("HF_TOKEN")
|
24 |
if token:
|
25 |
login(token=token)
|
26 |
-
# Clear Hugging Face cache
|
27 |
-
# cache_dirs = [
|
28 |
-
# "/home/user/.cache/huggingface/",
|
29 |
-
# "/home/user/.cache/torch/",
|
30 |
-
# "/home/user/.cache/pip/"
|
31 |
-
# ]
|
32 |
-
|
33 |
-
# for cache_dir in cache_dirs:
|
34 |
-
# if os.path.exists(cache_dir):
|
35 |
-
# print(f"Clearing cache: {cache_dir}")
|
36 |
-
# shutil.rmtree(cache_dir, ignore_errors=True)
|
37 |
-
# Add the current directory to Python path
|
38 |
current_dir = Path(__file__).parent
|
39 |
sys.path.append(str(current_dir))
|
40 |
-
|
41 |
-
# BIOMEDPARSE_PATH = Path(__file__).parent / "BiomedParse"
|
42 |
-
# sys.path.append(str(BIOMEDPARSE_PATH))
|
43 |
-
# sys.path.append(str(BIOMEDPARSE_PATH / "BiomedParse")) # Add the inner BiomedParse directory
|
44 |
from modeling.BaseModel import BaseModel
|
45 |
from modeling import build_model
|
46 |
from utilities.arguments import load_opt_from_config_files
|
@@ -51,7 +35,7 @@ from inference_utils.processing_utils import read_rgb
|
|
51 |
|
52 |
import spaces
|
53 |
|
54 |
-
|
55 |
MARKDOWN = """
|
56 |
<div align="center" style="padding: 20px 0;">
|
57 |
<h1 style="font-size: 3em; margin: 0;">
|
@@ -154,6 +138,68 @@ MODALITY_PROMPTS = {
|
|
154 |
"OCT": ["edema"] }
|
155 |
|
156 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
157 |
def on_mode_dropdown_change(selected_mode):
|
158 |
if selected_mode in IMAGE_INFERENCE_MODES:
|
159 |
# Show modality dropdown and hide other inputs initially
|
@@ -231,72 +277,68 @@ def update_example_prompts(modality):
|
|
231 |
@spaces.GPU
|
232 |
@torch.inference_mode()
|
233 |
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
|
234 |
-
def process_image(image_path,
|
235 |
try:
|
236 |
-
# Input validation
|
237 |
if not image_path:
|
238 |
raise ValueError("Please upload an image")
|
239 |
-
|
240 |
-
raise ValueError("Please enter prompts for analysis")
|
241 |
-
if not modality:
|
242 |
-
raise ValueError("Please select a modality")
|
243 |
-
|
244 |
-
# Original BiomedParse processing
|
245 |
image = read_rgb(image_path)
|
246 |
-
|
247 |
-
|
|
|
|
|
|
|
248 |
|
249 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
250 |
results = []
|
251 |
analysis_results = []
|
|
|
252 |
|
253 |
-
|
254 |
-
|
255 |
-
p_value = check_mask_stats(image,
|
256 |
-
analysis_results.append(f"P-value for '{
|
257 |
|
|
|
258 |
overlay_image = image.copy()
|
259 |
-
|
|
|
260 |
results.append(overlay_image)
|
261 |
|
262 |
-
#
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
question = 'Give the modality, organ, analysis, abnormalities (if any), treatment (if abnormalities are present)?'
|
268 |
-
msgs = [{'role': 'user', 'content': [pil_image, question]}]
|
269 |
-
|
270 |
-
print("Starting LLM inference...")
|
271 |
-
llm_response = ""
|
272 |
-
for new_text in llm_model.chat(
|
273 |
-
image=pil_image,
|
274 |
-
msgs=msgs,
|
275 |
-
tokenizer=llm_tokenizer,
|
276 |
-
sampling=True,
|
277 |
-
temperature=0.95,
|
278 |
-
stream=True
|
279 |
-
):
|
280 |
-
llm_response += new_text
|
281 |
-
print(f"LLM generated response: {llm_response}")
|
282 |
-
|
283 |
-
# Make the combined analysis more visible
|
284 |
-
combined_analysis = "\n\n" + "="*50 + "\n"
|
285 |
-
combined_analysis += "BiomedParse Analysis:\n"
|
286 |
-
combined_analysis += "\n".join(analysis_results)
|
287 |
-
combined_analysis += "\n\n" + "="*50 + "\n"
|
288 |
-
combined_analysis += "LLM Analysis:\n"
|
289 |
-
combined_analysis += llm_response
|
290 |
-
combined_analysis += "\n" + "="*50
|
291 |
-
|
292 |
-
except Exception as e:
|
293 |
-
print(f"LLM analysis failed with error: {str(e)}")
|
294 |
-
combined_analysis = "\n".join(analysis_results)
|
295 |
-
else:
|
296 |
-
print("LLM model or tokenizer is not available")
|
297 |
-
combined_analysis = "\n".join(analysis_results)
|
298 |
|
299 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
300 |
|
301 |
except Exception as e:
|
302 |
error_msg = f"⚠️ An error occurred: {str(e)}"
|
@@ -309,33 +351,45 @@ with gr.Blocks() as demo:
|
|
309 |
with gr.Row():
|
310 |
with gr.Column():
|
311 |
image_input = gr.Image(type="filepath", label="Input Image")
|
312 |
-
|
313 |
-
lines=
|
314 |
-
placeholder="
|
315 |
-
label="
|
316 |
)
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
)
|
322 |
-
submit_btn = gr.Button("
|
|
|
323 |
with gr.Column():
|
324 |
-
output_gallery = gr.Gallery(
|
325 |
-
|
326 |
-
|
|
|
|
|
|
|
|
|
|
|
327 |
interactive=False,
|
328 |
-
show_label=True
|
|
|
329 |
)
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
|
|
|
|
|
|
|
|
|
|
335 |
submit_btn.click(
|
336 |
fn=process_image,
|
337 |
-
inputs=[image_input,
|
338 |
-
outputs=[output_gallery,
|
339 |
api_name="process"
|
340 |
)
|
341 |
|
|
|
16 |
import sys
|
17 |
from pathlib import Path
|
18 |
from huggingface_hub import login
|
|
|
19 |
from transformers import AutoModel, AutoTokenizer, BitsAndBytesConfig
|
20 |
|
21 |
+
|
22 |
token = os.getenv("HF_TOKEN")
|
23 |
if token:
|
24 |
login(token=token)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
current_dir = Path(__file__).parent
|
26 |
sys.path.append(str(current_dir))
|
27 |
+
|
|
|
|
|
|
|
28 |
from modeling.BaseModel import BaseModel
|
29 |
from modeling import build_model
|
30 |
from utilities.arguments import load_opt_from_config_files
|
|
|
35 |
|
36 |
import spaces
|
37 |
|
38 |
+
|
39 |
MARKDOWN = """
|
40 |
<div align="center" style="padding: 20px 0;">
|
41 |
<h1 style="font-size: 3em; margin: 0;">
|
|
|
138 |
"OCT": ["edema"] }
|
139 |
|
140 |
|
141 |
+
def extract_modality_from_llm(llm_output):
|
142 |
+
"""Extract modality from LLM output and map it to BIOMEDPARSE_MODES"""
|
143 |
+
llm_output = llm_output.lower()
|
144 |
+
|
145 |
+
# Direct modality mapping
|
146 |
+
modality_keywords = {
|
147 |
+
'ct': {
|
148 |
+
'abdomen': 'CT-Abdomen',
|
149 |
+
'chest': 'CT-Chest',
|
150 |
+
'liver': 'CT-Liver'
|
151 |
+
},
|
152 |
+
'mri': {
|
153 |
+
'abdomen': 'MRI-Abdomen',
|
154 |
+
'cardiac': 'MRI-Cardiac',
|
155 |
+
'heart': 'MRI-Cardiac',
|
156 |
+
'flair': 'MRI-FLAIR-Brain',
|
157 |
+
't1': 'MRI-T1-Gd-Brain',
|
158 |
+
'contrast': 'MRI-T1-Gd-Brain',
|
159 |
+
'brain': 'MRI-FLAIR-Brain' # default to FLAIR if just "brain" is mentioned
|
160 |
+
},
|
161 |
+
'x-ray': {'chest': 'X-Ray-Chest'},
|
162 |
+
'ultrasound': {'cardiac': 'Ultrasound-Cardiac', 'heart': 'Ultrasound-Cardiac'},
|
163 |
+
'endoscopy': {'': 'Endoscopy'},
|
164 |
+
'fundus': {'': 'Fundus'},
|
165 |
+
'dermoscopy': {'': 'Dermoscopy'},
|
166 |
+
'oct': {'': 'OCT'},
|
167 |
+
'pathology': {'': 'Pathology'}
|
168 |
+
}
|
169 |
+
|
170 |
+
for modality, subtypes in modality_keywords.items():
|
171 |
+
if modality in llm_output:
|
172 |
+
# For modalities with subtypes, try to find the specific subtype
|
173 |
+
if subtypes:
|
174 |
+
for keyword, specific_modality in subtypes.items():
|
175 |
+
if not keyword or keyword in llm_output:
|
176 |
+
return specific_modality
|
177 |
+
# For modalities without subtypes, return the direct mapping
|
178 |
+
return next(iter(subtypes.values()))
|
179 |
+
|
180 |
+
return None
|
181 |
+
|
182 |
+
def extract_clinical_findings(llm_output, modality):
|
183 |
+
"""Extract relevant clinical findings that match available anatomical sites in BIOMEDPARSE_MODES"""
|
184 |
+
available_sites = BIOMEDPARSE_MODES.get(modality, [])
|
185 |
+
findings = []
|
186 |
+
|
187 |
+
# Convert sites to lowercase for case-insensitive matching
|
188 |
+
sites_lower = {site.lower(): site for site in available_sites}
|
189 |
+
|
190 |
+
# Look for each available site in the LLM output
|
191 |
+
for site_lower, original_site in sites_lower.items():
|
192 |
+
if site_lower in llm_output.lower():
|
193 |
+
findings.append(original_site)
|
194 |
+
|
195 |
+
# Add additional findings from MODALITY_PROMPTS if available
|
196 |
+
if modality in MODALITY_PROMPTS:
|
197 |
+
for prompt in MODALITY_PROMPTS[modality]:
|
198 |
+
if prompt.lower() in llm_output.lower() and prompt not in findings:
|
199 |
+
findings.append(prompt)
|
200 |
+
|
201 |
+
return findings
|
202 |
+
|
203 |
def on_mode_dropdown_change(selected_mode):
|
204 |
if selected_mode in IMAGE_INFERENCE_MODES:
|
205 |
# Show modality dropdown and hide other inputs initially
|
|
|
277 |
@spaces.GPU
|
278 |
@torch.inference_mode()
|
279 |
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
|
280 |
+
def process_image(image_path, user_prompt, modality=None):
|
281 |
try:
|
|
|
282 |
if not image_path:
|
283 |
raise ValueError("Please upload an image")
|
284 |
+
|
|
|
|
|
|
|
|
|
|
|
285 |
image = read_rgb(image_path)
|
286 |
+
pil_image = Image.fromarray(image)
|
287 |
+
|
288 |
+
# Step 1: Get LLM analysis
|
289 |
+
question = f"Analyze this medical image considering the following context: {user_prompt}. Include modality, anatomical structures, and any abnormalities."
|
290 |
+
msgs = [{'role': 'user', 'content': [pil_image, question]}]
|
291 |
|
292 |
+
llm_response = ""
|
293 |
+
for new_text in llm_model.chat(
|
294 |
+
image=pil_image,
|
295 |
+
msgs=msgs,
|
296 |
+
tokenizer=llm_tokenizer,
|
297 |
+
sampling=True,
|
298 |
+
temperature=0.95,
|
299 |
+
stream=True
|
300 |
+
):
|
301 |
+
llm_response += new_text
|
302 |
+
|
303 |
+
# Step 2: Extract modality from LLM output
|
304 |
+
detected_modality = extract_modality_from_llm(llm_response)
|
305 |
+
if not detected_modality:
|
306 |
+
raise ValueError("Could not determine image modality from LLM output")
|
307 |
+
|
308 |
+
# Step 3: Extract relevant clinical findings
|
309 |
+
clinical_findings = extract_clinical_findings(llm_response, detected_modality)
|
310 |
+
|
311 |
+
# Step 4: Generate masks for each finding
|
312 |
results = []
|
313 |
analysis_results = []
|
314 |
+
colors = [(255,0,0), (0,255,0), (0,0,255), (255,255,0), (255,0,255)] # Different colors for different findings
|
315 |
|
316 |
+
for idx, finding in enumerate(clinical_findings):
|
317 |
+
pred_mask = interactive_infer_image(model, pil_image, [finding])[0]
|
318 |
+
p_value = check_mask_stats(image, pred_mask * 255, detected_modality, finding)
|
319 |
+
analysis_results.append(f"P-value for '{finding}' ({detected_modality}): {p_value:.4f}")
|
320 |
|
321 |
+
# Create colored overlay
|
322 |
overlay_image = image.copy()
|
323 |
+
color = colors[idx % len(colors)]
|
324 |
+
overlay_image[pred_mask > 0.5] = color
|
325 |
results.append(overlay_image)
|
326 |
|
327 |
+
# Update LLM response with color references
|
328 |
+
enhanced_response = llm_response + "\n\nSegmentation Results:\n"
|
329 |
+
for idx, finding in enumerate(clinical_findings):
|
330 |
+
color_name = ["red", "green", "blue", "yellow", "magenta"][idx % len(colors)]
|
331 |
+
enhanced_response += f"- {finding} (shown in {color_name})\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
332 |
|
333 |
+
combined_analysis = "\n\n" + "="*50 + "\n"
|
334 |
+
combined_analysis += "BiomedParse Analysis:\n"
|
335 |
+
combined_analysis += "\n".join(analysis_results)
|
336 |
+
combined_analysis += "\n\n" + "="*50 + "\n"
|
337 |
+
combined_analysis += "Enhanced LLM Analysis:\n"
|
338 |
+
combined_analysis += enhanced_response
|
339 |
+
combined_analysis += "\n" + "="*50
|
340 |
+
|
341 |
+
return results, combined_analysis, detected_modality
|
342 |
|
343 |
except Exception as e:
|
344 |
error_msg = f"⚠️ An error occurred: {str(e)}"
|
|
|
351 |
with gr.Row():
|
352 |
with gr.Column():
|
353 |
image_input = gr.Image(type="filepath", label="Input Image")
|
354 |
+
prompt_input = gr.Textbox(
|
355 |
+
lines=4,
|
356 |
+
placeholder="Ask any question about the medical image...",
|
357 |
+
label="Question/Prompt"
|
358 |
)
|
359 |
+
detected_modality = gr.Textbox(
|
360 |
+
label="Detected Modality",
|
361 |
+
interactive=False,
|
362 |
+
visible=True
|
363 |
)
|
364 |
+
submit_btn = gr.Button("Analyze")
|
365 |
+
|
366 |
with gr.Column():
|
367 |
+
output_gallery = gr.Gallery(
|
368 |
+
label="Segmentation Results",
|
369 |
+
show_label=True,
|
370 |
+
columns=[2],
|
371 |
+
height="auto"
|
372 |
+
)
|
373 |
+
analysis_output = gr.Textbox(
|
374 |
+
label="Analysis",
|
375 |
interactive=False,
|
376 |
+
show_label=True,
|
377 |
+
lines=10
|
378 |
)
|
379 |
+
|
380 |
+
# Examples section
|
381 |
+
gr.Examples(
|
382 |
+
examples=IMAGE_PROCESSING_EXAMPLES,
|
383 |
+
inputs=[image_input, prompt_input],
|
384 |
+
outputs=[output_gallery, analysis_output, detected_modality],
|
385 |
+
cache_examples=True,
|
386 |
+
)
|
387 |
+
|
388 |
+
# Connect the submit button to the process_image function
|
389 |
submit_btn.click(
|
390 |
fn=process_image,
|
391 |
+
inputs=[image_input, prompt_input],
|
392 |
+
outputs=[output_gallery, analysis_output, detected_modality],
|
393 |
api_name="process"
|
394 |
)
|
395 |
|