File size: 6,606 Bytes
2056352
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
from transformers import AutoModelForCausalLM, AutoProcessor
from PIL import Image
import torch
import gradio as gr
import requests
import tempfile

device = torch.device("cpu")
model = AutoModelForCausalLM.from_pretrained("microsoft/maira-2", trust_remote_code=True)
processor = AutoProcessor.from_pretrained("microsoft/maira-2", trust_remote_code=True)
model = model.eval().to(device)

def get_sample_data():
    """Download sample medical images and data"""
    frontal_url = "https://openi.nlm.nih.gov/imgs/512/145/145/CXR145_IM-0290-1001.png"
    lateral_url = "https://openi.nlm.nih.gov/imgs/512/145/145/CXR145_IM-0290-2001.png"
    
    def download_image(url):
        response = requests.get(url, headers={"User-Agent": "MAIRA-2"}, stream=True)
        return Image.open(response.raw)
    
    return {
        "frontal": download_image(frontal_url),
        "lateral": download_image(lateral_url),
        "indication": "Dyspnea.",
        "technique": "PA and lateral views of the chest.",
        "comparison": "None.",
        "phrase": "Pleural effusion."
    }

def save_temp_image(img):
    """Save PIL image to temporary file"""
    temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
    img.save(temp_file.name)
    return temp_file.name

def load_sample_findings():
    """Load sample data for findings generation"""
    sample = get_sample_data()
    return [
        save_temp_image(sample["frontal"]),
        save_temp_image(sample["lateral"]),
        sample["indication"],
        sample["technique"],
        sample["comparison"],
        None, None, None, False
    ]

def load_sample_phrase():
    """Load sample data for phrase grounding"""
    sample = get_sample_data()
    return [save_temp_image(sample["frontal"]), sample["phrase"]]

def generate_report(frontal_path, lateral_path, indication, technique, comparison, 
                   prior_frontal_path, prior_lateral_path, prior_report, grounding):
    """Generate radiology report with optional grounding"""
    try:
        # Load images
        current_frontal = Image.open(frontal_path)
        current_lateral = Image.open(lateral_path)
        prior_frontal = Image.open(prior_frontal_path) if prior_frontal_path else None
        prior_lateral = Image.open(prior_lateral_path) if prior_lateral_path else None

        # Process inputs
        processed = processor.format_and_preprocess_reporting_input(
            current_frontal=current_frontal,
            current_lateral=current_lateral,
            prior_frontal=prior_frontal,
            prior_lateral=prior_lateral,
            indication=indication,
            technique=technique,
            comparison=comparison,
            prior_report=prior_report or None,
            return_tensors="pt",
            get_grounding=grounding
        ).to(device)

        # Generate report
        outputs = model.generate(**processed, 
                               max_new_tokens=450 if grounding else 300,
                               use_cache=True)
        
        # Decode and format
        prompt_length = processed["input_ids"].shape[-1]
        decoded = processor.decode(outputs[0][prompt_length:], skip_special_tokens=True)
        return processor.convert_output_to_plaintext_or_grounded_sequence(decoded.lstrip())
    
    except Exception as e:
        return f"Error: {str(e)}"

def ground_phrase(frontal_path, phrase):
    """Perform phrase grounding on image"""
    try:
        frontal = Image.open(frontal_path)
        processed = processor.format_and_preprocess_phrase_grounding_input(
            frontal_image=frontal,
            phrase=phrase,
            return_tensors="pt"
        ).to(device)
        
        outputs = model.generate(**processed, max_new_tokens=150, use_cache=True)
        
        prompt_length = processed["input_ids"].shape[-1]
        decoded = processor.decode(outputs[0][prompt_length:], skip_special_tokens=True)
        return processor.convert_output_to_plaintext_or_grounded_sequence(decoded)
    
    except Exception as e:
        return f"Error: {str(e)}"

# Gradio UI
with gr.Blocks(title="MAIRA-2 Medical Imaging Assistant") as demo:
    gr.Markdown("# MAIRA-2 Medical Imaging Assistant\nAI-powered radiology report generation and phrase grounding")
    
    with gr.Tab("Report Generation"):
        with gr.Row():
            with gr.Column():
                gr.Markdown("## Current Study")
                frontal = gr.Image(label="Frontal View", type="filepath")
                lateral = gr.Image(label="Lateral View", type="filepath")
                indication = gr.Textbox(label="Clinical Indication")
                technique = gr.Textbox(label="Imaging Technique")
                comparison = gr.Textbox(label="Comparison")
                
                gr.Markdown("## Prior Study (Optional)")
                prior_frontal = gr.Image(label="Prior Frontal View", type="filepath")
                prior_lateral = gr.Image(label="Prior Lateral View", type="filepath")
                prior_report = gr.Textbox(label="Prior Report")
                
                grounding = gr.Checkbox(label="Include Grounding")
                sample_btn = gr.Button("Load Sample Data")
                
            with gr.Column():
                report_output = gr.Textbox(label="Generated Report", lines=10)
                generate_btn = gr.Button("Generate Report")
        
        sample_btn.click(load_sample_findings,
                        outputs=[frontal, lateral, indication, technique, comparison,
                               prior_frontal, prior_lateral, prior_report, grounding])
        generate_btn.click(generate_report,
                         inputs=[frontal, lateral, indication, technique, comparison,
                               prior_frontal, prior_lateral, prior_report, grounding],
                         outputs=report_output)
    
    with gr.Tab("Phrase Grounding"):
        with gr.Row():
            with gr.Column():
                pg_frontal = gr.Image(label="Frontal View", type="filepath")
                phrase = gr.Textbox(label="Phrase to Ground")
                pg_sample_btn = gr.Button("Load Sample Data")
            with gr.Column():
                pg_output = gr.Textbox(label="Grounding Result", lines=3)
                pg_btn = gr.Button("Find Phrase")
        
        pg_sample_btn.click(load_sample_phrase,
                           outputs=[pg_frontal, phrase])
        pg_btn.click(ground_phrase,
                    inputs=[pg_frontal, phrase],
                    outputs=pg_output)

demo.launch()