ayyuce commited on
Commit
e2d3fe3
·
verified ·
1 Parent(s): 3424243

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -8
app.py CHANGED
@@ -17,13 +17,13 @@ def load_model_and_processor(hf_token: str):
17
  return _model_cache[hf_token]
18
  device = torch.device("cpu")
19
  model = AutoModelForCausalLM.from_pretrained(
20
- "microsoft/maira-2",
21
- trust_remote_code=True,
22
  use_auth_token=hf_token
23
  )
24
  processor = AutoProcessor.from_pretrained(
25
- "microsoft/maira-2",
26
- trust_remote_code=True,
27
  use_auth_token=hf_token
28
  )
29
  model.eval()
@@ -33,7 +33,7 @@ def load_model_and_processor(hf_token: str):
33
 
34
  def get_sample_data() -> dict:
35
  """
36
- Download sample chest X-ray images and associated data.
37
  """
38
  frontal_image_url = "https://openi.nlm.nih.gov/imgs/512/145/145/CXR145_IM-0290-1001.png"
39
  lateral_image_url = "https://openi.nlm.nih.gov/imgs/512/145/145/CXR145_IM-0290-2001.png"
@@ -86,7 +86,14 @@ def generate_report(hf_token, frontal, lateral, indication, technique, compariso
86
  return_tensors="pt",
87
  get_grounding=use_grounding,
88
  )
 
89
  processed_inputs = {k: v.to(device) for k, v in processed_inputs.items()}
 
 
 
 
 
 
90
  max_tokens = 450 if use_grounding else 300
91
  with torch.no_grad():
92
  output_decoding = model.generate(
@@ -121,6 +128,12 @@ def run_phrase_grounding(hf_token, frontal, phrase):
121
  return_tensors="pt",
122
  )
123
  processed_inputs = {k: v.to(device) for k, v in processed_inputs.items()}
 
 
 
 
 
 
124
  with torch.no_grad():
125
  output_decoding = model.generate(
126
  **processed_inputs,
@@ -132,6 +145,7 @@ def run_phrase_grounding(hf_token, frontal, phrase):
132
  prediction = processor.convert_output_to_plaintext_or_grounded_sequence(decoded_text)
133
  return prediction
134
 
 
135
  def login_ui(hf_token):
136
  """Authenticate the user by loading the model."""
137
  try:
@@ -177,14 +191,14 @@ def load_sample_findings():
177
  sample = get_sample_data()
178
  return [
179
  save_temp_image(sample["frontal"]), # frontal image file path
180
- save_temp_image(sample["lateral"]), # lateral image file path
181
  sample["indication"],
182
  sample["technique"],
183
  sample["comparison"],
184
  None, # prior frontal (not used)
185
  None, # prior lateral (not used)
186
  None, # prior report (not used)
187
- False
188
  ]
189
 
190
  def load_sample_phrase():
@@ -276,4 +290,4 @@ with gr.Blocks(title="MAIRA-2 Medical Assistant") as demo:
276
  outputs=pg_output
277
  )
278
 
279
- demo.launch()
 
17
  return _model_cache[hf_token]
18
  device = torch.device("cpu")
19
  model = AutoModelForCausalLM.from_pretrained(
20
+ "microsoft/maira-2",
21
+ trust_remote_code=True,
22
  use_auth_token=hf_token
23
  )
24
  processor = AutoProcessor.from_pretrained(
25
+ "microsoft/maira-2",
26
+ trust_remote_code=True,
27
  use_auth_token=hf_token
28
  )
29
  model.eval()
 
33
 
34
  def get_sample_data() -> dict:
35
  """
36
+ Downloads sample chest X-ray images and associated data.
37
  """
38
  frontal_image_url = "https://openi.nlm.nih.gov/imgs/512/145/145/CXR145_IM-0290-1001.png"
39
  lateral_image_url = "https://openi.nlm.nih.gov/imgs/512/145/145/CXR145_IM-0290-2001.png"
 
86
  return_tensors="pt",
87
  get_grounding=use_grounding,
88
  )
89
+ # Move all tensors to the CPU
90
  processed_inputs = {k: v.to(device) for k, v in processed_inputs.items()}
91
+ # Remove keys containing "image_sizes" to prevent unexpected keyword errors.
92
+ processed_inputs = dict(processed_inputs)
93
+ keys_to_remove = [k for k in processed_inputs if "image_sizes" in k]
94
+ for key in keys_to_remove:
95
+ processed_inputs.pop(key, None)
96
+
97
  max_tokens = 450 if use_grounding else 300
98
  with torch.no_grad():
99
  output_decoding = model.generate(
 
128
  return_tensors="pt",
129
  )
130
  processed_inputs = {k: v.to(device) for k, v in processed_inputs.items()}
131
+ # Remove keys containing "image_sizes" to prevent unexpected keyword errors.
132
+ processed_inputs = dict(processed_inputs)
133
+ keys_to_remove = [k for k in processed_inputs if "image_sizes" in k]
134
+ for key in keys_to_remove:
135
+ processed_inputs.pop(key, None)
136
+
137
  with torch.no_grad():
138
  output_decoding = model.generate(
139
  **processed_inputs,
 
145
  prediction = processor.convert_output_to_plaintext_or_grounded_sequence(decoded_text)
146
  return prediction
147
 
148
+
149
  def login_ui(hf_token):
150
  """Authenticate the user by loading the model."""
151
  try:
 
191
  sample = get_sample_data()
192
  return [
193
  save_temp_image(sample["frontal"]), # frontal image file path
194
+ save_temp_image(sample["lateral"]), # lateral image file path
195
  sample["indication"],
196
  sample["technique"],
197
  sample["comparison"],
198
  None, # prior frontal (not used)
199
  None, # prior lateral (not used)
200
  None, # prior report (not used)
201
+ False # grounding checkbox default
202
  ]
203
 
204
  def load_sample_phrase():
 
290
  outputs=pg_output
291
  )
292
 
293
+ demo.launch()