File size: 8,521 Bytes
2056352
 
 
 
 
 
 
f5c4a8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2056352
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f5c4a8e
2056352
f5c4a8e
 
 
 
2056352
 
 
 
 
 
f5c4a8e
2056352
 
 
 
 
 
 
 
 
 
f5c4a8e
2056352
f5c4a8e
 
 
 
 
2056352
 
f5c4a8e
 
2056352
 
f5c4a8e
2056352
 
f5c4a8e
 
 
 
2056352
 
f5c4a8e
2056352
 
 
f5c4a8e
2056352
f5c4a8e
 
 
 
 
2056352
 
f5c4a8e
 
2056352
 
f5c4a8e
2056352
f5c4a8e
 
 
 
 
 
 
2056352
f5c4a8e
 
 
 
 
 
 
 
 
 
 
 
 
 
2056352
f5c4a8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2056352
f5c4a8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2056352
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
from transformers import AutoModelForCausalLM, AutoProcessor
from PIL import Image
import torch
import gradio as gr
import requests
import tempfile

MODEL_STATE = {
    "model": None,
    "processor": None,
    "authenticated": False
}

def login(hf_token):
    """Authenticate and load the model"""
    try:
        MODEL_STATE.update({"model": None, "processor": None, "authenticated": False})
        
        MODEL_STATE["model"] = AutoModelForCausalLM.from_pretrained(
            "microsoft/maira-2",
            trust_remote_code=True,
            use_auth_token=hf_token
        )
        MODEL_STATE["processor"] = AutoProcessor.from_pretrained(
            "microsoft/maira-2",
            trust_remote_code=True,
            use_auth_token=hf_token
        )
        
        MODEL_STATE["model"] = MODEL_STATE["model"].eval().to("cpu")
        MODEL_STATE["authenticated"] = True
        
        return "🔓 Login successful! You can now use the model."
    except Exception as e:
        MODEL_STATE.update({"model": None, "processor": None, "authenticated": False})
        return f"❌ Login failed: {str(e)}"

def get_sample_data():
    """Download sample medical images and data"""
    frontal_url = "https://openi.nlm.nih.gov/imgs/512/145/145/CXR145_IM-0290-1001.png"
    lateral_url = "https://openi.nlm.nih.gov/imgs/512/145/145/CXR145_IM-0290-2001.png"
    
    def download_image(url):
        response = requests.get(url, headers={"User-Agent": "MAIRA-2"}, stream=True)
        return Image.open(response.raw)
    
    return {
        "frontal": download_image(frontal_url),
        "lateral": download_image(lateral_url),
        "indication": "Dyspnea.",
        "technique": "PA and lateral views of the chest.",
        "comparison": "None.",
        "phrase": "Pleural effusion."
    }

def save_temp_image(img):
    """Save PIL image to temporary file"""
    temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
    img.save(temp_file.name)
    return temp_file.name

def load_sample_findings():
    sample = get_sample_data()
    return [
        save_temp_image(sample["frontal"]),
        save_temp_image(sample["lateral"]),
        sample["indication"],
        sample["technique"],
        sample["comparison"],
        None, None, None, False
    ]

def load_sample_phrase():
    sample = get_sample_data()
    return [save_temp_image(sample["frontal"]), sample["phrase"]]

def generate_report(frontal_path, lateral_path, indication, technique, comparison,
                   prior_frontal_path, prior_lateral_path, prior_report, grounding):
    """Generate radiology report with authentication check"""
    if not MODEL_STATE["authenticated"]:
        return "⚠️ Please authenticate with your Hugging Face token first!"
    
    try:
        current_frontal = Image.open(frontal_path)
        current_lateral = Image.open(lateral_path)
        prior_frontal = Image.open(prior_frontal_path) if prior_frontal_path else None
        prior_lateral = Image.open(prior_lateral_path) if prior_lateral_path else None

        processed = MODEL_STATE["processor"].format_and_preprocess_reporting_input(
            current_frontal=current_frontal,
            current_lateral=current_lateral,
            prior_frontal=prior_frontal,
            prior_lateral=prior_lateral,
            indication=indication,
            technique=technique,
            comparison=comparison,
            prior_report=prior_report or None,
            return_tensors="pt",
            get_grounding=grounding
        ).to("cpu")

        outputs = MODEL_STATE["model"].generate(
            **processed,
            max_new_tokens=450 if grounding else 300,
            use_cache=True
        )
        
        prompt_length = processed["input_ids"].shape[-1]
        decoded = MODEL_STATE["processor"].decode(outputs[0][prompt_length:], skip_special_tokens=True)
        return MODEL_STATE["processor"].convert_output_to_plaintext_or_grounded_sequence(decoded.lstrip())
    
    except Exception as e:
        return f"❌ Generation error: {str(e)}"

def ground_phrase(frontal_path, phrase):
    """Perform phrase grounding with authentication check"""
    if not MODEL_STATE["authenticated"]:
        return "⚠️ Please authenticate with your Hugging Face token first!"
    
    try:
        frontal = Image.open(frontal_path)
        processed = MODEL_STATE["processor"].format_and_preprocess_phrase_grounding_input(
            frontal_image=frontal,
            phrase=phrase,
            return_tensors="pt"
        ).to("cpu")
        
        outputs = MODEL_STATE["model"].generate(
            **processed,
            max_new_tokens=150,
            use_cache=True
        )
        
        prompt_length = processed["input_ids"].shape[-1]
        decoded = MODEL_STATE["processor"].decode(outputs[0][prompt_length:], skip_special_tokens=True)
        return MODEL_STATE["processor"].convert_output_to_plaintext_or_grounded_sequence(decoded)
    
    except Exception as e:
        return f"❌ Grounding error: {str(e)}"

with gr.Blocks(title="MAIRA-2 Medical Assistant") as demo:
    gr.Markdown("""# MAIRA-2 Medical Assistant
    **Authentication required** - You need a Hugging Face account and access token to use this model.
    1. Get your access token from [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)
    2. Request model access at [https://huggingface.co/microsoft/maira-2](https://huggingface.co/microsoft/maira-2)
    3. Paste your token below to begin
    """)
    
    with gr.Row():
        hf_token = gr.Textbox(
            label="Hugging Face Token",
            placeholder="hf_xxxxxxxxxxxxxxxxxxxx",
            type="password"
        )
        login_btn = gr.Button("Authenticate")
        login_status = gr.Textbox(label="Authentication Status", interactive=False)
    
    login_btn.click(
        login,
        inputs=hf_token,
        outputs=login_status
    )
    
    with gr.Tabs():
        with gr.Tab("Report Generation"):
            with gr.Row():
                with gr.Column():
                    gr.Markdown("## Current Study")
                    frontal = gr.Image(label="Frontal View", type="filepath")
                    lateral = gr.Image(label="Lateral View", type="filepath")
                    indication = gr.Textbox(label="Clinical Indication")
                    technique = gr.Textbox(label="Imaging Technique")
                    comparison = gr.Textbox(label="Comparison")
                    
                    gr.Markdown("## Prior Study (Optional)")
                    prior_frontal = gr.Image(label="Prior Frontal View", type="filepath")
                    prior_lateral = gr.Image(label="Prior Lateral View", type="filepath")
                    prior_report = gr.Textbox(label="Prior Report")
                    
                    grounding = gr.Checkbox(label="Include Grounding")
                    sample_btn = gr.Button("Load Sample Data")
                
                with gr.Column():
                    report_output = gr.Textbox(label="Generated Report", lines=10)
                    generate_btn = gr.Button("Generate Report")
            
            sample_btn.click(
                load_sample_findings,
                outputs=[frontal, lateral, indication, technique, comparison,
                       prior_frontal, prior_lateral, prior_report, grounding]
            )
            generate_btn.click(
                generate_report,
                inputs=[frontal, lateral, indication, technique, comparison,
                      prior_frontal, prior_lateral, prior_report, grounding],
                outputs=report_output
            )
        
        with gr.Tab("Phrase Grounding"):
            with gr.Row():
                with gr.Column():
                    pg_frontal = gr.Image(label="Frontal View", type="filepath")
                    phrase = gr.Textbox(label="Phrase to Ground")
                    pg_sample_btn = gr.Button("Load Sample Data")
                with gr.Column():
                    pg_output = gr.Textbox(label="Grounding Result", lines=3)
                    pg_btn = gr.Button("Find Phrase")
            
            pg_sample_btn.click(
                load_sample_phrase,
                outputs=[pg_frontal, phrase]
            )
            pg_btn.click(
                ground_phrase,
                inputs=[pg_frontal, phrase],
                outputs=pg_output
            )

demo.launch()