Update app.py
Browse files
app.py
CHANGED
@@ -4,6 +4,7 @@ import torch
|
|
4 |
import gradio as gr
|
5 |
import requests
|
6 |
import tempfile
|
|
|
7 |
|
8 |
MODEL_STATE = {
|
9 |
"model": None,
|
@@ -75,20 +76,22 @@ def load_sample_phrase():
|
|
75 |
return [save_temp_image(sample["frontal"]), sample["phrase"]]
|
76 |
|
77 |
def generate_report(frontal_path, lateral_path, indication, technique, comparison,
|
78 |
-
|
79 |
"""Generate radiology report with authentication check"""
|
80 |
if not MODEL_STATE["authenticated"]:
|
81 |
return "⚠️ Please authenticate with your Hugging Face token first!"
|
82 |
|
83 |
-
if not frontal_path or not lateral_path:
|
84 |
-
return "❌ Please upload both the frontal and lateral images for the current study."
|
85 |
-
|
86 |
try:
|
87 |
-
current_frontal = Image.open(frontal_path)
|
88 |
-
current_lateral = Image.open(lateral_path)
|
89 |
prior_frontal = Image.open(prior_frontal_path) if prior_frontal_path else None
|
90 |
prior_lateral = Image.open(prior_lateral_path) if prior_lateral_path else None
|
91 |
|
|
|
|
|
|
|
|
|
|
|
92 |
processed = MODEL_STATE["processor"].format_and_preprocess_reporting_input(
|
93 |
current_frontal=current_frontal,
|
94 |
current_lateral=current_lateral,
|
@@ -97,20 +100,20 @@ def generate_report(frontal_path, lateral_path, indication, technique, compariso
|
|
97 |
indication=indication,
|
98 |
technique=technique,
|
99 |
comparison=comparison,
|
100 |
-
prior_report=prior_report
|
101 |
return_tensors="pt",
|
102 |
get_grounding=grounding
|
103 |
).to("cpu")
|
104 |
|
105 |
-
processed.
|
106 |
|
107 |
outputs = MODEL_STATE["model"].generate(
|
108 |
-
**
|
109 |
max_new_tokens=450 if grounding else 300,
|
110 |
use_cache=True
|
111 |
)
|
112 |
|
113 |
-
prompt_length =
|
114 |
decoded = MODEL_STATE["processor"].decode(outputs[0][prompt_length:], skip_special_tokens=True)
|
115 |
return MODEL_STATE["processor"].convert_output_to_plaintext_or_grounded_sequence(decoded.lstrip())
|
116 |
|
@@ -122,28 +125,26 @@ def ground_phrase(frontal_path, phrase):
|
|
122 |
if not MODEL_STATE["authenticated"]:
|
123 |
return "⚠️ Please authenticate with your Hugging Face token first!"
|
124 |
|
125 |
-
# Check that the required image is provided.
|
126 |
-
if not frontal_path:
|
127 |
-
return "❌ Please upload the frontal image for phrase grounding."
|
128 |
-
|
129 |
try:
|
|
|
|
|
|
|
130 |
frontal = Image.open(frontal_path)
|
131 |
processed = MODEL_STATE["processor"].format_and_preprocess_phrase_grounding_input(
|
132 |
frontal_image=frontal,
|
133 |
phrase=phrase,
|
134 |
return_tensors="pt"
|
135 |
).to("cpu")
|
136 |
-
|
137 |
-
|
138 |
-
processed.pop("image_sizes", None)
|
139 |
|
140 |
outputs = MODEL_STATE["model"].generate(
|
141 |
-
**
|
142 |
max_new_tokens=150,
|
143 |
use_cache=True
|
144 |
)
|
145 |
|
146 |
-
prompt_length =
|
147 |
decoded = MODEL_STATE["processor"].decode(outputs[0][prompt_length:], skip_special_tokens=True)
|
148 |
return MODEL_STATE["processor"].convert_output_to_plaintext_or_grounded_sequence(decoded)
|
149 |
|
@@ -199,12 +200,12 @@ with gr.Blocks(title="MAIRA-2 Medical Assistant") as demo:
|
|
199 |
sample_btn.click(
|
200 |
load_sample_findings,
|
201 |
outputs=[frontal, lateral, indication, technique, comparison,
|
202 |
-
|
203 |
)
|
204 |
generate_btn.click(
|
205 |
generate_report,
|
206 |
inputs=[frontal, lateral, indication, technique, comparison,
|
207 |
-
|
208 |
outputs=report_output
|
209 |
)
|
210 |
|
|
|
4 |
import gradio as gr
|
5 |
import requests
|
6 |
import tempfile
|
7 |
+
import os
|
8 |
|
9 |
MODEL_STATE = {
|
10 |
"model": None,
|
|
|
76 |
return [save_temp_image(sample["frontal"]), sample["phrase"]]
|
77 |
|
78 |
def generate_report(frontal_path, lateral_path, indication, technique, comparison,
|
79 |
+
prior_frontal_path, prior_lateral_path, prior_report, grounding):
|
80 |
"""Generate radiology report with authentication check"""
|
81 |
if not MODEL_STATE["authenticated"]:
|
82 |
return "⚠️ Please authenticate with your Hugging Face token first!"
|
83 |
|
|
|
|
|
|
|
84 |
try:
|
85 |
+
current_frontal = Image.open(frontal_path) if frontal_path else None
|
86 |
+
current_lateral = Image.open(lateral_path) if lateral_path else None
|
87 |
prior_frontal = Image.open(prior_frontal_path) if prior_frontal_path else None
|
88 |
prior_lateral = Image.open(prior_lateral_path) if prior_lateral_path else None
|
89 |
|
90 |
+
if not current_frontal or not current_lateral:
|
91 |
+
return "❌ Missing required current study images"
|
92 |
+
|
93 |
+
prior_report = prior_report or ""
|
94 |
+
|
95 |
processed = MODEL_STATE["processor"].format_and_preprocess_reporting_input(
|
96 |
current_frontal=current_frontal,
|
97 |
current_lateral=current_lateral,
|
|
|
100 |
indication=indication,
|
101 |
technique=technique,
|
102 |
comparison=comparison,
|
103 |
+
prior_report=prior_report,
|
104 |
return_tensors="pt",
|
105 |
get_grounding=grounding
|
106 |
).to("cpu")
|
107 |
|
108 |
+
processed_inputs = {k: v for k, v in processed.items() if k != 'image_sizes'}
|
109 |
|
110 |
outputs = MODEL_STATE["model"].generate(
|
111 |
+
**processed_inputs,
|
112 |
max_new_tokens=450 if grounding else 300,
|
113 |
use_cache=True
|
114 |
)
|
115 |
|
116 |
+
prompt_length = processed_inputs["input_ids"].shape[-1]
|
117 |
decoded = MODEL_STATE["processor"].decode(outputs[0][prompt_length:], skip_special_tokens=True)
|
118 |
return MODEL_STATE["processor"].convert_output_to_plaintext_or_grounded_sequence(decoded.lstrip())
|
119 |
|
|
|
125 |
if not MODEL_STATE["authenticated"]:
|
126 |
return "⚠️ Please authenticate with your Hugging Face token first!"
|
127 |
|
|
|
|
|
|
|
|
|
128 |
try:
|
129 |
+
if not frontal_path:
|
130 |
+
return "❌ Missing frontal view image"
|
131 |
+
|
132 |
frontal = Image.open(frontal_path)
|
133 |
processed = MODEL_STATE["processor"].format_and_preprocess_phrase_grounding_input(
|
134 |
frontal_image=frontal,
|
135 |
phrase=phrase,
|
136 |
return_tensors="pt"
|
137 |
).to("cpu")
|
138 |
+
|
139 |
+
processed_inputs = {k: v for k, v in processed.items() if k != 'image_sizes'}
|
|
|
140 |
|
141 |
outputs = MODEL_STATE["model"].generate(
|
142 |
+
**processed_inputs,
|
143 |
max_new_tokens=150,
|
144 |
use_cache=True
|
145 |
)
|
146 |
|
147 |
+
prompt_length = processed_inputs["input_ids"].shape[-1]
|
148 |
decoded = MODEL_STATE["processor"].decode(outputs[0][prompt_length:], skip_special_tokens=True)
|
149 |
return MODEL_STATE["processor"].convert_output_to_plaintext_or_grounded_sequence(decoded)
|
150 |
|
|
|
200 |
sample_btn.click(
|
201 |
load_sample_findings,
|
202 |
outputs=[frontal, lateral, indication, technique, comparison,
|
203 |
+
prior_frontal, prior_lateral, prior_report, grounding]
|
204 |
)
|
205 |
generate_btn.click(
|
206 |
generate_report,
|
207 |
inputs=[frontal, lateral, indication, technique, comparison,
|
208 |
+
prior_frontal, prior_lateral, prior_report, grounding],
|
209 |
outputs=report_output
|
210 |
)
|
211 |
|