ayyuce commited on
Commit
8c00a90
·
verified ·
1 Parent(s): aad1f20

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -21
app.py CHANGED
@@ -4,6 +4,7 @@ import torch
4
  import gradio as gr
5
  import requests
6
  import tempfile
 
7
 
8
  MODEL_STATE = {
9
  "model": None,
@@ -75,20 +76,22 @@ def load_sample_phrase():
75
  return [save_temp_image(sample["frontal"]), sample["phrase"]]
76
 
77
  def generate_report(frontal_path, lateral_path, indication, technique, comparison,
78
- prior_frontal_path, prior_lateral_path, prior_report, grounding):
79
  """Generate radiology report with authentication check"""
80
  if not MODEL_STATE["authenticated"]:
81
  return "⚠️ Please authenticate with your Hugging Face token first!"
82
 
83
- if not frontal_path or not lateral_path:
84
- return "❌ Please upload both the frontal and lateral images for the current study."
85
-
86
  try:
87
- current_frontal = Image.open(frontal_path)
88
- current_lateral = Image.open(lateral_path)
89
  prior_frontal = Image.open(prior_frontal_path) if prior_frontal_path else None
90
  prior_lateral = Image.open(prior_lateral_path) if prior_lateral_path else None
91
 
 
 
 
 
 
92
  processed = MODEL_STATE["processor"].format_and_preprocess_reporting_input(
93
  current_frontal=current_frontal,
94
  current_lateral=current_lateral,
@@ -97,20 +100,20 @@ def generate_report(frontal_path, lateral_path, indication, technique, compariso
97
  indication=indication,
98
  technique=technique,
99
  comparison=comparison,
100
- prior_report=prior_report or None,
101
  return_tensors="pt",
102
  get_grounding=grounding
103
  ).to("cpu")
104
 
105
- processed.pop("image_sizes", None)
106
 
107
  outputs = MODEL_STATE["model"].generate(
108
- **processed,
109
  max_new_tokens=450 if grounding else 300,
110
  use_cache=True
111
  )
112
 
113
- prompt_length = processed["input_ids"].shape[-1]
114
  decoded = MODEL_STATE["processor"].decode(outputs[0][prompt_length:], skip_special_tokens=True)
115
  return MODEL_STATE["processor"].convert_output_to_plaintext_or_grounded_sequence(decoded.lstrip())
116
 
@@ -122,28 +125,26 @@ def ground_phrase(frontal_path, phrase):
122
  if not MODEL_STATE["authenticated"]:
123
  return "⚠️ Please authenticate with your Hugging Face token first!"
124
 
125
- # Check that the required image is provided.
126
- if not frontal_path:
127
- return "❌ Please upload the frontal image for phrase grounding."
128
-
129
  try:
 
 
 
130
  frontal = Image.open(frontal_path)
131
  processed = MODEL_STATE["processor"].format_and_preprocess_phrase_grounding_input(
132
  frontal_image=frontal,
133
  phrase=phrase,
134
  return_tensors="pt"
135
  ).to("cpu")
136
-
137
- # Remove the unexpected key if present.
138
- processed.pop("image_sizes", None)
139
 
140
  outputs = MODEL_STATE["model"].generate(
141
- **processed,
142
  max_new_tokens=150,
143
  use_cache=True
144
  )
145
 
146
- prompt_length = processed["input_ids"].shape[-1]
147
  decoded = MODEL_STATE["processor"].decode(outputs[0][prompt_length:], skip_special_tokens=True)
148
  return MODEL_STATE["processor"].convert_output_to_plaintext_or_grounded_sequence(decoded)
149
 
@@ -199,12 +200,12 @@ with gr.Blocks(title="MAIRA-2 Medical Assistant") as demo:
199
  sample_btn.click(
200
  load_sample_findings,
201
  outputs=[frontal, lateral, indication, technique, comparison,
202
- prior_frontal, prior_lateral, prior_report, grounding]
203
  )
204
  generate_btn.click(
205
  generate_report,
206
  inputs=[frontal, lateral, indication, technique, comparison,
207
- prior_frontal, prior_lateral, prior_report, grounding],
208
  outputs=report_output
209
  )
210
 
 
4
  import gradio as gr
5
  import requests
6
  import tempfile
7
+ import os
8
 
9
  MODEL_STATE = {
10
  "model": None,
 
76
  return [save_temp_image(sample["frontal"]), sample["phrase"]]
77
 
78
  def generate_report(frontal_path, lateral_path, indication, technique, comparison,
79
+ prior_frontal_path, prior_lateral_path, prior_report, grounding):
80
  """Generate radiology report with authentication check"""
81
  if not MODEL_STATE["authenticated"]:
82
  return "⚠️ Please authenticate with your Hugging Face token first!"
83
 
 
 
 
84
  try:
85
+ current_frontal = Image.open(frontal_path) if frontal_path else None
86
+ current_lateral = Image.open(lateral_path) if lateral_path else None
87
  prior_frontal = Image.open(prior_frontal_path) if prior_frontal_path else None
88
  prior_lateral = Image.open(prior_lateral_path) if prior_lateral_path else None
89
 
90
+ if not current_frontal or not current_lateral:
91
+ return "❌ Missing required current study images"
92
+
93
+ prior_report = prior_report or ""
94
+
95
  processed = MODEL_STATE["processor"].format_and_preprocess_reporting_input(
96
  current_frontal=current_frontal,
97
  current_lateral=current_lateral,
 
100
  indication=indication,
101
  technique=technique,
102
  comparison=comparison,
103
+ prior_report=prior_report,
104
  return_tensors="pt",
105
  get_grounding=grounding
106
  ).to("cpu")
107
 
108
+ processed_inputs = {k: v for k, v in processed.items() if k != 'image_sizes'}
109
 
110
  outputs = MODEL_STATE["model"].generate(
111
+ **processed_inputs,
112
  max_new_tokens=450 if grounding else 300,
113
  use_cache=True
114
  )
115
 
116
+ prompt_length = processed_inputs["input_ids"].shape[-1]
117
  decoded = MODEL_STATE["processor"].decode(outputs[0][prompt_length:], skip_special_tokens=True)
118
  return MODEL_STATE["processor"].convert_output_to_plaintext_or_grounded_sequence(decoded.lstrip())
119
 
 
125
  if not MODEL_STATE["authenticated"]:
126
  return "⚠️ Please authenticate with your Hugging Face token first!"
127
 
 
 
 
 
128
  try:
129
+ if not frontal_path:
130
+ return "❌ Missing frontal view image"
131
+
132
  frontal = Image.open(frontal_path)
133
  processed = MODEL_STATE["processor"].format_and_preprocess_phrase_grounding_input(
134
  frontal_image=frontal,
135
  phrase=phrase,
136
  return_tensors="pt"
137
  ).to("cpu")
138
+
139
+ processed_inputs = {k: v for k, v in processed.items() if k != 'image_sizes'}
 
140
 
141
  outputs = MODEL_STATE["model"].generate(
142
+ **processed_inputs,
143
  max_new_tokens=150,
144
  use_cache=True
145
  )
146
 
147
+ prompt_length = processed_inputs["input_ids"].shape[-1]
148
  decoded = MODEL_STATE["processor"].decode(outputs[0][prompt_length:], skip_special_tokens=True)
149
  return MODEL_STATE["processor"].convert_output_to_plaintext_or_grounded_sequence(decoded)
150
 
 
200
  sample_btn.click(
201
  load_sample_findings,
202
  outputs=[frontal, lateral, indication, technique, comparison,
203
+ prior_frontal, prior_lateral, prior_report, grounding]
204
  )
205
  generate_btn.click(
206
  generate_report,
207
  inputs=[frontal, lateral, indication, technique, comparison,
208
+ prior_frontal, prior_lateral, prior_report, grounding],
209
  outputs=report_output
210
  )
211