Update app.py
Browse files
app.py
CHANGED
|
@@ -4,6 +4,7 @@ import torch
|
|
| 4 |
import gradio as gr
|
| 5 |
import requests
|
| 6 |
import tempfile
|
|
|
|
| 7 |
|
| 8 |
MODEL_STATE = {
|
| 9 |
"model": None,
|
|
@@ -75,20 +76,22 @@ def load_sample_phrase():
|
|
| 75 |
return [save_temp_image(sample["frontal"]), sample["phrase"]]
|
| 76 |
|
| 77 |
def generate_report(frontal_path, lateral_path, indication, technique, comparison,
|
| 78 |
-
|
| 79 |
"""Generate radiology report with authentication check"""
|
| 80 |
if not MODEL_STATE["authenticated"]:
|
| 81 |
return "⚠️ Please authenticate with your Hugging Face token first!"
|
| 82 |
|
| 83 |
-
if not frontal_path or not lateral_path:
|
| 84 |
-
return "❌ Please upload both the frontal and lateral images for the current study."
|
| 85 |
-
|
| 86 |
try:
|
| 87 |
-
current_frontal = Image.open(frontal_path)
|
| 88 |
-
current_lateral = Image.open(lateral_path)
|
| 89 |
prior_frontal = Image.open(prior_frontal_path) if prior_frontal_path else None
|
| 90 |
prior_lateral = Image.open(prior_lateral_path) if prior_lateral_path else None
|
| 91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
processed = MODEL_STATE["processor"].format_and_preprocess_reporting_input(
|
| 93 |
current_frontal=current_frontal,
|
| 94 |
current_lateral=current_lateral,
|
|
@@ -97,20 +100,20 @@ def generate_report(frontal_path, lateral_path, indication, technique, compariso
|
|
| 97 |
indication=indication,
|
| 98 |
technique=technique,
|
| 99 |
comparison=comparison,
|
| 100 |
-
prior_report=prior_report
|
| 101 |
return_tensors="pt",
|
| 102 |
get_grounding=grounding
|
| 103 |
).to("cpu")
|
| 104 |
|
| 105 |
-
processed.
|
| 106 |
|
| 107 |
outputs = MODEL_STATE["model"].generate(
|
| 108 |
-
**
|
| 109 |
max_new_tokens=450 if grounding else 300,
|
| 110 |
use_cache=True
|
| 111 |
)
|
| 112 |
|
| 113 |
-
prompt_length =
|
| 114 |
decoded = MODEL_STATE["processor"].decode(outputs[0][prompt_length:], skip_special_tokens=True)
|
| 115 |
return MODEL_STATE["processor"].convert_output_to_plaintext_or_grounded_sequence(decoded.lstrip())
|
| 116 |
|
|
@@ -122,28 +125,26 @@ def ground_phrase(frontal_path, phrase):
|
|
| 122 |
if not MODEL_STATE["authenticated"]:
|
| 123 |
return "⚠️ Please authenticate with your Hugging Face token first!"
|
| 124 |
|
| 125 |
-
# Check that the required image is provided.
|
| 126 |
-
if not frontal_path:
|
| 127 |
-
return "❌ Please upload the frontal image for phrase grounding."
|
| 128 |
-
|
| 129 |
try:
|
|
|
|
|
|
|
|
|
|
| 130 |
frontal = Image.open(frontal_path)
|
| 131 |
processed = MODEL_STATE["processor"].format_and_preprocess_phrase_grounding_input(
|
| 132 |
frontal_image=frontal,
|
| 133 |
phrase=phrase,
|
| 134 |
return_tensors="pt"
|
| 135 |
).to("cpu")
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
processed.pop("image_sizes", None)
|
| 139 |
|
| 140 |
outputs = MODEL_STATE["model"].generate(
|
| 141 |
-
**
|
| 142 |
max_new_tokens=150,
|
| 143 |
use_cache=True
|
| 144 |
)
|
| 145 |
|
| 146 |
-
prompt_length =
|
| 147 |
decoded = MODEL_STATE["processor"].decode(outputs[0][prompt_length:], skip_special_tokens=True)
|
| 148 |
return MODEL_STATE["processor"].convert_output_to_plaintext_or_grounded_sequence(decoded)
|
| 149 |
|
|
@@ -199,12 +200,12 @@ with gr.Blocks(title="MAIRA-2 Medical Assistant") as demo:
|
|
| 199 |
sample_btn.click(
|
| 200 |
load_sample_findings,
|
| 201 |
outputs=[frontal, lateral, indication, technique, comparison,
|
| 202 |
-
|
| 203 |
)
|
| 204 |
generate_btn.click(
|
| 205 |
generate_report,
|
| 206 |
inputs=[frontal, lateral, indication, technique, comparison,
|
| 207 |
-
|
| 208 |
outputs=report_output
|
| 209 |
)
|
| 210 |
|
|
|
|
| 4 |
import gradio as gr
|
| 5 |
import requests
|
| 6 |
import tempfile
|
| 7 |
+
import os
|
| 8 |
|
| 9 |
MODEL_STATE = {
|
| 10 |
"model": None,
|
|
|
|
| 76 |
return [save_temp_image(sample["frontal"]), sample["phrase"]]
|
| 77 |
|
| 78 |
def generate_report(frontal_path, lateral_path, indication, technique, comparison,
|
| 79 |
+
prior_frontal_path, prior_lateral_path, prior_report, grounding):
|
| 80 |
"""Generate radiology report with authentication check"""
|
| 81 |
if not MODEL_STATE["authenticated"]:
|
| 82 |
return "⚠️ Please authenticate with your Hugging Face token first!"
|
| 83 |
|
|
|
|
|
|
|
|
|
|
| 84 |
try:
|
| 85 |
+
current_frontal = Image.open(frontal_path) if frontal_path else None
|
| 86 |
+
current_lateral = Image.open(lateral_path) if lateral_path else None
|
| 87 |
prior_frontal = Image.open(prior_frontal_path) if prior_frontal_path else None
|
| 88 |
prior_lateral = Image.open(prior_lateral_path) if prior_lateral_path else None
|
| 89 |
|
| 90 |
+
if not current_frontal or not current_lateral:
|
| 91 |
+
return "❌ Missing required current study images"
|
| 92 |
+
|
| 93 |
+
prior_report = prior_report or ""
|
| 94 |
+
|
| 95 |
processed = MODEL_STATE["processor"].format_and_preprocess_reporting_input(
|
| 96 |
current_frontal=current_frontal,
|
| 97 |
current_lateral=current_lateral,
|
|
|
|
| 100 |
indication=indication,
|
| 101 |
technique=technique,
|
| 102 |
comparison=comparison,
|
| 103 |
+
prior_report=prior_report,
|
| 104 |
return_tensors="pt",
|
| 105 |
get_grounding=grounding
|
| 106 |
).to("cpu")
|
| 107 |
|
| 108 |
+
processed_inputs = {k: v for k, v in processed.items() if k != 'image_sizes'}
|
| 109 |
|
| 110 |
outputs = MODEL_STATE["model"].generate(
|
| 111 |
+
**processed_inputs,
|
| 112 |
max_new_tokens=450 if grounding else 300,
|
| 113 |
use_cache=True
|
| 114 |
)
|
| 115 |
|
| 116 |
+
prompt_length = processed_inputs["input_ids"].shape[-1]
|
| 117 |
decoded = MODEL_STATE["processor"].decode(outputs[0][prompt_length:], skip_special_tokens=True)
|
| 118 |
return MODEL_STATE["processor"].convert_output_to_plaintext_or_grounded_sequence(decoded.lstrip())
|
| 119 |
|
|
|
|
| 125 |
if not MODEL_STATE["authenticated"]:
|
| 126 |
return "⚠️ Please authenticate with your Hugging Face token first!"
|
| 127 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
try:
|
| 129 |
+
if not frontal_path:
|
| 130 |
+
return "❌ Missing frontal view image"
|
| 131 |
+
|
| 132 |
frontal = Image.open(frontal_path)
|
| 133 |
processed = MODEL_STATE["processor"].format_and_preprocess_phrase_grounding_input(
|
| 134 |
frontal_image=frontal,
|
| 135 |
phrase=phrase,
|
| 136 |
return_tensors="pt"
|
| 137 |
).to("cpu")
|
| 138 |
+
|
| 139 |
+
processed_inputs = {k: v for k, v in processed.items() if k != 'image_sizes'}
|
|
|
|
| 140 |
|
| 141 |
outputs = MODEL_STATE["model"].generate(
|
| 142 |
+
**processed_inputs,
|
| 143 |
max_new_tokens=150,
|
| 144 |
use_cache=True
|
| 145 |
)
|
| 146 |
|
| 147 |
+
prompt_length = processed_inputs["input_ids"].shape[-1]
|
| 148 |
decoded = MODEL_STATE["processor"].decode(outputs[0][prompt_length:], skip_special_tokens=True)
|
| 149 |
return MODEL_STATE["processor"].convert_output_to_plaintext_or_grounded_sequence(decoded)
|
| 150 |
|
|
|
|
| 200 |
sample_btn.click(
|
| 201 |
load_sample_findings,
|
| 202 |
outputs=[frontal, lateral, indication, technique, comparison,
|
| 203 |
+
prior_frontal, prior_lateral, prior_report, grounding]
|
| 204 |
)
|
| 205 |
generate_btn.click(
|
| 206 |
generate_report,
|
| 207 |
inputs=[frontal, lateral, indication, technique, comparison,
|
| 208 |
+
prior_frontal, prior_lateral, prior_report, grounding],
|
| 209 |
outputs=report_output
|
| 210 |
)
|
| 211 |
|