image / app.py
mgbam's picture
Update app.py
aa563db verified
raw
history blame
5.62 kB
import gradio as gr
import torch
import numpy as np
from transformers import AutoModelForCausalLM
from janus.models import VLChatProcessor
from PIL import Image
import spaces
# Medical Imaging Analysis Configuration
MEDICAL_CONFIG = {
"echo_guidelines": "ASE 2023 Standards",
"histo_guidelines": "CAP Protocols 2024",
"cardiac_params": ["LVEF", "E/A Ratio", "Wall Motion"],
"histo_params": ["Nuclear Atypia", "Mitotic Count", "Stromal Invasion"]
}
# Initialize Medical Imaging Model
model_path = "deepseek-ai/Janus-Pro-1B"
class MedicalImagingAdapter(torch.nn.Module):
def __init__(self, base_model):
super().__init__()
self.base_model = base_model
# Cardiac-specific projections
self.cardiac_proj = torch.nn.Linear(2048, 2048)
# Histopathology-specific projections
self.histo_proj = torch.nn.Linear(2048, 2048)
def forward(self, *args, **kwargs):
outputs = self.base_model(*args, **kwargs)
return outputs
vl_gpt = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
vl_gpt.language_model = MedicalImagingAdapter(vl_gpt.language_model)
if torch.cuda.is_available():
vl_gpt = vl_gpt.to(torch.bfloat16).cuda()
vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
# Medical Image Processing Pipelines
def preprocess_echo(image):
"""Process echocardiography images"""
img = Image.fromarray(image).convert('L') # Grayscale
return np.array(img.resize((512, 512)))
def preprocess_histo(image):
"""Process histopathology slides"""
img = Image.fromarray(image)
return np.array(img.resize((1024, 1024)))
@torch.inference_mode()
@spaces.GPU(duration=120)
def analyze_medical_case(image, clinical_context, modality):
# Preprocess based on modality
processed_img = preprocess_echo(image) if modality == "Echo" else preprocess_histo(image)
# Create modality-specific prompt
system_prompt = f"""
Analyze this {modality} image following {MEDICAL_CONFIG['echo_guidelines' if modality=='Echo' else 'histo_guidelines']}.
Clinical Context: {clinical_context}
"""
conversation = [{
"role": "<|Radiologist|>" if modality == "Echo" else "<|Pathologist|>",
"content": system_prompt,
"images": [processed_img],
}, {"role": "<|AI_Assistant|>", "content": ""}]
inputs = vl_chat_processor(
conversations=conversation,
images=[Image.fromarray(processed_img)],
force_batchify=True
).to(vl_gpt.device)
outputs = vl_gpt.generate(
inputs_embeds=vl_gpt.prepare_inputs_embeds(**inputs),
attention_mask=inputs.attention_mask,
max_new_tokens=512,
temperature=0.1,
top_p=0.9,
repetition_penalty=1.5
)
report = vl_chat_processor.tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
return format_medical_report(report, modality)
def format_medical_report(text, modality):
# Structure report based on modality
sections = {
"Echo": [
("Chamber Dimensions", "LVEDD", "LVESD"),
("Valvular Function", "Aortic Valve", "Mitral Valve"),
("Hemodynamics", "E/A Ratio", "LVEF")
],
"Histo": [
("Architecture", "Gland Formation", "Stromal Pattern"),
("Cellular Features", "Nuclear Atypia", "Mitotic Count"),
("Diagnostic Impression", "Tumor Grade", "Margin Status")
]
}
formatted = f"**{modality} Analysis Report**\n\n"
for section in sections[modality]:
header = section[0]
formatted += f"### {header}\n"
for sub in section[1:]:
if sub in text:
start = text.find(sub)
end = text.find("\n\n", start)
formatted += f"- **{sub}:** {text[start+len(sub)+1:end].strip()}\n"
return formatted
# Medical Imaging Interface
with gr.Blocks(title="Cardiac & Histopathology AI", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
## Medical Imaging Analysis Platform
*Analyzes echocardiograms and histopathology slides - Research Use Only*
""")
with gr.Row():
with gr.Column():
image_input = gr.Image(label="Upload Medical Image")
modality_select = gr.Radio(
["Echo", "Histo"],
label="Image Modality",
info="Select 'Echo' for cardiac ultrasound, 'Histo' for biopsy slides"
)
clinical_input = gr.Textbox(
label="Clinical Context",
placeholder="e.g., 'Assess LV function' or 'Evaluate for malignancy'"
)
analyze_btn = gr.Button("Analyze Case", variant="primary")
with gr.Column():
report_output = gr.Markdown(label="AI Clinical Report")
# Preloaded examples
gr.Examples(
examples=[
["Evaluate LV systolic function", "case1.png", "Echo"],
["Assess mitral valve function", "case2.jpg", "Echo"],
["Analyze for malignant features", "case3.png", "Histo"],
["Evaluate tumor margins", "case4.png", "Histo"]
],
inputs=[clinical_input, image_input, modality_select],
label="Example Medical Cases"
)
@demo.func
def analyze_and_display(image, clinical_context, modality):
report = analyze_medical_case(image, clinical_context, modality)
return report
analyze_btn.click(analyze_and_display, [image_input, clinical_input, modality_select], report_output)
demo.launch(share=True)