ayyuce commited on
Commit
f5c4a8e
·
verified ·
1 Parent(s): 2056352

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +132 -73
app.py CHANGED
@@ -5,10 +5,35 @@ import gradio as gr
5
  import requests
6
  import tempfile
7
 
8
- device = torch.device("cpu")
9
- model = AutoModelForCausalLM.from_pretrained("microsoft/maira-2", trust_remote_code=True)
10
- processor = AutoProcessor.from_pretrained("microsoft/maira-2", trust_remote_code=True)
11
- model = model.eval().to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  def get_sample_data():
14
  """Download sample medical images and data"""
@@ -35,7 +60,6 @@ def save_temp_image(img):
35
  return temp_file.name
36
 
37
  def load_sample_findings():
38
- """Load sample data for findings generation"""
39
  sample = get_sample_data()
40
  return [
41
  save_temp_image(sample["frontal"]),
@@ -47,22 +71,22 @@ def load_sample_findings():
47
  ]
48
 
49
  def load_sample_phrase():
50
- """Load sample data for phrase grounding"""
51
  sample = get_sample_data()
52
  return [save_temp_image(sample["frontal"]), sample["phrase"]]
53
 
54
- def generate_report(frontal_path, lateral_path, indication, technique, comparison,
55
  prior_frontal_path, prior_lateral_path, prior_report, grounding):
56
- """Generate radiology report with optional grounding"""
 
 
 
57
  try:
58
- # Load images
59
  current_frontal = Image.open(frontal_path)
60
  current_lateral = Image.open(lateral_path)
61
  prior_frontal = Image.open(prior_frontal_path) if prior_frontal_path else None
62
  prior_lateral = Image.open(prior_lateral_path) if prior_lateral_path else None
63
 
64
- # Process inputs
65
- processed = processor.format_and_preprocess_reporting_input(
66
  current_frontal=current_frontal,
67
  current_lateral=current_lateral,
68
  prior_frontal=prior_frontal,
@@ -73,88 +97,123 @@ def generate_report(frontal_path, lateral_path, indication, technique, compariso
73
  prior_report=prior_report or None,
74
  return_tensors="pt",
75
  get_grounding=grounding
76
- ).to(device)
77
 
78
- # Generate report
79
- outputs = model.generate(**processed,
80
- max_new_tokens=450 if grounding else 300,
81
- use_cache=True)
 
82
 
83
- # Decode and format
84
  prompt_length = processed["input_ids"].shape[-1]
85
- decoded = processor.decode(outputs[0][prompt_length:], skip_special_tokens=True)
86
- return processor.convert_output_to_plaintext_or_grounded_sequence(decoded.lstrip())
87
 
88
  except Exception as e:
89
- return f"Error: {str(e)}"
90
 
91
  def ground_phrase(frontal_path, phrase):
92
- """Perform phrase grounding on image"""
 
 
 
93
  try:
94
  frontal = Image.open(frontal_path)
95
- processed = processor.format_and_preprocess_phrase_grounding_input(
96
  frontal_image=frontal,
97
  phrase=phrase,
98
  return_tensors="pt"
99
- ).to(device)
100
 
101
- outputs = model.generate(**processed, max_new_tokens=150, use_cache=True)
 
 
 
 
102
 
103
  prompt_length = processed["input_ids"].shape[-1]
104
- decoded = processor.decode(outputs[0][prompt_length:], skip_special_tokens=True)
105
- return processor.convert_output_to_plaintext_or_grounded_sequence(decoded)
106
 
107
  except Exception as e:
108
- return f"Error: {str(e)}"
109
 
110
- # Gradio UI
111
- with gr.Blocks(title="MAIRA-2 Medical Imaging Assistant") as demo:
112
- gr.Markdown("# MAIRA-2 Medical Imaging Assistant\nAI-powered radiology report generation and phrase grounding")
 
 
 
 
113
 
114
- with gr.Tab("Report Generation"):
115
- with gr.Row():
116
- with gr.Column():
117
- gr.Markdown("## Current Study")
118
- frontal = gr.Image(label="Frontal View", type="filepath")
119
- lateral = gr.Image(label="Lateral View", type="filepath")
120
- indication = gr.Textbox(label="Clinical Indication")
121
- technique = gr.Textbox(label="Imaging Technique")
122
- comparison = gr.Textbox(label="Comparison")
123
-
124
- gr.Markdown("## Prior Study (Optional)")
125
- prior_frontal = gr.Image(label="Prior Frontal View", type="filepath")
126
- prior_lateral = gr.Image(label="Prior Lateral View", type="filepath")
127
- prior_report = gr.Textbox(label="Prior Report")
128
-
129
- grounding = gr.Checkbox(label="Include Grounding")
130
- sample_btn = gr.Button("Load Sample Data")
131
-
132
- with gr.Column():
133
- report_output = gr.Textbox(label="Generated Report", lines=10)
134
- generate_btn = gr.Button("Generate Report")
135
-
136
- sample_btn.click(load_sample_findings,
137
- outputs=[frontal, lateral, indication, technique, comparison,
138
- prior_frontal, prior_lateral, prior_report, grounding])
139
- generate_btn.click(generate_report,
140
- inputs=[frontal, lateral, indication, technique, comparison,
141
- prior_frontal, prior_lateral, prior_report, grounding],
142
- outputs=report_output)
143
 
144
- with gr.Tab("Phrase Grounding"):
145
- with gr.Row():
146
- with gr.Column():
147
- pg_frontal = gr.Image(label="Frontal View", type="filepath")
148
- phrase = gr.Textbox(label="Phrase to Ground")
149
- pg_sample_btn = gr.Button("Load Sample Data")
150
- with gr.Column():
151
- pg_output = gr.Textbox(label="Grounding Result", lines=3)
152
- pg_btn = gr.Button("Find Phrase")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
- pg_sample_btn.click(load_sample_phrase,
155
- outputs=[pg_frontal, phrase])
156
- pg_btn.click(ground_phrase,
157
- inputs=[pg_frontal, phrase],
158
- outputs=pg_output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
  demo.launch()
 
5
  import requests
6
  import tempfile
7
 
8
+ MODEL_STATE = {
9
+ "model": None,
10
+ "processor": None,
11
+ "authenticated": False
12
+ }
13
+
14
+ def login(hf_token):
15
+ """Authenticate and load the model"""
16
+ try:
17
+ MODEL_STATE.update({"model": None, "processor": None, "authenticated": False})
18
+
19
+ MODEL_STATE["model"] = AutoModelForCausalLM.from_pretrained(
20
+ "microsoft/maira-2",
21
+ trust_remote_code=True,
22
+ use_auth_token=hf_token
23
+ )
24
+ MODEL_STATE["processor"] = AutoProcessor.from_pretrained(
25
+ "microsoft/maira-2",
26
+ trust_remote_code=True,
27
+ use_auth_token=hf_token
28
+ )
29
+
30
+ MODEL_STATE["model"] = MODEL_STATE["model"].eval().to("cpu")
31
+ MODEL_STATE["authenticated"] = True
32
+
33
+ return "🔓 Login successful! You can now use the model."
34
+ except Exception as e:
35
+ MODEL_STATE.update({"model": None, "processor": None, "authenticated": False})
36
+ return f"❌ Login failed: {str(e)}"
37
 
38
  def get_sample_data():
39
  """Download sample medical images and data"""
 
60
  return temp_file.name
61
 
62
  def load_sample_findings():
 
63
  sample = get_sample_data()
64
  return [
65
  save_temp_image(sample["frontal"]),
 
71
  ]
72
 
73
  def load_sample_phrase():
 
74
  sample = get_sample_data()
75
  return [save_temp_image(sample["frontal"]), sample["phrase"]]
76
 
77
+ def generate_report(frontal_path, lateral_path, indication, technique, comparison,
78
  prior_frontal_path, prior_lateral_path, prior_report, grounding):
79
+ """Generate radiology report with authentication check"""
80
+ if not MODEL_STATE["authenticated"]:
81
+ return "⚠️ Please authenticate with your Hugging Face token first!"
82
+
83
  try:
 
84
  current_frontal = Image.open(frontal_path)
85
  current_lateral = Image.open(lateral_path)
86
  prior_frontal = Image.open(prior_frontal_path) if prior_frontal_path else None
87
  prior_lateral = Image.open(prior_lateral_path) if prior_lateral_path else None
88
 
89
+ processed = MODEL_STATE["processor"].format_and_preprocess_reporting_input(
 
90
  current_frontal=current_frontal,
91
  current_lateral=current_lateral,
92
  prior_frontal=prior_frontal,
 
97
  prior_report=prior_report or None,
98
  return_tensors="pt",
99
  get_grounding=grounding
100
+ ).to("cpu")
101
 
102
+ outputs = MODEL_STATE["model"].generate(
103
+ **processed,
104
+ max_new_tokens=450 if grounding else 300,
105
+ use_cache=True
106
+ )
107
 
 
108
  prompt_length = processed["input_ids"].shape[-1]
109
+ decoded = MODEL_STATE["processor"].decode(outputs[0][prompt_length:], skip_special_tokens=True)
110
+ return MODEL_STATE["processor"].convert_output_to_plaintext_or_grounded_sequence(decoded.lstrip())
111
 
112
  except Exception as e:
113
+ return f"❌ Generation error: {str(e)}"
114
 
115
  def ground_phrase(frontal_path, phrase):
116
+ """Perform phrase grounding with authentication check"""
117
+ if not MODEL_STATE["authenticated"]:
118
+ return "⚠️ Please authenticate with your Hugging Face token first!"
119
+
120
  try:
121
  frontal = Image.open(frontal_path)
122
+ processed = MODEL_STATE["processor"].format_and_preprocess_phrase_grounding_input(
123
  frontal_image=frontal,
124
  phrase=phrase,
125
  return_tensors="pt"
126
+ ).to("cpu")
127
 
128
+ outputs = MODEL_STATE["model"].generate(
129
+ **processed,
130
+ max_new_tokens=150,
131
+ use_cache=True
132
+ )
133
 
134
  prompt_length = processed["input_ids"].shape[-1]
135
+ decoded = MODEL_STATE["processor"].decode(outputs[0][prompt_length:], skip_special_tokens=True)
136
+ return MODEL_STATE["processor"].convert_output_to_plaintext_or_grounded_sequence(decoded)
137
 
138
  except Exception as e:
139
+ return f"❌ Grounding error: {str(e)}"
140
 
141
+ with gr.Blocks(title="MAIRA-2 Medical Assistant") as demo:
142
+ gr.Markdown("""# MAIRA-2 Medical Assistant
143
+ **Authentication required** - You need a Hugging Face account and access token to use this model.
144
+ 1. Get your access token from [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)
145
+ 2. Request model access at [https://huggingface.co/microsoft/maira-2](https://huggingface.co/microsoft/maira-2)
146
+ 3. Paste your token below to begin
147
+ """)
148
 
149
+ with gr.Row():
150
+ hf_token = gr.Textbox(
151
+ label="Hugging Face Token",
152
+ placeholder="hf_xxxxxxxxxxxxxxxxxxxx",
153
+ type="password"
154
+ )
155
+ login_btn = gr.Button("Authenticate")
156
+ login_status = gr.Textbox(label="Authentication Status", interactive=False)
157
+
158
+ login_btn.click(
159
+ login,
160
+ inputs=hf_token,
161
+ outputs=login_status
162
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
 
164
+ with gr.Tabs():
165
+ with gr.Tab("Report Generation"):
166
+ with gr.Row():
167
+ with gr.Column():
168
+ gr.Markdown("## Current Study")
169
+ frontal = gr.Image(label="Frontal View", type="filepath")
170
+ lateral = gr.Image(label="Lateral View", type="filepath")
171
+ indication = gr.Textbox(label="Clinical Indication")
172
+ technique = gr.Textbox(label="Imaging Technique")
173
+ comparison = gr.Textbox(label="Comparison")
174
+
175
+ gr.Markdown("## Prior Study (Optional)")
176
+ prior_frontal = gr.Image(label="Prior Frontal View", type="filepath")
177
+ prior_lateral = gr.Image(label="Prior Lateral View", type="filepath")
178
+ prior_report = gr.Textbox(label="Prior Report")
179
+
180
+ grounding = gr.Checkbox(label="Include Grounding")
181
+ sample_btn = gr.Button("Load Sample Data")
182
+
183
+ with gr.Column():
184
+ report_output = gr.Textbox(label="Generated Report", lines=10)
185
+ generate_btn = gr.Button("Generate Report")
186
+
187
+ sample_btn.click(
188
+ load_sample_findings,
189
+ outputs=[frontal, lateral, indication, technique, comparison,
190
+ prior_frontal, prior_lateral, prior_report, grounding]
191
+ )
192
+ generate_btn.click(
193
+ generate_report,
194
+ inputs=[frontal, lateral, indication, technique, comparison,
195
+ prior_frontal, prior_lateral, prior_report, grounding],
196
+ outputs=report_output
197
+ )
198
 
199
+ with gr.Tab("Phrase Grounding"):
200
+ with gr.Row():
201
+ with gr.Column():
202
+ pg_frontal = gr.Image(label="Frontal View", type="filepath")
203
+ phrase = gr.Textbox(label="Phrase to Ground")
204
+ pg_sample_btn = gr.Button("Load Sample Data")
205
+ with gr.Column():
206
+ pg_output = gr.Textbox(label="Grounding Result", lines=3)
207
+ pg_btn = gr.Button("Find Phrase")
208
+
209
+ pg_sample_btn.click(
210
+ load_sample_phrase,
211
+ outputs=[pg_frontal, phrase]
212
+ )
213
+ pg_btn.click(
214
+ ground_phrase,
215
+ inputs=[pg_frontal, phrase],
216
+ outputs=pg_output
217
+ )
218
 
219
  demo.launch()