Update app.py
Browse files
app.py
CHANGED
@@ -17,13 +17,13 @@ def load_model_and_processor(hf_token: str):
|
|
17 |
return _model_cache[hf_token]
|
18 |
device = torch.device("cpu")
|
19 |
model = AutoModelForCausalLM.from_pretrained(
|
20 |
-
"microsoft/maira-2",
|
21 |
-
trust_remote_code=True,
|
22 |
use_auth_token=hf_token
|
23 |
)
|
24 |
processor = AutoProcessor.from_pretrained(
|
25 |
-
"microsoft/maira-2",
|
26 |
-
trust_remote_code=True,
|
27 |
use_auth_token=hf_token
|
28 |
)
|
29 |
model.eval()
|
@@ -33,7 +33,7 @@ def load_model_and_processor(hf_token: str):
|
|
33 |
|
34 |
def get_sample_data() -> dict:
|
35 |
"""
|
36 |
-
|
37 |
"""
|
38 |
frontal_image_url = "https://openi.nlm.nih.gov/imgs/512/145/145/CXR145_IM-0290-1001.png"
|
39 |
lateral_image_url = "https://openi.nlm.nih.gov/imgs/512/145/145/CXR145_IM-0290-2001.png"
|
@@ -86,7 +86,14 @@ def generate_report(hf_token, frontal, lateral, indication, technique, compariso
|
|
86 |
return_tensors="pt",
|
87 |
get_grounding=use_grounding,
|
88 |
)
|
|
|
89 |
processed_inputs = {k: v.to(device) for k, v in processed_inputs.items()}
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
max_tokens = 450 if use_grounding else 300
|
91 |
with torch.no_grad():
|
92 |
output_decoding = model.generate(
|
@@ -121,6 +128,12 @@ def run_phrase_grounding(hf_token, frontal, phrase):
|
|
121 |
return_tensors="pt",
|
122 |
)
|
123 |
processed_inputs = {k: v.to(device) for k, v in processed_inputs.items()}
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
with torch.no_grad():
|
125 |
output_decoding = model.generate(
|
126 |
**processed_inputs,
|
@@ -132,6 +145,7 @@ def run_phrase_grounding(hf_token, frontal, phrase):
|
|
132 |
prediction = processor.convert_output_to_plaintext_or_grounded_sequence(decoded_text)
|
133 |
return prediction
|
134 |
|
|
|
135 |
def login_ui(hf_token):
|
136 |
"""Authenticate the user by loading the model."""
|
137 |
try:
|
@@ -177,14 +191,14 @@ def load_sample_findings():
|
|
177 |
sample = get_sample_data()
|
178 |
return [
|
179 |
save_temp_image(sample["frontal"]), # frontal image file path
|
180 |
-
save_temp_image(sample["lateral"]),
|
181 |
sample["indication"],
|
182 |
sample["technique"],
|
183 |
sample["comparison"],
|
184 |
None, # prior frontal (not used)
|
185 |
None, # prior lateral (not used)
|
186 |
None, # prior report (not used)
|
187 |
-
False
|
188 |
]
|
189 |
|
190 |
def load_sample_phrase():
|
@@ -276,4 +290,4 @@ with gr.Blocks(title="MAIRA-2 Medical Assistant") as demo:
|
|
276 |
outputs=pg_output
|
277 |
)
|
278 |
|
279 |
-
demo.launch()
|
|
|
17 |
return _model_cache[hf_token]
|
18 |
device = torch.device("cpu")
|
19 |
model = AutoModelForCausalLM.from_pretrained(
|
20 |
+
"microsoft/maira-2",
|
21 |
+
trust_remote_code=True,
|
22 |
use_auth_token=hf_token
|
23 |
)
|
24 |
processor = AutoProcessor.from_pretrained(
|
25 |
+
"microsoft/maira-2",
|
26 |
+
trust_remote_code=True,
|
27 |
use_auth_token=hf_token
|
28 |
)
|
29 |
model.eval()
|
|
|
33 |
|
34 |
def get_sample_data() -> dict:
|
35 |
"""
|
36 |
+
Downloads sample chest X-ray images and associated data.
|
37 |
"""
|
38 |
frontal_image_url = "https://openi.nlm.nih.gov/imgs/512/145/145/CXR145_IM-0290-1001.png"
|
39 |
lateral_image_url = "https://openi.nlm.nih.gov/imgs/512/145/145/CXR145_IM-0290-2001.png"
|
|
|
86 |
return_tensors="pt",
|
87 |
get_grounding=use_grounding,
|
88 |
)
|
89 |
+
# Move all tensors to the CPU
|
90 |
processed_inputs = {k: v.to(device) for k, v in processed_inputs.items()}
|
91 |
+
# Remove keys containing "image_sizes" to prevent unexpected keyword errors.
|
92 |
+
processed_inputs = dict(processed_inputs)
|
93 |
+
keys_to_remove = [k for k in processed_inputs if "image_sizes" in k]
|
94 |
+
for key in keys_to_remove:
|
95 |
+
processed_inputs.pop(key, None)
|
96 |
+
|
97 |
max_tokens = 450 if use_grounding else 300
|
98 |
with torch.no_grad():
|
99 |
output_decoding = model.generate(
|
|
|
128 |
return_tensors="pt",
|
129 |
)
|
130 |
processed_inputs = {k: v.to(device) for k, v in processed_inputs.items()}
|
131 |
+
# Remove keys containing "image_sizes" to prevent unexpected keyword errors.
|
132 |
+
processed_inputs = dict(processed_inputs)
|
133 |
+
keys_to_remove = [k for k in processed_inputs if "image_sizes" in k]
|
134 |
+
for key in keys_to_remove:
|
135 |
+
processed_inputs.pop(key, None)
|
136 |
+
|
137 |
with torch.no_grad():
|
138 |
output_decoding = model.generate(
|
139 |
**processed_inputs,
|
|
|
145 |
prediction = processor.convert_output_to_plaintext_or_grounded_sequence(decoded_text)
|
146 |
return prediction
|
147 |
|
148 |
+
|
149 |
def login_ui(hf_token):
|
150 |
"""Authenticate the user by loading the model."""
|
151 |
try:
|
|
|
191 |
sample = get_sample_data()
|
192 |
return [
|
193 |
save_temp_image(sample["frontal"]), # frontal image file path
|
194 |
+
save_temp_image(sample["lateral"]), # lateral image file path
|
195 |
sample["indication"],
|
196 |
sample["technique"],
|
197 |
sample["comparison"],
|
198 |
None, # prior frontal (not used)
|
199 |
None, # prior lateral (not used)
|
200 |
None, # prior report (not used)
|
201 |
+
False # grounding checkbox default
|
202 |
]
|
203 |
|
204 |
def load_sample_phrase():
|
|
|
290 |
outputs=pg_output
|
291 |
)
|
292 |
|
293 |
+
demo.launch()
|