import gradio as gr import torch import requests import tempfile from pathlib import Path from PIL import Image from transformers import AutoModelForCausalLM, AutoProcessor _model_cache = {} def load_model_and_processor(hf_token: str): """ Loads the MAIRA-2 model and processor from Hugging Face using the provided token. The loaded objects are cached keyed by the token. """ if hf_token in _model_cache: return _model_cache[hf_token] device = torch.device("cpu") model = AutoModelForCausalLM.from_pretrained( "microsoft/maira-2", trust_remote_code=True, use_auth_token=hf_token ) processor = AutoProcessor.from_pretrained( "microsoft/maira-2", trust_remote_code=True, use_auth_token=hf_token ) model.eval() model.to(device) _model_cache[hf_token] = (model, processor) return model, processor def get_sample_data() -> dict: """ Downloads sample chest X-ray images and associated data. """ frontal_image_url = "https://openi.nlm.nih.gov/imgs/512/145/145/CXR145_IM-0290-1001.png" lateral_image_url = "https://openi.nlm.nih.gov/imgs/512/145/145/CXR145_IM-0290-2001.png" def download_and_open(url: str) -> Image.Image: response = requests.get(url, headers={"User-Agent": "MAIRA-2"}, stream=True) return Image.open(response.raw).convert("RGB") frontal = download_and_open(frontal_image_url) lateral = download_and_open(lateral_image_url) return { "frontal": frontal, "lateral": lateral, "indication": "Dyspnea.", "technique": "PA and lateral views of the chest.", "comparison": "None.", "phrase": "Pleural effusion." } def generate_report(hf_token, frontal, lateral, indication, technique, comparison, use_grounding): """ Generates a radiology report using the MAIRA-2 model. If any image/text input is missing, sample data is used. """ try: model, processor = load_model_and_processor(hf_token) except Exception as e: return f"Error loading model: {str(e)}" device = torch.device("cpu") sample = get_sample_data() if frontal is None: frontal = sample["frontal"] if lateral is None: lateral = sample["lateral"] if not indication: indication = sample["indication"] if not technique: technique = sample["technique"] if not comparison: comparison = sample["comparison"] processed_inputs = processor.format_and_preprocess_reporting_input( current_frontal=frontal, current_lateral=lateral, prior_frontal=None, # No prior study is used in this demo. indication=indication, technique=technique, comparison=comparison, prior_report=None, return_tensors="pt", get_grounding=use_grounding, ) # Move all tensors to the CPU processed_inputs = {k: v.to(device) for k, v in processed_inputs.items()} # Remove keys containing "image_sizes" to prevent unexpected keyword errors. processed_inputs = dict(processed_inputs) keys_to_remove = [k for k in processed_inputs if "image_sizes" in k] for key in keys_to_remove: processed_inputs.pop(key, None) max_tokens = 450 if use_grounding else 300 with torch.no_grad(): output_decoding = model.generate( **processed_inputs, max_new_tokens=max_tokens, use_cache=True, ) prompt_length = processed_inputs["input_ids"].shape[-1] decoded_text = processor.decode(output_decoding[0][prompt_length:], skip_special_tokens=True) decoded_text = decoded_text.lstrip() # Remove any leading whitespace prediction = processor.convert_output_to_plaintext_or_grounded_sequence(decoded_text) return prediction def run_phrase_grounding(hf_token, frontal, phrase): """ Runs phrase grounding using the MAIRA-2 model. If image or phrase is missing, sample data is used. """ try: model, processor = load_model_and_processor(hf_token) except Exception as e: return f"Error loading model: {str(e)}" device = torch.device("cpu") sample = get_sample_data() if frontal is None: frontal = sample["frontal"] if not phrase: phrase = sample["phrase"] processed_inputs = processor.format_and_preprocess_phrase_grounding_input( frontal_image=frontal, phrase=phrase, return_tensors="pt", ) processed_inputs = {k: v.to(device) for k, v in processed_inputs.items()} # Remove keys containing "image_sizes" to prevent unexpected keyword errors. processed_inputs = dict(processed_inputs) keys_to_remove = [k for k in processed_inputs if "image_sizes" in k] for key in keys_to_remove: processed_inputs.pop(key, None) with torch.no_grad(): output_decoding = model.generate( **processed_inputs, max_new_tokens=150, use_cache=True, ) prompt_length = processed_inputs["input_ids"].shape[-1] decoded_text = processor.decode(output_decoding[0][prompt_length:], skip_special_tokens=True) prediction = processor.convert_output_to_plaintext_or_grounded_sequence(decoded_text) return prediction def login_ui(hf_token): """Authenticate the user by loading the model.""" try: load_model_and_processor(hf_token) return "🔓 Login successful! You can now use the model." except Exception as e: return f"❌ Login failed: {str(e)}" def generate_report_ui(hf_token, frontal_path, lateral_path, indication, technique, comparison, prior_frontal_path, prior_lateral_path, prior_report, grounding): """ Wrapper for generate_report that accepts file paths (from the UI) for images. Prior study fields are ignored. """ try: frontal = Image.open(frontal_path) if frontal_path else None lateral = Image.open(lateral_path) if lateral_path else None except Exception as e: return f"❌ Error loading images: {str(e)}" return generate_report(hf_token, frontal, lateral, indication, technique, comparison, grounding) def run_phrase_grounding_ui(hf_token, frontal_path, phrase): """ Wrapper for run_phrase_grounding that accepts a file path for the frontal image. """ try: frontal = Image.open(frontal_path) if frontal_path else None except Exception as e: return f"❌ Error loading image: {str(e)}" return run_phrase_grounding(hf_token, frontal, phrase) def save_temp_image(img: Image.Image) -> str: """Save a PIL image to a temporary file and return the file path.""" temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False) img.save(temp_file.name) return temp_file.name def load_sample_findings(): """ Loads sample data for the report generation tab. Returns file paths for current study images, sample text fields, and dummy values for prior study. """ sample = get_sample_data() return [ save_temp_image(sample["frontal"]), # frontal image file path save_temp_image(sample["lateral"]), # lateral image file path sample["indication"], sample["technique"], sample["comparison"], None, # prior frontal (not used) None, # prior lateral (not used) None, # prior report (not used) False # grounding checkbox default ] def load_sample_phrase(): """ Loads sample data for the phrase grounding tab. Returns file path for the frontal image and a sample phrase. """ sample = get_sample_data() return [save_temp_image(sample["frontal"]), sample["phrase"]] with gr.Blocks(title="MAIRA-2 Medical Assistant") as demo: gr.Markdown( """ # MAIRA-2 Medical Assistant **Authentication required** - You need a Hugging Face account and access token to use this model. 1. Get your access token from [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens) 2. Request model access at [https://huggingface.co/microsoft/maira-2](https://huggingface.co/microsoft/maira-2) 3. Paste your token below to begin """ ) with gr.Row(): hf_token = gr.Textbox( label="Hugging Face Token", placeholder="hf_xxxxxxxxxxxxxxxxxxxx", type="password" ) login_btn = gr.Button("Authenticate") login_status = gr.Textbox(label="Authentication Status", interactive=False) login_btn.click( login_ui, inputs=hf_token, outputs=login_status ) with gr.Tabs(): with gr.Tab("Report Generation"): with gr.Row(): with gr.Column(): gr.Markdown("## Current Study") frontal = gr.Image(label="Frontal View", type="filepath") lateral = gr.Image(label="Lateral View", type="filepath") indication = gr.Textbox(label="Clinical Indication") technique = gr.Textbox(label="Imaging Technique") comparison = gr.Textbox(label="Comparison") gr.Markdown("## Prior Study (Optional)") prior_frontal = gr.Image(label="Prior Frontal View", type="filepath") prior_lateral = gr.Image(label="Prior Lateral View", type="filepath") prior_report = gr.Textbox(label="Prior Report") grounding = gr.Checkbox(label="Include Grounding") sample_btn = gr.Button("Load Sample Data") with gr.Column(): report_output = gr.Textbox(label="Generated Report", lines=10) generate_btn = gr.Button("Generate Report") sample_btn.click( load_sample_findings, outputs=[frontal, lateral, indication, technique, comparison, prior_frontal, prior_lateral, prior_report, grounding] ) generate_btn.click( generate_report_ui, inputs=[hf_token, frontal, lateral, indication, technique, comparison, prior_frontal, prior_lateral, prior_report, grounding], outputs=report_output ) with gr.Tab("Phrase Grounding"): with gr.Row(): with gr.Column(): pg_frontal = gr.Image(label="Frontal View", type="filepath") phrase = gr.Textbox(label="Phrase to Ground") pg_sample_btn = gr.Button("Load Sample Data") with gr.Column(): pg_output = gr.Textbox(label="Grounding Result", lines=3) pg_btn = gr.Button("Find Phrase") pg_sample_btn.click( load_sample_phrase, outputs=[pg_frontal, phrase] ) pg_btn.click( run_phrase_grounding_ui, inputs=[hf_token, pg_frontal, phrase], outputs=pg_output ) demo.launch()