MAIRA-2 / app.py
ayyuce's picture
Update app.py
e2d3fe3 verified
import gradio as gr
import torch
import requests
import tempfile
from pathlib import Path
from PIL import Image
from transformers import AutoModelForCausalLM, AutoProcessor
_model_cache = {}
def load_model_and_processor(hf_token: str):
"""
Loads the MAIRA-2 model and processor from Hugging Face using the provided token.
The loaded objects are cached keyed by the token.
"""
if hf_token in _model_cache:
return _model_cache[hf_token]
device = torch.device("cpu")
model = AutoModelForCausalLM.from_pretrained(
"microsoft/maira-2",
trust_remote_code=True,
use_auth_token=hf_token
)
processor = AutoProcessor.from_pretrained(
"microsoft/maira-2",
trust_remote_code=True,
use_auth_token=hf_token
)
model.eval()
model.to(device)
_model_cache[hf_token] = (model, processor)
return model, processor
def get_sample_data() -> dict:
"""
Downloads sample chest X-ray images and associated data.
"""
frontal_image_url = "https://openi.nlm.nih.gov/imgs/512/145/145/CXR145_IM-0290-1001.png"
lateral_image_url = "https://openi.nlm.nih.gov/imgs/512/145/145/CXR145_IM-0290-2001.png"
def download_and_open(url: str) -> Image.Image:
response = requests.get(url, headers={"User-Agent": "MAIRA-2"}, stream=True)
return Image.open(response.raw).convert("RGB")
frontal = download_and_open(frontal_image_url)
lateral = download_and_open(lateral_image_url)
return {
"frontal": frontal,
"lateral": lateral,
"indication": "Dyspnea.",
"technique": "PA and lateral views of the chest.",
"comparison": "None.",
"phrase": "Pleural effusion."
}
def generate_report(hf_token, frontal, lateral, indication, technique, comparison, use_grounding):
"""
Generates a radiology report using the MAIRA-2 model.
If any image/text input is missing, sample data is used.
"""
try:
model, processor = load_model_and_processor(hf_token)
except Exception as e:
return f"Error loading model: {str(e)}"
device = torch.device("cpu")
sample = get_sample_data()
if frontal is None:
frontal = sample["frontal"]
if lateral is None:
lateral = sample["lateral"]
if not indication:
indication = sample["indication"]
if not technique:
technique = sample["technique"]
if not comparison:
comparison = sample["comparison"]
processed_inputs = processor.format_and_preprocess_reporting_input(
current_frontal=frontal,
current_lateral=lateral,
prior_frontal=None, # No prior study is used in this demo.
indication=indication,
technique=technique,
comparison=comparison,
prior_report=None,
return_tensors="pt",
get_grounding=use_grounding,
)
# Move all tensors to the CPU
processed_inputs = {k: v.to(device) for k, v in processed_inputs.items()}
# Remove keys containing "image_sizes" to prevent unexpected keyword errors.
processed_inputs = dict(processed_inputs)
keys_to_remove = [k for k in processed_inputs if "image_sizes" in k]
for key in keys_to_remove:
processed_inputs.pop(key, None)
max_tokens = 450 if use_grounding else 300
with torch.no_grad():
output_decoding = model.generate(
**processed_inputs,
max_new_tokens=max_tokens,
use_cache=True,
)
prompt_length = processed_inputs["input_ids"].shape[-1]
decoded_text = processor.decode(output_decoding[0][prompt_length:], skip_special_tokens=True)
decoded_text = decoded_text.lstrip() # Remove any leading whitespace
prediction = processor.convert_output_to_plaintext_or_grounded_sequence(decoded_text)
return prediction
def run_phrase_grounding(hf_token, frontal, phrase):
"""
Runs phrase grounding using the MAIRA-2 model.
If image or phrase is missing, sample data is used.
"""
try:
model, processor = load_model_and_processor(hf_token)
except Exception as e:
return f"Error loading model: {str(e)}"
device = torch.device("cpu")
sample = get_sample_data()
if frontal is None:
frontal = sample["frontal"]
if not phrase:
phrase = sample["phrase"]
processed_inputs = processor.format_and_preprocess_phrase_grounding_input(
frontal_image=frontal,
phrase=phrase,
return_tensors="pt",
)
processed_inputs = {k: v.to(device) for k, v in processed_inputs.items()}
# Remove keys containing "image_sizes" to prevent unexpected keyword errors.
processed_inputs = dict(processed_inputs)
keys_to_remove = [k for k in processed_inputs if "image_sizes" in k]
for key in keys_to_remove:
processed_inputs.pop(key, None)
with torch.no_grad():
output_decoding = model.generate(
**processed_inputs,
max_new_tokens=150,
use_cache=True,
)
prompt_length = processed_inputs["input_ids"].shape[-1]
decoded_text = processor.decode(output_decoding[0][prompt_length:], skip_special_tokens=True)
prediction = processor.convert_output_to_plaintext_or_grounded_sequence(decoded_text)
return prediction
def login_ui(hf_token):
"""Authenticate the user by loading the model."""
try:
load_model_and_processor(hf_token)
return "πŸ”“ Login successful! You can now use the model."
except Exception as e:
return f"❌ Login failed: {str(e)}"
def generate_report_ui(hf_token, frontal_path, lateral_path, indication, technique, comparison,
prior_frontal_path, prior_lateral_path, prior_report, grounding):
"""
Wrapper for generate_report that accepts file paths (from the UI) for images.
Prior study fields are ignored.
"""
try:
frontal = Image.open(frontal_path) if frontal_path else None
lateral = Image.open(lateral_path) if lateral_path else None
except Exception as e:
return f"❌ Error loading images: {str(e)}"
return generate_report(hf_token, frontal, lateral, indication, technique, comparison, grounding)
def run_phrase_grounding_ui(hf_token, frontal_path, phrase):
"""
Wrapper for run_phrase_grounding that accepts a file path for the frontal image.
"""
try:
frontal = Image.open(frontal_path) if frontal_path else None
except Exception as e:
return f"❌ Error loading image: {str(e)}"
return run_phrase_grounding(hf_token, frontal, phrase)
def save_temp_image(img: Image.Image) -> str:
"""Save a PIL image to a temporary file and return the file path."""
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
img.save(temp_file.name)
return temp_file.name
def load_sample_findings():
"""
Loads sample data for the report generation tab.
Returns file paths for current study images, sample text fields, and dummy values for prior study.
"""
sample = get_sample_data()
return [
save_temp_image(sample["frontal"]), # frontal image file path
save_temp_image(sample["lateral"]), # lateral image file path
sample["indication"],
sample["technique"],
sample["comparison"],
None, # prior frontal (not used)
None, # prior lateral (not used)
None, # prior report (not used)
False # grounding checkbox default
]
def load_sample_phrase():
"""
Loads sample data for the phrase grounding tab.
Returns file path for the frontal image and a sample phrase.
"""
sample = get_sample_data()
return [save_temp_image(sample["frontal"]), sample["phrase"]]
with gr.Blocks(title="MAIRA-2 Medical Assistant") as demo:
gr.Markdown(
"""
# MAIRA-2 Medical Assistant
**Authentication required** - You need a Hugging Face account and access token to use this model.
1. Get your access token from [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)
2. Request model access at [https://huggingface.co/microsoft/maira-2](https://huggingface.co/microsoft/maira-2)
3. Paste your token below to begin
"""
)
with gr.Row():
hf_token = gr.Textbox(
label="Hugging Face Token",
placeholder="hf_xxxxxxxxxxxxxxxxxxxx",
type="password"
)
login_btn = gr.Button("Authenticate")
login_status = gr.Textbox(label="Authentication Status", interactive=False)
login_btn.click(
login_ui,
inputs=hf_token,
outputs=login_status
)
with gr.Tabs():
with gr.Tab("Report Generation"):
with gr.Row():
with gr.Column():
gr.Markdown("## Current Study")
frontal = gr.Image(label="Frontal View", type="filepath")
lateral = gr.Image(label="Lateral View", type="filepath")
indication = gr.Textbox(label="Clinical Indication")
technique = gr.Textbox(label="Imaging Technique")
comparison = gr.Textbox(label="Comparison")
gr.Markdown("## Prior Study (Optional)")
prior_frontal = gr.Image(label="Prior Frontal View", type="filepath")
prior_lateral = gr.Image(label="Prior Lateral View", type="filepath")
prior_report = gr.Textbox(label="Prior Report")
grounding = gr.Checkbox(label="Include Grounding")
sample_btn = gr.Button("Load Sample Data")
with gr.Column():
report_output = gr.Textbox(label="Generated Report", lines=10)
generate_btn = gr.Button("Generate Report")
sample_btn.click(
load_sample_findings,
outputs=[frontal, lateral, indication, technique, comparison,
prior_frontal, prior_lateral, prior_report, grounding]
)
generate_btn.click(
generate_report_ui,
inputs=[hf_token, frontal, lateral, indication, technique, comparison,
prior_frontal, prior_lateral, prior_report, grounding],
outputs=report_output
)
with gr.Tab("Phrase Grounding"):
with gr.Row():
with gr.Column():
pg_frontal = gr.Image(label="Frontal View", type="filepath")
phrase = gr.Textbox(label="Phrase to Ground")
pg_sample_btn = gr.Button("Load Sample Data")
with gr.Column():
pg_output = gr.Textbox(label="Grounding Result", lines=3)
pg_btn = gr.Button("Find Phrase")
pg_sample_btn.click(
load_sample_phrase,
outputs=[pg_frontal, phrase]
)
pg_btn.click(
run_phrase_grounding_ui,
inputs=[hf_token, pg_frontal, phrase],
outputs=pg_output
)
demo.launch()