File size: 6,606 Bytes
2056352 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
from transformers import AutoModelForCausalLM, AutoProcessor
from PIL import Image
import torch
import gradio as gr
import requests
import tempfile
device = torch.device("cpu")
model = AutoModelForCausalLM.from_pretrained("microsoft/maira-2", trust_remote_code=True)
processor = AutoProcessor.from_pretrained("microsoft/maira-2", trust_remote_code=True)
model = model.eval().to(device)
def get_sample_data():
"""Download sample medical images and data"""
frontal_url = "https://openi.nlm.nih.gov/imgs/512/145/145/CXR145_IM-0290-1001.png"
lateral_url = "https://openi.nlm.nih.gov/imgs/512/145/145/CXR145_IM-0290-2001.png"
def download_image(url):
response = requests.get(url, headers={"User-Agent": "MAIRA-2"}, stream=True)
return Image.open(response.raw)
return {
"frontal": download_image(frontal_url),
"lateral": download_image(lateral_url),
"indication": "Dyspnea.",
"technique": "PA and lateral views of the chest.",
"comparison": "None.",
"phrase": "Pleural effusion."
}
def save_temp_image(img):
"""Save PIL image to temporary file"""
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
img.save(temp_file.name)
return temp_file.name
def load_sample_findings():
"""Load sample data for findings generation"""
sample = get_sample_data()
return [
save_temp_image(sample["frontal"]),
save_temp_image(sample["lateral"]),
sample["indication"],
sample["technique"],
sample["comparison"],
None, None, None, False
]
def load_sample_phrase():
"""Load sample data for phrase grounding"""
sample = get_sample_data()
return [save_temp_image(sample["frontal"]), sample["phrase"]]
def generate_report(frontal_path, lateral_path, indication, technique, comparison,
prior_frontal_path, prior_lateral_path, prior_report, grounding):
"""Generate radiology report with optional grounding"""
try:
# Load images
current_frontal = Image.open(frontal_path)
current_lateral = Image.open(lateral_path)
prior_frontal = Image.open(prior_frontal_path) if prior_frontal_path else None
prior_lateral = Image.open(prior_lateral_path) if prior_lateral_path else None
# Process inputs
processed = processor.format_and_preprocess_reporting_input(
current_frontal=current_frontal,
current_lateral=current_lateral,
prior_frontal=prior_frontal,
prior_lateral=prior_lateral,
indication=indication,
technique=technique,
comparison=comparison,
prior_report=prior_report or None,
return_tensors="pt",
get_grounding=grounding
).to(device)
# Generate report
outputs = model.generate(**processed,
max_new_tokens=450 if grounding else 300,
use_cache=True)
# Decode and format
prompt_length = processed["input_ids"].shape[-1]
decoded = processor.decode(outputs[0][prompt_length:], skip_special_tokens=True)
return processor.convert_output_to_plaintext_or_grounded_sequence(decoded.lstrip())
except Exception as e:
return f"Error: {str(e)}"
def ground_phrase(frontal_path, phrase):
"""Perform phrase grounding on image"""
try:
frontal = Image.open(frontal_path)
processed = processor.format_and_preprocess_phrase_grounding_input(
frontal_image=frontal,
phrase=phrase,
return_tensors="pt"
).to(device)
outputs = model.generate(**processed, max_new_tokens=150, use_cache=True)
prompt_length = processed["input_ids"].shape[-1]
decoded = processor.decode(outputs[0][prompt_length:], skip_special_tokens=True)
return processor.convert_output_to_plaintext_or_grounded_sequence(decoded)
except Exception as e:
return f"Error: {str(e)}"
# Gradio UI
with gr.Blocks(title="MAIRA-2 Medical Imaging Assistant") as demo:
gr.Markdown("# MAIRA-2 Medical Imaging Assistant\nAI-powered radiology report generation and phrase grounding")
with gr.Tab("Report Generation"):
with gr.Row():
with gr.Column():
gr.Markdown("## Current Study")
frontal = gr.Image(label="Frontal View", type="filepath")
lateral = gr.Image(label="Lateral View", type="filepath")
indication = gr.Textbox(label="Clinical Indication")
technique = gr.Textbox(label="Imaging Technique")
comparison = gr.Textbox(label="Comparison")
gr.Markdown("## Prior Study (Optional)")
prior_frontal = gr.Image(label="Prior Frontal View", type="filepath")
prior_lateral = gr.Image(label="Prior Lateral View", type="filepath")
prior_report = gr.Textbox(label="Prior Report")
grounding = gr.Checkbox(label="Include Grounding")
sample_btn = gr.Button("Load Sample Data")
with gr.Column():
report_output = gr.Textbox(label="Generated Report", lines=10)
generate_btn = gr.Button("Generate Report")
sample_btn.click(load_sample_findings,
outputs=[frontal, lateral, indication, technique, comparison,
prior_frontal, prior_lateral, prior_report, grounding])
generate_btn.click(generate_report,
inputs=[frontal, lateral, indication, technique, comparison,
prior_frontal, prior_lateral, prior_report, grounding],
outputs=report_output)
with gr.Tab("Phrase Grounding"):
with gr.Row():
with gr.Column():
pg_frontal = gr.Image(label="Frontal View", type="filepath")
phrase = gr.Textbox(label="Phrase to Ground")
pg_sample_btn = gr.Button("Load Sample Data")
with gr.Column():
pg_output = gr.Textbox(label="Grounding Result", lines=3)
pg_btn = gr.Button("Find Phrase")
pg_sample_btn.click(load_sample_phrase,
outputs=[pg_frontal, phrase])
pg_btn.click(ground_phrase,
inputs=[pg_frontal, phrase],
outputs=pg_output)
demo.launch() |