File size: 8,521 Bytes
2056352 f5c4a8e 2056352 f5c4a8e 2056352 f5c4a8e 2056352 f5c4a8e 2056352 f5c4a8e 2056352 f5c4a8e 2056352 f5c4a8e 2056352 f5c4a8e 2056352 f5c4a8e 2056352 f5c4a8e 2056352 f5c4a8e 2056352 f5c4a8e 2056352 f5c4a8e 2056352 f5c4a8e 2056352 f5c4a8e 2056352 f5c4a8e 2056352 f5c4a8e 2056352 f5c4a8e 2056352 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 |
from transformers import AutoModelForCausalLM, AutoProcessor
from PIL import Image
import torch
import gradio as gr
import requests
import tempfile
MODEL_STATE = {
"model": None,
"processor": None,
"authenticated": False
}
def login(hf_token):
"""Authenticate and load the model"""
try:
MODEL_STATE.update({"model": None, "processor": None, "authenticated": False})
MODEL_STATE["model"] = AutoModelForCausalLM.from_pretrained(
"microsoft/maira-2",
trust_remote_code=True,
use_auth_token=hf_token
)
MODEL_STATE["processor"] = AutoProcessor.from_pretrained(
"microsoft/maira-2",
trust_remote_code=True,
use_auth_token=hf_token
)
MODEL_STATE["model"] = MODEL_STATE["model"].eval().to("cpu")
MODEL_STATE["authenticated"] = True
return "🔓 Login successful! You can now use the model."
except Exception as e:
MODEL_STATE.update({"model": None, "processor": None, "authenticated": False})
return f"❌ Login failed: {str(e)}"
def get_sample_data():
"""Download sample medical images and data"""
frontal_url = "https://openi.nlm.nih.gov/imgs/512/145/145/CXR145_IM-0290-1001.png"
lateral_url = "https://openi.nlm.nih.gov/imgs/512/145/145/CXR145_IM-0290-2001.png"
def download_image(url):
response = requests.get(url, headers={"User-Agent": "MAIRA-2"}, stream=True)
return Image.open(response.raw)
return {
"frontal": download_image(frontal_url),
"lateral": download_image(lateral_url),
"indication": "Dyspnea.",
"technique": "PA and lateral views of the chest.",
"comparison": "None.",
"phrase": "Pleural effusion."
}
def save_temp_image(img):
"""Save PIL image to temporary file"""
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
img.save(temp_file.name)
return temp_file.name
def load_sample_findings():
sample = get_sample_data()
return [
save_temp_image(sample["frontal"]),
save_temp_image(sample["lateral"]),
sample["indication"],
sample["technique"],
sample["comparison"],
None, None, None, False
]
def load_sample_phrase():
sample = get_sample_data()
return [save_temp_image(sample["frontal"]), sample["phrase"]]
def generate_report(frontal_path, lateral_path, indication, technique, comparison,
prior_frontal_path, prior_lateral_path, prior_report, grounding):
"""Generate radiology report with authentication check"""
if not MODEL_STATE["authenticated"]:
return "⚠️ Please authenticate with your Hugging Face token first!"
try:
current_frontal = Image.open(frontal_path)
current_lateral = Image.open(lateral_path)
prior_frontal = Image.open(prior_frontal_path) if prior_frontal_path else None
prior_lateral = Image.open(prior_lateral_path) if prior_lateral_path else None
processed = MODEL_STATE["processor"].format_and_preprocess_reporting_input(
current_frontal=current_frontal,
current_lateral=current_lateral,
prior_frontal=prior_frontal,
prior_lateral=prior_lateral,
indication=indication,
technique=technique,
comparison=comparison,
prior_report=prior_report or None,
return_tensors="pt",
get_grounding=grounding
).to("cpu")
outputs = MODEL_STATE["model"].generate(
**processed,
max_new_tokens=450 if grounding else 300,
use_cache=True
)
prompt_length = processed["input_ids"].shape[-1]
decoded = MODEL_STATE["processor"].decode(outputs[0][prompt_length:], skip_special_tokens=True)
return MODEL_STATE["processor"].convert_output_to_plaintext_or_grounded_sequence(decoded.lstrip())
except Exception as e:
return f"❌ Generation error: {str(e)}"
def ground_phrase(frontal_path, phrase):
"""Perform phrase grounding with authentication check"""
if not MODEL_STATE["authenticated"]:
return "⚠️ Please authenticate with your Hugging Face token first!"
try:
frontal = Image.open(frontal_path)
processed = MODEL_STATE["processor"].format_and_preprocess_phrase_grounding_input(
frontal_image=frontal,
phrase=phrase,
return_tensors="pt"
).to("cpu")
outputs = MODEL_STATE["model"].generate(
**processed,
max_new_tokens=150,
use_cache=True
)
prompt_length = processed["input_ids"].shape[-1]
decoded = MODEL_STATE["processor"].decode(outputs[0][prompt_length:], skip_special_tokens=True)
return MODEL_STATE["processor"].convert_output_to_plaintext_or_grounded_sequence(decoded)
except Exception as e:
return f"❌ Grounding error: {str(e)}"
with gr.Blocks(title="MAIRA-2 Medical Assistant") as demo:
gr.Markdown("""# MAIRA-2 Medical Assistant
**Authentication required** - You need a Hugging Face account and access token to use this model.
1. Get your access token from [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)
2. Request model access at [https://huggingface.co/microsoft/maira-2](https://huggingface.co/microsoft/maira-2)
3. Paste your token below to begin
""")
with gr.Row():
hf_token = gr.Textbox(
label="Hugging Face Token",
placeholder="hf_xxxxxxxxxxxxxxxxxxxx",
type="password"
)
login_btn = gr.Button("Authenticate")
login_status = gr.Textbox(label="Authentication Status", interactive=False)
login_btn.click(
login,
inputs=hf_token,
outputs=login_status
)
with gr.Tabs():
with gr.Tab("Report Generation"):
with gr.Row():
with gr.Column():
gr.Markdown("## Current Study")
frontal = gr.Image(label="Frontal View", type="filepath")
lateral = gr.Image(label="Lateral View", type="filepath")
indication = gr.Textbox(label="Clinical Indication")
technique = gr.Textbox(label="Imaging Technique")
comparison = gr.Textbox(label="Comparison")
gr.Markdown("## Prior Study (Optional)")
prior_frontal = gr.Image(label="Prior Frontal View", type="filepath")
prior_lateral = gr.Image(label="Prior Lateral View", type="filepath")
prior_report = gr.Textbox(label="Prior Report")
grounding = gr.Checkbox(label="Include Grounding")
sample_btn = gr.Button("Load Sample Data")
with gr.Column():
report_output = gr.Textbox(label="Generated Report", lines=10)
generate_btn = gr.Button("Generate Report")
sample_btn.click(
load_sample_findings,
outputs=[frontal, lateral, indication, technique, comparison,
prior_frontal, prior_lateral, prior_report, grounding]
)
generate_btn.click(
generate_report,
inputs=[frontal, lateral, indication, technique, comparison,
prior_frontal, prior_lateral, prior_report, grounding],
outputs=report_output
)
with gr.Tab("Phrase Grounding"):
with gr.Row():
with gr.Column():
pg_frontal = gr.Image(label="Frontal View", type="filepath")
phrase = gr.Textbox(label="Phrase to Ground")
pg_sample_btn = gr.Button("Load Sample Data")
with gr.Column():
pg_output = gr.Textbox(label="Grounding Result", lines=3)
pg_btn = gr.Button("Find Phrase")
pg_sample_btn.click(
load_sample_phrase,
outputs=[pg_frontal, phrase]
)
pg_btn.click(
ground_phrase,
inputs=[pg_frontal, phrase],
outputs=pg_output
)
demo.launch() |