m7mdal7aj commited on
Commit
8eb0316
·
verified ·
1 Parent(s): 03f10c9

Update my_model/tabs/run_inference.py

Browse files
Files changed (1) hide show
  1. my_model/tabs/run_inference.py +63 -49
my_model/tabs/run_inference.py CHANGED
@@ -5,7 +5,8 @@ import accelerate
5
  import scipy
6
  import copy
7
  import time
8
- from typing import Tuple, Dict
 
9
  from PIL import Image
10
  import torch.nn as nn
11
  import pandas as pd
@@ -17,9 +18,9 @@ from my_model.config import inference_config as config
17
 
18
 
19
  class InferenceRunner(StateManager):
20
-
21
  """
22
- Manages the user interface and interactions for a Streamlit-based Knowledge-Based Visual Question Answering (KBVQA) application.
 
23
  This class handles image uploads, displays sample images, and facilitates the question-answering process using the KBVQA model.
24
  Inherits from the StateManager class.
25
  """
@@ -28,10 +29,10 @@ class InferenceRunner(StateManager):
28
  """
29
  Initializes the InferenceRunner instance, setting up the necessary state.
30
  """
31
-
32
- super().__init__()
33
 
 
34
 
 
35
  def answer_question(self, caption: str, detected_objects_str: str, question: str) -> Tuple[str, int]:
36
  """
37
  Generates an answer to a user's question based on the image's caption and detected objects.
@@ -42,8 +43,9 @@ class InferenceRunner(StateManager):
42
  question (str): User's question about the image.
43
 
44
  Returns:
45
- tuple: A tuple containing the answer to the question and the prompt length.
46
  """
 
47
  free_gpu_resources()
48
  answer = st.session_state.kbvqa.generate_answer(question, caption, detected_objects_str)
49
  prompt_length = st.session_state.kbvqa.current_prompt_length
@@ -54,6 +56,9 @@ class InferenceRunner(StateManager):
54
  def display_sample_images(self) -> None:
55
  """
56
  Displays sample images as clickable thumbnails for the user to select.
 
 
 
57
  """
58
 
59
  self.col1.write("Choose from sample images:")
@@ -66,37 +71,50 @@ class InferenceRunner(StateManager):
66
  if st.button(f'Select Sample Image {idx + 1}', key=f'sample_{idx+1}'):
67
  self.process_new_image(sample_image_path, image)
68
 
 
69
  def handle_image_upload(self) -> None:
70
  """
71
  Provides an image uploader widget for the user to upload their own images.
 
 
 
72
  """
 
73
  uploaded_image = self.col1.file_uploader("Or upload an Image", type=["png", "jpg", "jpeg"])
74
  if uploaded_image is not None:
75
  self.process_new_image(uploaded_image.name, Image.open(uploaded_image))
 
76
 
77
- def display_image_and_analysis(self, image_key: str, image_data: dict, nested_col21, nested_col22) -> None:
78
  """
79
  Displays the uploaded or selected image and provides an option to analyze the image.
80
 
81
  Args:
82
  image_key (str): Unique key identifying the image.
83
- image_data (dict): Data associated with the image.
84
- nested_col21 (streamlit column): Column for displaying the image.
85
- nested_col22 (streamlit column): Column for displaying the analysis button.
 
 
 
86
  """
87
 
88
  image_for_display = self.resize_image(image_data['image'], 600)
89
  nested_col21.image(image_for_display, caption=f'Uploaded Image: {image_key[-11:]}')
90
  self.handle_analysis_button(image_key, image_data, nested_col22)
 
91
 
92
- def handle_analysis_button(self, image_key: str, image_data: dict, nested_col22) -> None:
93
  """
94
  Provides an 'Analyze Image' button and processes the image analysis upon click.
95
 
96
  Args:
97
  image_key (str): Unique key identifying the image.
98
- image_data (dict): Data associated with the image.
99
- nested_col22 (streamlit column): Column for displaying the analysis button.
 
 
 
100
  """
101
 
102
  if not image_data['analysis_done'] or self.settings_changed or self.confidance_change:
@@ -109,14 +127,18 @@ class InferenceRunner(StateManager):
109
  self.update_image_data(image_key, caption, detected_objects_str, True)
110
  st.session_state['loading_in_progress'] = False
111
 
112
- def handle_question_answering(self, image_key: str, image_data: dict, nested_col22) -> None:
 
113
  """
114
  Manages the question-answering interface for each image.
115
 
116
  Args:
117
  image_key (str): Unique key identifying the image.
118
- image_data (dict): Data associated with the image.
119
- nested_col22 (streamlit column): Column for displaying the question-answering interface.
 
 
 
120
  """
121
 
122
  if image_data['analysis_done']:
@@ -126,14 +148,17 @@ class InferenceRunner(StateManager):
126
  nested_col22.warning("Confidence level changed, please click 'Analyze Image' each time you change it.")
127
 
128
 
129
- def display_question_answering_interface(self, image_key: str, image_data: Dict, nested_col22: st.columns) -> None:
130
  """
131
  Displays the interface for question answering, including sample questions and a custom question input.
132
-
133
  Args:
134
  image_key (str): Unique key identifying the image.
135
- image_data (dict): Data associated with the image.
136
- nested_col22 (streamlit column): The column where the interface will be displayed.
 
 
 
137
  """
138
 
139
  sample_questions = config.SAMPLE_QUESTIONS.get(image_key, [])
@@ -152,19 +177,20 @@ class InferenceRunner(StateManager):
152
  nested_col22.text(f"Q{num+1}: {q}\nA{num+1}: {a}\nPrompt Length: {p}\n")
153
 
154
 
155
-
156
- def process_question(self, image_key: str, question: str, image_data: Dict, nested_col22: st.columns) -> None:
157
  """
158
  Processes the user's question, generates an answer, and updates the question-answer history.
 
 
159
 
160
  Args:
161
  image_key (str): Unique key identifying the image.
162
  question (str): The question asked by the user.
163
  image_data (Dict): Data associated with the image.
164
- nested_col22 (streamlit column): The column where the answer will be displayed.
165
-
166
- This method checks if the question is new or if settings have changed, and if so, generates an answer using the KBVQA model.
167
- It then updates the question-answer history for the image.
168
  """
169
 
170
  qa_history = image_data.get('qa_history', [])
@@ -172,7 +198,7 @@ class InferenceRunner(StateManager):
172
  if nested_col22.button('Get Answer', key=f'answer_{image_key}', disabled=self.is_widget_disabled):
173
  answer, prompt_length = self.answer_question(image_data['caption'], image_data['detected_objects_str'], question)
174
  self.add_to_qa_history(image_key, question, answer, prompt_length)
175
- # nested_col22.text(f"Q: {question}\nA: {answer}\nPrompt Length: {prompt_length}")
176
 
177
  def image_qa_app(self) -> None:
178
  """
@@ -180,6 +206,9 @@ class InferenceRunner(StateManager):
180
 
181
  This method orchestrates the display of sample images, handles image uploads, and facilitates the question-answering process.
182
  It iterates through each image in the session state, displaying the image and providing interfaces for image analysis and question answering.
 
 
 
183
  """
184
 
185
  self.display_sample_images()
@@ -192,29 +221,27 @@ class InferenceRunner(StateManager):
192
  self.display_image_and_analysis(image_key, image_data, nested_col21, nested_col22)
193
  self.handle_question_answering(image_key, image_data, nested_col22)
194
 
195
-
196
 
197
- def run_inference(self):
198
  """
199
- Sets up widgets and manages the inference process, including model loading and reloading,
200
- based on user interactions.
201
 
202
  This method orchestrates the overall flow of the inference process.
203
- """
204
 
205
- self.set_up_widgets()
 
 
 
 
206
 
207
  load_fine_tuned_model = False
208
  fine_tuned_model_already_loaded = False
209
  reload_detection_model = False
210
  force_reload_full_model = False
211
 
212
-
213
  if self.is_model_loaded and self.settings_changed:
214
  self.col1.warning("Model settings have changed, please reload the model, this will take a second .. ")
215
  self.update_prev_state()
216
-
217
-
218
  st.session_state.button_label = "Reload Model" if self.is_model_loaded and st.session_state.kbvqa.detection_model != st.session_state['detection_model'] else "Load Model"
219
 
220
  with self.col1:
@@ -232,25 +259,20 @@ class InferenceRunner(StateManager):
232
  reload_detection_model = True
233
  if nested_col12.button("Force Reload", on_click=self.disable_widgets, disabled=self.is_widget_disabled):
234
  force_reload_full_model = True
235
-
236
-
237
  if load_fine_tuned_model:
238
  t1=time.time()
239
  free_gpu_resources()
240
  self.load_model()
241
  st.session_state['time_taken_to_load_model'] = int(time.time()-t1)
242
  st.session_state['loading_in_progress'] = False
243
-
244
  elif fine_tuned_model_already_loaded:
245
  free_gpu_resources()
246
  self.col1.text("Model already loaded and no settings were changed:)")
247
  st.session_state['loading_in_progress'] = False
248
-
249
  elif reload_detection_model:
250
  free_gpu_resources()
251
  self.reload_detection_model()
252
  st.session_state['loading_in_progress'] = False
253
-
254
  elif force_reload_full_model:
255
  free_gpu_resources()
256
  t1=time.time()
@@ -258,19 +280,11 @@ class InferenceRunner(StateManager):
258
  st.session_state['time_taken_to_load_model'] = int(time.time()-t1)
259
  st.session_state['loading_in_progress'] = False
260
  st.session_state['model_loaded'] = True
261
-
262
- # elif st.session_state.method == "13b-Fine-Tuned Model":
263
- # self.col1.warning(f'Model using {st.session_state.method} is not deployed yet, will be ready later.')
264
-
265
-
266
  elif st.session_state.method == "Vision-Language Embeddings Alignment":
267
  self.col1.warning(f'Model using {st.session_state.method} is desgined but requires large scale data and multiple high-end GPUs, implementation will be explored in the future.')
268
-
269
-
270
  if self.is_model_loaded:
271
  free_gpu_resources()
272
  st.session_state['loading_in_progress'] = False
273
-
274
  self.image_qa_app() # this is the main Q/A Application
275
 
276
 
 
5
  import scipy
6
  import copy
7
  import time
8
+ from typing import Tuple, Dict, List
9
+ from streamlit.delta_generator import DeltaGenerator
10
  from PIL import Image
11
  import torch.nn as nn
12
  import pandas as pd
 
18
 
19
 
20
  class InferenceRunner(StateManager):
 
21
  """
22
+ Manages the user interface and interactions for running inference using the Streamlit-based Knowledge-Based Visual Question Answering (KBVQA) application.
23
+
24
  This class handles image uploads, displays sample images, and facilitates the question-answering process using the KBVQA model.
25
  Inherits from the StateManager class.
26
  """
 
29
  """
30
  Initializes the InferenceRunner instance, setting up the necessary state.
31
  """
 
 
32
 
33
+ super().__init__()
34
 
35
+
36
  def answer_question(self, caption: str, detected_objects_str: str, question: str) -> Tuple[str, int]:
37
  """
38
  Generates an answer to a user's question based on the image's caption and detected objects.
 
43
  question (str): User's question about the image.
44
 
45
  Returns:
46
+ Tuple[str, int]: A tuple containing the answer to the question and the prompt length.
47
  """
48
+
49
  free_gpu_resources()
50
  answer = st.session_state.kbvqa.generate_answer(question, caption, detected_objects_str)
51
  prompt_length = st.session_state.kbvqa.current_prompt_length
 
56
  def display_sample_images(self) -> None:
57
  """
58
  Displays sample images as clickable thumbnails for the user to select.
59
+
60
+ Returns:
61
+ None
62
  """
63
 
64
  self.col1.write("Choose from sample images:")
 
71
  if st.button(f'Select Sample Image {idx + 1}', key=f'sample_{idx+1}'):
72
  self.process_new_image(sample_image_path, image)
73
 
74
+
75
  def handle_image_upload(self) -> None:
76
  """
77
  Provides an image uploader widget for the user to upload their own images.
78
+
79
+ Returns:
80
+ None
81
  """
82
+
83
  uploaded_image = self.col1.file_uploader("Or upload an Image", type=["png", "jpg", "jpeg"])
84
  if uploaded_image is not None:
85
  self.process_new_image(uploaded_image.name, Image.open(uploaded_image))
86
+
87
 
88
+ def display_image_and_analysis(self, image_key: str, image_data: Dict, nested_col21: DeltaGenerator, nested_col22: DeltaGenerator) -> None:
89
  """
90
  Displays the uploaded or selected image and provides an option to analyze the image.
91
 
92
  Args:
93
  image_key (str): Unique key identifying the image.
94
+ image_data (Dict): Data associated with the image.
95
+ nested_col21 (DeltaGenerator): Column for displaying the image.
96
+ nested_col22 (DeltaGenerator): Column for displaying the analysis button.
97
+
98
+ Returns:
99
+ None
100
  """
101
 
102
  image_for_display = self.resize_image(image_data['image'], 600)
103
  nested_col21.image(image_for_display, caption=f'Uploaded Image: {image_key[-11:]}')
104
  self.handle_analysis_button(image_key, image_data, nested_col22)
105
+
106
 
107
+ def handle_analysis_button(self, image_key: str, image_data: Dict, nested_col22: DeltaGenerator) -> None:
108
  """
109
  Provides an 'Analyze Image' button and processes the image analysis upon click.
110
 
111
  Args:
112
  image_key (str): Unique key identifying the image.
113
+ image_data (Dict): Data associated with the image.
114
+ nested_col22 (DeltaGenerator): Column for displaying the analysis button.
115
+
116
+ Returns:
117
+ None
118
  """
119
 
120
  if not image_data['analysis_done'] or self.settings_changed or self.confidance_change:
 
127
  self.update_image_data(image_key, caption, detected_objects_str, True)
128
  st.session_state['loading_in_progress'] = False
129
 
130
+
131
+ def handle_question_answering(self, image_key: str, image_data: Dict, nested_col22: DeltaGenerator) -> None:
132
  """
133
  Manages the question-answering interface for each image.
134
 
135
  Args:
136
  image_key (str): Unique key identifying the image.
137
+ image_data (Dict): Data associated with the image.
138
+ nested_col22 (DeltaGenerator): Column for displaying the question-answering interface.
139
+
140
+ Returns:
141
+ None
142
  """
143
 
144
  if image_data['analysis_done']:
 
148
  nested_col22.warning("Confidence level changed, please click 'Analyze Image' each time you change it.")
149
 
150
 
151
+ def display_question_answering_interface(self, image_key: str, image_data: Dict, nested_col22: DeltaGenerator) -> None:
152
  """
153
  Displays the interface for question answering, including sample questions and a custom question input.
154
+
155
  Args:
156
  image_key (str): Unique key identifying the image.
157
+ image_data (Dict): Data associated with the image.
158
+ nested_col22 (DeltaGenerator): The column where the interface will be displayed.
159
+
160
+ Returns:
161
+ None
162
  """
163
 
164
  sample_questions = config.SAMPLE_QUESTIONS.get(image_key, [])
 
177
  nested_col22.text(f"Q{num+1}: {q}\nA{num+1}: {a}\nPrompt Length: {p}\n")
178
 
179
 
180
+ def process_question(self, image_key: str, question: str, image_data: Dict, nested_col22: DeltaGenerator) -> None:
 
181
  """
182
  Processes the user's question, generates an answer, and updates the question-answer history.
183
+ This method checks if the question is new or if settings have changed, and if so, generates an answer using the KBVQA model.
184
+ It then updates the question-answer history for the image.
185
 
186
  Args:
187
  image_key (str): Unique key identifying the image.
188
  question (str): The question asked by the user.
189
  image_data (Dict): Data associated with the image.
190
+ nested_col22 (DeltaGenerator): The column where the answer will be displayed.
191
+
192
+ Returns:
193
+ None
194
  """
195
 
196
  qa_history = image_data.get('qa_history', [])
 
198
  if nested_col22.button('Get Answer', key=f'answer_{image_key}', disabled=self.is_widget_disabled):
199
  answer, prompt_length = self.answer_question(image_data['caption'], image_data['detected_objects_str'], question)
200
  self.add_to_qa_history(image_key, question, answer, prompt_length)
201
+
202
 
203
  def image_qa_app(self) -> None:
204
  """
 
206
 
207
  This method orchestrates the display of sample images, handles image uploads, and facilitates the question-answering process.
208
  It iterates through each image in the session state, displaying the image and providing interfaces for image analysis and question answering.
209
+
210
+ Returns:
211
+ None
212
  """
213
 
214
  self.display_sample_images()
 
221
  self.display_image_and_analysis(image_key, image_data, nested_col21, nested_col22)
222
  self.handle_question_answering(image_key, image_data, nested_col22)
223
 
 
224
 
225
+ def run_inference(self) -> None:
226
  """
227
+ Sets up widgets and manages the inference process, including model loading and reloading, based on user interactions.
 
228
 
229
  This method orchestrates the overall flow of the inference process.
 
230
 
231
+ Returns:
232
+ None
233
+ """
234
+
235
+ self.set_up_widgets() # Inherent from the StateManager Class
236
 
237
  load_fine_tuned_model = False
238
  fine_tuned_model_already_loaded = False
239
  reload_detection_model = False
240
  force_reload_full_model = False
241
 
 
242
  if self.is_model_loaded and self.settings_changed:
243
  self.col1.warning("Model settings have changed, please reload the model, this will take a second .. ")
244
  self.update_prev_state()
 
 
245
  st.session_state.button_label = "Reload Model" if self.is_model_loaded and st.session_state.kbvqa.detection_model != st.session_state['detection_model'] else "Load Model"
246
 
247
  with self.col1:
 
259
  reload_detection_model = True
260
  if nested_col12.button("Force Reload", on_click=self.disable_widgets, disabled=self.is_widget_disabled):
261
  force_reload_full_model = True
 
 
262
  if load_fine_tuned_model:
263
  t1=time.time()
264
  free_gpu_resources()
265
  self.load_model()
266
  st.session_state['time_taken_to_load_model'] = int(time.time()-t1)
267
  st.session_state['loading_in_progress'] = False
 
268
  elif fine_tuned_model_already_loaded:
269
  free_gpu_resources()
270
  self.col1.text("Model already loaded and no settings were changed:)")
271
  st.session_state['loading_in_progress'] = False
 
272
  elif reload_detection_model:
273
  free_gpu_resources()
274
  self.reload_detection_model()
275
  st.session_state['loading_in_progress'] = False
 
276
  elif force_reload_full_model:
277
  free_gpu_resources()
278
  t1=time.time()
 
280
  st.session_state['time_taken_to_load_model'] = int(time.time()-t1)
281
  st.session_state['loading_in_progress'] = False
282
  st.session_state['model_loaded'] = True
 
 
 
 
 
283
  elif st.session_state.method == "Vision-Language Embeddings Alignment":
284
  self.col1.warning(f'Model using {st.session_state.method} is desgined but requires large scale data and multiple high-end GPUs, implementation will be explored in the future.')
 
 
285
  if self.is_model_loaded:
286
  free_gpu_resources()
287
  st.session_state['loading_in_progress'] = False
 
288
  self.image_qa_app() # this is the main Q/A Application
289
 
290