ayyuce commited on
Commit
3424243
·
verified ·
1 Parent(s): b012996

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +183 -143
app.py CHANGED
@@ -1,170 +1,211 @@
1
- from transformers import AutoModelForCausalLM, AutoProcessor
2
- from PIL import Image
3
- import torch
4
  import gradio as gr
 
5
  import requests
6
  import tempfile
7
- import os
 
 
8
 
9
- MODEL_STATE = {
10
- "model": None,
11
- "processor": None,
12
- "authenticated": False
13
- }
14
 
15
- def login(hf_token):
16
- """Authenticate and load the model"""
17
- try:
18
- MODEL_STATE.update({"model": None, "processor": None, "authenticated": False})
19
-
20
- MODEL_STATE["model"] = AutoModelForCausalLM.from_pretrained(
21
- "microsoft/maira-2",
22
- trust_remote_code=True,
23
- use_auth_token=hf_token
24
- )
25
- MODEL_STATE["processor"] = AutoProcessor.from_pretrained(
26
- "microsoft/maira-2",
27
- trust_remote_code=True,
28
- use_auth_token=hf_token
29
- )
30
-
31
- MODEL_STATE["model"] = MODEL_STATE["model"].eval().to("cpu")
32
- MODEL_STATE["authenticated"] = True
33
-
34
- return "🔓 Login successful! You can now use the model."
35
- except Exception as e:
36
- MODEL_STATE.update({"model": None, "processor": None, "authenticated": False})
37
- return f"❌ Login failed: {str(e)}"
38
 
39
- def get_sample_data():
40
- """Download sample medical images and data"""
41
- frontal_url = "https://openi.nlm.nih.gov/imgs/512/145/145/CXR145_IM-0290-1001.png"
42
- lateral_url = "https://openi.nlm.nih.gov/imgs/512/145/145/CXR145_IM-0290-2001.png"
43
-
44
- def download_image(url):
 
 
45
  response = requests.get(url, headers={"User-Agent": "MAIRA-2"}, stream=True)
46
- return Image.open(response.raw)
47
-
 
 
48
  return {
49
- "frontal": download_image(frontal_url),
50
- "lateral": download_image(lateral_url),
51
  "indication": "Dyspnea.",
52
  "technique": "PA and lateral views of the chest.",
53
  "comparison": "None.",
54
  "phrase": "Pleural effusion."
55
  }
56
 
57
- def save_temp_image(img):
58
- """Save PIL image to temporary file"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
60
  img.save(temp_file.name)
61
  return temp_file.name
62
 
63
  def load_sample_findings():
 
 
 
 
64
  sample = get_sample_data()
65
  return [
66
- save_temp_image(sample["frontal"]),
67
- save_temp_image(sample["lateral"]),
68
  sample["indication"],
69
  sample["technique"],
70
  sample["comparison"],
71
- None, None, None, False
 
 
 
72
  ]
73
 
74
  def load_sample_phrase():
 
 
 
 
75
  sample = get_sample_data()
76
  return [save_temp_image(sample["frontal"]), sample["phrase"]]
77
 
78
- def generate_report(frontal_path, lateral_path, indication, technique, comparison,
79
- prior_frontal_path, prior_lateral_path, prior_report, grounding):
80
- """Generate radiology report with authentication check"""
81
- if not MODEL_STATE["authenticated"]:
82
- return "⚠️ Please authenticate with your Hugging Face token first!"
83
-
84
- try:
85
- current_frontal = Image.open(frontal_path) if frontal_path else None
86
- current_lateral = Image.open(lateral_path) if lateral_path else None
87
- prior_frontal = Image.open(prior_frontal_path) if prior_frontal_path else None
88
- prior_lateral = Image.open(prior_lateral_path) if prior_lateral_path else None
89
-
90
- if not current_frontal or not current_lateral:
91
- return "❌ Missing required current study images"
92
-
93
- prior_report = prior_report or ""
94
-
95
- processed = MODEL_STATE["processor"].format_and_preprocess_reporting_input(
96
- current_frontal=current_frontal,
97
- current_lateral=current_lateral,
98
- prior_frontal=prior_frontal,
99
- prior_lateral=prior_lateral,
100
- indication=indication,
101
- technique=technique,
102
- comparison=comparison,
103
- prior_report=prior_report,
104
- return_tensors="pt",
105
- get_grounding=grounding
106
- ).to("cpu")
107
-
108
- processed = dict(processed)
109
- image_size_keys = [k for k in processed.keys() if "image_sizes" in k]
110
- for k in image_size_keys:
111
- processed.pop(k, None)
112
-
113
- outputs = MODEL_STATE["model"].generate(
114
- **processed,
115
- max_new_tokens=450 if grounding else 300,
116
- use_cache=True
117
- )
118
-
119
- prompt_length = processed["input_ids"].shape[-1]
120
- decoded = MODEL_STATE["processor"].decode(outputs[0][prompt_length:], skip_special_tokens=True)
121
- return MODEL_STATE["processor"].convert_output_to_plaintext_or_grounded_sequence(decoded.lstrip())
122
-
123
- except Exception as e:
124
- return f"❌ Generation error: {str(e)}"
125
-
126
- def ground_phrase(frontal_path, phrase):
127
- """Perform phrase grounding with authentication check"""
128
- if not MODEL_STATE["authenticated"]:
129
- return "⚠️ Please authenticate with your Hugging Face token first!"
130
-
131
- try:
132
- if not frontal_path:
133
- return "❌ Missing frontal view image"
134
-
135
- frontal = Image.open(frontal_path)
136
- processed = MODEL_STATE["processor"].format_and_preprocess_phrase_grounding_input(
137
- frontal_image=frontal,
138
- phrase=phrase,
139
- return_tensors="pt"
140
- ).to("cpu")
141
-
142
- # Convert to regular dict and remove image size related keys
143
- processed = dict(processed)
144
- image_size_keys = [k for k in processed.keys() if "image_sizes" in k]
145
- for k in image_size_keys:
146
- processed.pop(k, None)
147
-
148
- outputs = MODEL_STATE["model"].generate(
149
- **processed,
150
- max_new_tokens=150,
151
- use_cache=True
152
- )
153
-
154
- prompt_length = processed["input_ids"].shape[-1]
155
- decoded = MODEL_STATE["processor"].decode(outputs[0][prompt_length:], skip_special_tokens=True)
156
- return MODEL_STATE["processor"].convert_output_to_plaintext_or_grounded_sequence(decoded)
157
-
158
- except Exception as e:
159
- return f"❌ Grounding error: {str(e)}"
160
 
161
  with gr.Blocks(title="MAIRA-2 Medical Assistant") as demo:
162
- gr.Markdown("""# MAIRA-2 Medical Assistant
163
- **Authentication required** - You need a Hugging Face account and access token to use this model.
164
- 1. Get your access token from [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)
165
- 2. Request model access at [https://huggingface.co/microsoft/maira-2](https://huggingface.co/microsoft/maira-2)
166
- 3. Paste your token below to begin
167
- """)
 
 
 
168
 
169
  with gr.Row():
170
  hf_token = gr.Textbox(
@@ -176,7 +217,7 @@ with gr.Blocks(title="MAIRA-2 Medical Assistant") as demo:
176
  login_status = gr.Textbox(label="Authentication Status", interactive=False)
177
 
178
  login_btn.click(
179
- login,
180
  inputs=hf_token,
181
  outputs=login_status
182
  )
@@ -199,7 +240,6 @@ with gr.Blocks(title="MAIRA-2 Medical Assistant") as demo:
199
 
200
  grounding = gr.Checkbox(label="Include Grounding")
201
  sample_btn = gr.Button("Load Sample Data")
202
-
203
  with gr.Column():
204
  report_output = gr.Textbox(label="Generated Report", lines=10)
205
  generate_btn = gr.Button("Generate Report")
@@ -207,12 +247,12 @@ with gr.Blocks(title="MAIRA-2 Medical Assistant") as demo:
207
  sample_btn.click(
208
  load_sample_findings,
209
  outputs=[frontal, lateral, indication, technique, comparison,
210
- prior_frontal, prior_lateral, prior_report, grounding]
211
  )
212
  generate_btn.click(
213
- generate_report,
214
- inputs=[frontal, lateral, indication, technique, comparison,
215
- prior_frontal, prior_lateral, prior_report, grounding],
216
  outputs=report_output
217
  )
218
 
@@ -231,8 +271,8 @@ with gr.Blocks(title="MAIRA-2 Medical Assistant") as demo:
231
  outputs=[pg_frontal, phrase]
232
  )
233
  pg_btn.click(
234
- ground_phrase,
235
- inputs=[pg_frontal, phrase],
236
  outputs=pg_output
237
  )
238
 
 
 
 
 
1
  import gradio as gr
2
+ import torch
3
  import requests
4
  import tempfile
5
+ from pathlib import Path
6
+ from PIL import Image
7
+ from transformers import AutoModelForCausalLM, AutoProcessor
8
 
9
+ _model_cache = {}
 
 
 
 
10
 
11
+ def load_model_and_processor(hf_token: str):
12
+ """
13
+ Loads the MAIRA-2 model and processor from Hugging Face using the provided token.
14
+ The loaded objects are cached keyed by the token.
15
+ """
16
+ if hf_token in _model_cache:
17
+ return _model_cache[hf_token]
18
+ device = torch.device("cpu")
19
+ model = AutoModelForCausalLM.from_pretrained(
20
+ "microsoft/maira-2",
21
+ trust_remote_code=True,
22
+ use_auth_token=hf_token
23
+ )
24
+ processor = AutoProcessor.from_pretrained(
25
+ "microsoft/maira-2",
26
+ trust_remote_code=True,
27
+ use_auth_token=hf_token
28
+ )
29
+ model.eval()
30
+ model.to(device)
31
+ _model_cache[hf_token] = (model, processor)
32
+ return model, processor
 
33
 
34
+ def get_sample_data() -> dict:
35
+ """
36
+ Download sample chest X-ray images and associated data.
37
+ """
38
+ frontal_image_url = "https://openi.nlm.nih.gov/imgs/512/145/145/CXR145_IM-0290-1001.png"
39
+ lateral_image_url = "https://openi.nlm.nih.gov/imgs/512/145/145/CXR145_IM-0290-2001.png"
40
+
41
+ def download_and_open(url: str) -> Image.Image:
42
  response = requests.get(url, headers={"User-Agent": "MAIRA-2"}, stream=True)
43
+ return Image.open(response.raw).convert("RGB")
44
+
45
+ frontal = download_and_open(frontal_image_url)
46
+ lateral = download_and_open(lateral_image_url)
47
  return {
48
+ "frontal": frontal,
49
+ "lateral": lateral,
50
  "indication": "Dyspnea.",
51
  "technique": "PA and lateral views of the chest.",
52
  "comparison": "None.",
53
  "phrase": "Pleural effusion."
54
  }
55
 
56
+ def generate_report(hf_token, frontal, lateral, indication, technique, comparison, use_grounding):
57
+ """
58
+ Generates a radiology report using the MAIRA-2 model.
59
+ If any image/text input is missing, sample data is used.
60
+ """
61
+ try:
62
+ model, processor = load_model_and_processor(hf_token)
63
+ except Exception as e:
64
+ return f"Error loading model: {str(e)}"
65
+ device = torch.device("cpu")
66
+ sample = get_sample_data()
67
+ if frontal is None:
68
+ frontal = sample["frontal"]
69
+ if lateral is None:
70
+ lateral = sample["lateral"]
71
+ if not indication:
72
+ indication = sample["indication"]
73
+ if not technique:
74
+ technique = sample["technique"]
75
+ if not comparison:
76
+ comparison = sample["comparison"]
77
+
78
+ processed_inputs = processor.format_and_preprocess_reporting_input(
79
+ current_frontal=frontal,
80
+ current_lateral=lateral,
81
+ prior_frontal=None, # No prior study is used in this demo.
82
+ indication=indication,
83
+ technique=technique,
84
+ comparison=comparison,
85
+ prior_report=None,
86
+ return_tensors="pt",
87
+ get_grounding=use_grounding,
88
+ )
89
+ processed_inputs = {k: v.to(device) for k, v in processed_inputs.items()}
90
+ max_tokens = 450 if use_grounding else 300
91
+ with torch.no_grad():
92
+ output_decoding = model.generate(
93
+ **processed_inputs,
94
+ max_new_tokens=max_tokens,
95
+ use_cache=True,
96
+ )
97
+ prompt_length = processed_inputs["input_ids"].shape[-1]
98
+ decoded_text = processor.decode(output_decoding[0][prompt_length:], skip_special_tokens=True)
99
+ decoded_text = decoded_text.lstrip() # Remove any leading whitespace
100
+ prediction = processor.convert_output_to_plaintext_or_grounded_sequence(decoded_text)
101
+ return prediction
102
+
103
+ def run_phrase_grounding(hf_token, frontal, phrase):
104
+ """
105
+ Runs phrase grounding using the MAIRA-2 model.
106
+ If image or phrase is missing, sample data is used.
107
+ """
108
+ try:
109
+ model, processor = load_model_and_processor(hf_token)
110
+ except Exception as e:
111
+ return f"Error loading model: {str(e)}"
112
+ device = torch.device("cpu")
113
+ sample = get_sample_data()
114
+ if frontal is None:
115
+ frontal = sample["frontal"]
116
+ if not phrase:
117
+ phrase = sample["phrase"]
118
+ processed_inputs = processor.format_and_preprocess_phrase_grounding_input(
119
+ frontal_image=frontal,
120
+ phrase=phrase,
121
+ return_tensors="pt",
122
+ )
123
+ processed_inputs = {k: v.to(device) for k, v in processed_inputs.items()}
124
+ with torch.no_grad():
125
+ output_decoding = model.generate(
126
+ **processed_inputs,
127
+ max_new_tokens=150,
128
+ use_cache=True,
129
+ )
130
+ prompt_length = processed_inputs["input_ids"].shape[-1]
131
+ decoded_text = processor.decode(output_decoding[0][prompt_length:], skip_special_tokens=True)
132
+ prediction = processor.convert_output_to_plaintext_or_grounded_sequence(decoded_text)
133
+ return prediction
134
+
135
+ def login_ui(hf_token):
136
+ """Authenticate the user by loading the model."""
137
+ try:
138
+ load_model_and_processor(hf_token)
139
+ return "🔓 Login successful! You can now use the model."
140
+ except Exception as e:
141
+ return f"❌ Login failed: {str(e)}"
142
+
143
+ def generate_report_ui(hf_token, frontal_path, lateral_path, indication, technique, comparison,
144
+ prior_frontal_path, prior_lateral_path, prior_report, grounding):
145
+ """
146
+ Wrapper for generate_report that accepts file paths (from the UI) for images.
147
+ Prior study fields are ignored.
148
+ """
149
+ try:
150
+ frontal = Image.open(frontal_path) if frontal_path else None
151
+ lateral = Image.open(lateral_path) if lateral_path else None
152
+ except Exception as e:
153
+ return f"❌ Error loading images: {str(e)}"
154
+ return generate_report(hf_token, frontal, lateral, indication, technique, comparison, grounding)
155
+
156
+ def run_phrase_grounding_ui(hf_token, frontal_path, phrase):
157
+ """
158
+ Wrapper for run_phrase_grounding that accepts a file path for the frontal image.
159
+ """
160
+ try:
161
+ frontal = Image.open(frontal_path) if frontal_path else None
162
+ except Exception as e:
163
+ return f"❌ Error loading image: {str(e)}"
164
+ return run_phrase_grounding(hf_token, frontal, phrase)
165
+
166
+ def save_temp_image(img: Image.Image) -> str:
167
+ """Save a PIL image to a temporary file and return the file path."""
168
  temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
169
  img.save(temp_file.name)
170
  return temp_file.name
171
 
172
  def load_sample_findings():
173
+ """
174
+ Loads sample data for the report generation tab.
175
+ Returns file paths for current study images, sample text fields, and dummy values for prior study.
176
+ """
177
  sample = get_sample_data()
178
  return [
179
+ save_temp_image(sample["frontal"]), # frontal image file path
180
+ save_temp_image(sample["lateral"]), # lateral image file path
181
  sample["indication"],
182
  sample["technique"],
183
  sample["comparison"],
184
+ None, # prior frontal (not used)
185
+ None, # prior lateral (not used)
186
+ None, # prior report (not used)
187
+ False
188
  ]
189
 
190
  def load_sample_phrase():
191
+ """
192
+ Loads sample data for the phrase grounding tab.
193
+ Returns file path for the frontal image and a sample phrase.
194
+ """
195
  sample = get_sample_data()
196
  return [save_temp_image(sample["frontal"]), sample["phrase"]]
197
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
  with gr.Blocks(title="MAIRA-2 Medical Assistant") as demo:
200
+ gr.Markdown(
201
+ """
202
+ # MAIRA-2 Medical Assistant
203
+ **Authentication required** - You need a Hugging Face account and access token to use this model.
204
+ 1. Get your access token from [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)
205
+ 2. Request model access at [https://huggingface.co/microsoft/maira-2](https://huggingface.co/microsoft/maira-2)
206
+ 3. Paste your token below to begin
207
+ """
208
+ )
209
 
210
  with gr.Row():
211
  hf_token = gr.Textbox(
 
217
  login_status = gr.Textbox(label="Authentication Status", interactive=False)
218
 
219
  login_btn.click(
220
+ login_ui,
221
  inputs=hf_token,
222
  outputs=login_status
223
  )
 
240
 
241
  grounding = gr.Checkbox(label="Include Grounding")
242
  sample_btn = gr.Button("Load Sample Data")
 
243
  with gr.Column():
244
  report_output = gr.Textbox(label="Generated Report", lines=10)
245
  generate_btn = gr.Button("Generate Report")
 
247
  sample_btn.click(
248
  load_sample_findings,
249
  outputs=[frontal, lateral, indication, technique, comparison,
250
+ prior_frontal, prior_lateral, prior_report, grounding]
251
  )
252
  generate_btn.click(
253
+ generate_report_ui,
254
+ inputs=[hf_token, frontal, lateral, indication, technique, comparison,
255
+ prior_frontal, prior_lateral, prior_report, grounding],
256
  outputs=report_output
257
  )
258
 
 
271
  outputs=[pg_frontal, phrase]
272
  )
273
  pg_btn.click(
274
+ run_phrase_grounding_ui,
275
+ inputs=[hf_token, pg_frontal, phrase],
276
  outputs=pg_output
277
  )
278