|
import gradio as gr |
|
import torch |
|
import requests |
|
import tempfile |
|
from pathlib import Path |
|
from PIL import Image |
|
from transformers import AutoModelForCausalLM, AutoProcessor |
|
|
|
_model_cache = {} |
|
|
|
def load_model_and_processor(hf_token: str): |
|
""" |
|
Loads the MAIRA-2 model and processor from Hugging Face using the provided token. |
|
The loaded objects are cached keyed by the token. |
|
""" |
|
if hf_token in _model_cache: |
|
return _model_cache[hf_token] |
|
device = torch.device("cpu") |
|
model = AutoModelForCausalLM.from_pretrained( |
|
"microsoft/maira-2", |
|
trust_remote_code=True, |
|
use_auth_token=hf_token |
|
) |
|
processor = AutoProcessor.from_pretrained( |
|
"microsoft/maira-2", |
|
trust_remote_code=True, |
|
use_auth_token=hf_token |
|
) |
|
model.eval() |
|
model.to(device) |
|
_model_cache[hf_token] = (model, processor) |
|
return model, processor |
|
|
|
def get_sample_data() -> dict: |
|
""" |
|
Downloads sample chest X-ray images and associated data. |
|
""" |
|
frontal_image_url = "https://openi.nlm.nih.gov/imgs/512/145/145/CXR145_IM-0290-1001.png" |
|
lateral_image_url = "https://openi.nlm.nih.gov/imgs/512/145/145/CXR145_IM-0290-2001.png" |
|
|
|
def download_and_open(url: str) -> Image.Image: |
|
response = requests.get(url, headers={"User-Agent": "MAIRA-2"}, stream=True) |
|
return Image.open(response.raw).convert("RGB") |
|
|
|
frontal = download_and_open(frontal_image_url) |
|
lateral = download_and_open(lateral_image_url) |
|
return { |
|
"frontal": frontal, |
|
"lateral": lateral, |
|
"indication": "Dyspnea.", |
|
"technique": "PA and lateral views of the chest.", |
|
"comparison": "None.", |
|
"phrase": "Pleural effusion." |
|
} |
|
|
|
def generate_report(hf_token, frontal, lateral, indication, technique, comparison, use_grounding): |
|
""" |
|
Generates a radiology report using the MAIRA-2 model. |
|
If any image/text input is missing, sample data is used. |
|
""" |
|
try: |
|
model, processor = load_model_and_processor(hf_token) |
|
except Exception as e: |
|
return f"Error loading model: {str(e)}" |
|
device = torch.device("cpu") |
|
sample = get_sample_data() |
|
if frontal is None: |
|
frontal = sample["frontal"] |
|
if lateral is None: |
|
lateral = sample["lateral"] |
|
if not indication: |
|
indication = sample["indication"] |
|
if not technique: |
|
technique = sample["technique"] |
|
if not comparison: |
|
comparison = sample["comparison"] |
|
|
|
processed_inputs = processor.format_and_preprocess_reporting_input( |
|
current_frontal=frontal, |
|
current_lateral=lateral, |
|
prior_frontal=None, |
|
indication=indication, |
|
technique=technique, |
|
comparison=comparison, |
|
prior_report=None, |
|
return_tensors="pt", |
|
get_grounding=use_grounding, |
|
) |
|
|
|
processed_inputs = {k: v.to(device) for k, v in processed_inputs.items()} |
|
|
|
processed_inputs = dict(processed_inputs) |
|
keys_to_remove = [k for k in processed_inputs if "image_sizes" in k] |
|
for key in keys_to_remove: |
|
processed_inputs.pop(key, None) |
|
|
|
max_tokens = 450 if use_grounding else 300 |
|
with torch.no_grad(): |
|
output_decoding = model.generate( |
|
**processed_inputs, |
|
max_new_tokens=max_tokens, |
|
use_cache=True, |
|
) |
|
prompt_length = processed_inputs["input_ids"].shape[-1] |
|
decoded_text = processor.decode(output_decoding[0][prompt_length:], skip_special_tokens=True) |
|
decoded_text = decoded_text.lstrip() |
|
prediction = processor.convert_output_to_plaintext_or_grounded_sequence(decoded_text) |
|
return prediction |
|
|
|
def run_phrase_grounding(hf_token, frontal, phrase): |
|
""" |
|
Runs phrase grounding using the MAIRA-2 model. |
|
If image or phrase is missing, sample data is used. |
|
""" |
|
try: |
|
model, processor = load_model_and_processor(hf_token) |
|
except Exception as e: |
|
return f"Error loading model: {str(e)}" |
|
device = torch.device("cpu") |
|
sample = get_sample_data() |
|
if frontal is None: |
|
frontal = sample["frontal"] |
|
if not phrase: |
|
phrase = sample["phrase"] |
|
processed_inputs = processor.format_and_preprocess_phrase_grounding_input( |
|
frontal_image=frontal, |
|
phrase=phrase, |
|
return_tensors="pt", |
|
) |
|
processed_inputs = {k: v.to(device) for k, v in processed_inputs.items()} |
|
|
|
processed_inputs = dict(processed_inputs) |
|
keys_to_remove = [k for k in processed_inputs if "image_sizes" in k] |
|
for key in keys_to_remove: |
|
processed_inputs.pop(key, None) |
|
|
|
with torch.no_grad(): |
|
output_decoding = model.generate( |
|
**processed_inputs, |
|
max_new_tokens=150, |
|
use_cache=True, |
|
) |
|
prompt_length = processed_inputs["input_ids"].shape[-1] |
|
decoded_text = processor.decode(output_decoding[0][prompt_length:], skip_special_tokens=True) |
|
prediction = processor.convert_output_to_plaintext_or_grounded_sequence(decoded_text) |
|
return prediction |
|
|
|
|
|
def login_ui(hf_token): |
|
"""Authenticate the user by loading the model.""" |
|
try: |
|
load_model_and_processor(hf_token) |
|
return "π Login successful! You can now use the model." |
|
except Exception as e: |
|
return f"β Login failed: {str(e)}" |
|
|
|
def generate_report_ui(hf_token, frontal_path, lateral_path, indication, technique, comparison, |
|
prior_frontal_path, prior_lateral_path, prior_report, grounding): |
|
""" |
|
Wrapper for generate_report that accepts file paths (from the UI) for images. |
|
Prior study fields are ignored. |
|
""" |
|
try: |
|
frontal = Image.open(frontal_path) if frontal_path else None |
|
lateral = Image.open(lateral_path) if lateral_path else None |
|
except Exception as e: |
|
return f"β Error loading images: {str(e)}" |
|
return generate_report(hf_token, frontal, lateral, indication, technique, comparison, grounding) |
|
|
|
def run_phrase_grounding_ui(hf_token, frontal_path, phrase): |
|
""" |
|
Wrapper for run_phrase_grounding that accepts a file path for the frontal image. |
|
""" |
|
try: |
|
frontal = Image.open(frontal_path) if frontal_path else None |
|
except Exception as e: |
|
return f"β Error loading image: {str(e)}" |
|
return run_phrase_grounding(hf_token, frontal, phrase) |
|
|
|
def save_temp_image(img: Image.Image) -> str: |
|
"""Save a PIL image to a temporary file and return the file path.""" |
|
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False) |
|
img.save(temp_file.name) |
|
return temp_file.name |
|
|
|
def load_sample_findings(): |
|
""" |
|
Loads sample data for the report generation tab. |
|
Returns file paths for current study images, sample text fields, and dummy values for prior study. |
|
""" |
|
sample = get_sample_data() |
|
return [ |
|
save_temp_image(sample["frontal"]), |
|
save_temp_image(sample["lateral"]), |
|
sample["indication"], |
|
sample["technique"], |
|
sample["comparison"], |
|
None, |
|
None, |
|
None, |
|
False |
|
] |
|
|
|
def load_sample_phrase(): |
|
""" |
|
Loads sample data for the phrase grounding tab. |
|
Returns file path for the frontal image and a sample phrase. |
|
""" |
|
sample = get_sample_data() |
|
return [save_temp_image(sample["frontal"]), sample["phrase"]] |
|
|
|
|
|
with gr.Blocks(title="MAIRA-2 Medical Assistant") as demo: |
|
gr.Markdown( |
|
""" |
|
# MAIRA-2 Medical Assistant |
|
**Authentication required** - You need a Hugging Face account and access token to use this model. |
|
1. Get your access token from [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens) |
|
2. Request model access at [https://huggingface.co/microsoft/maira-2](https://huggingface.co/microsoft/maira-2) |
|
3. Paste your token below to begin |
|
""" |
|
) |
|
|
|
with gr.Row(): |
|
hf_token = gr.Textbox( |
|
label="Hugging Face Token", |
|
placeholder="hf_xxxxxxxxxxxxxxxxxxxx", |
|
type="password" |
|
) |
|
login_btn = gr.Button("Authenticate") |
|
login_status = gr.Textbox(label="Authentication Status", interactive=False) |
|
|
|
login_btn.click( |
|
login_ui, |
|
inputs=hf_token, |
|
outputs=login_status |
|
) |
|
|
|
with gr.Tabs(): |
|
with gr.Tab("Report Generation"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown("## Current Study") |
|
frontal = gr.Image(label="Frontal View", type="filepath") |
|
lateral = gr.Image(label="Lateral View", type="filepath") |
|
indication = gr.Textbox(label="Clinical Indication") |
|
technique = gr.Textbox(label="Imaging Technique") |
|
comparison = gr.Textbox(label="Comparison") |
|
|
|
gr.Markdown("## Prior Study (Optional)") |
|
prior_frontal = gr.Image(label="Prior Frontal View", type="filepath") |
|
prior_lateral = gr.Image(label="Prior Lateral View", type="filepath") |
|
prior_report = gr.Textbox(label="Prior Report") |
|
|
|
grounding = gr.Checkbox(label="Include Grounding") |
|
sample_btn = gr.Button("Load Sample Data") |
|
with gr.Column(): |
|
report_output = gr.Textbox(label="Generated Report", lines=10) |
|
generate_btn = gr.Button("Generate Report") |
|
|
|
sample_btn.click( |
|
load_sample_findings, |
|
outputs=[frontal, lateral, indication, technique, comparison, |
|
prior_frontal, prior_lateral, prior_report, grounding] |
|
) |
|
generate_btn.click( |
|
generate_report_ui, |
|
inputs=[hf_token, frontal, lateral, indication, technique, comparison, |
|
prior_frontal, prior_lateral, prior_report, grounding], |
|
outputs=report_output |
|
) |
|
|
|
with gr.Tab("Phrase Grounding"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
pg_frontal = gr.Image(label="Frontal View", type="filepath") |
|
phrase = gr.Textbox(label="Phrase to Ground") |
|
pg_sample_btn = gr.Button("Load Sample Data") |
|
with gr.Column(): |
|
pg_output = gr.Textbox(label="Grounding Result", lines=3) |
|
pg_btn = gr.Button("Find Phrase") |
|
|
|
pg_sample_btn.click( |
|
load_sample_phrase, |
|
outputs=[pg_frontal, phrase] |
|
) |
|
pg_btn.click( |
|
run_phrase_grounding_ui, |
|
inputs=[hf_token, pg_frontal, phrase], |
|
outputs=pg_output |
|
) |
|
|
|
demo.launch() |
|
|