Update my_model/tabs/run_inference.py
Browse files- my_model/tabs/run_inference.py +39 -25
my_model/tabs/run_inference.py
CHANGED
|
@@ -16,14 +16,36 @@ from my_model.config import inference_config as config
|
|
| 16 |
|
| 17 |
|
| 18 |
class InferenceRunner(StateManager):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
def __init__(self):
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
super().__init__()
|
| 22 |
self.initialize_state()
|
| 23 |
-
self.sample_images = config.SAMPLE_IMAGES
|
| 24 |
|
| 25 |
|
| 26 |
def answer_question(self, caption, detected_objects_str, question, model):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
free_gpu_resources()
|
| 28 |
answer = model.generate_answer(question, caption, detected_objects_str)
|
| 29 |
free_gpu_resources()
|
|
@@ -31,10 +53,18 @@ class InferenceRunner(StateManager):
|
|
| 31 |
|
| 32 |
|
| 33 |
def image_qa_app(self, kbvqa):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
# Display sample images as clickable thumbnails
|
| 35 |
self.col1.write("Choose from sample images:")
|
| 36 |
-
cols = self.col1.columns(len(
|
| 37 |
-
for idx, sample_image_path in enumerate(
|
| 38 |
with cols[idx]:
|
| 39 |
image = Image.open(sample_image_path)
|
| 40 |
image_for_display = self.resize_image(sample_image_path, 80, 80)
|
|
@@ -42,13 +72,8 @@ class InferenceRunner(StateManager):
|
|
| 42 |
if st.button(f'Select Sample Image {idx + 1}', key=f'sample_{idx}'):
|
| 43 |
self.process_new_image(sample_image_path, image, kbvqa)
|
| 44 |
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
# Image uploader
|
| 49 |
uploaded_image = self.col1.file_uploader("Or upload an Image", type=["png", "jpg", "jpeg"])
|
| 50 |
-
|
| 51 |
-
|
| 52 |
if uploaded_image is not None:
|
| 53 |
self.process_new_image(uploaded_image.name, Image.open(uploaded_image), kbvqa)
|
| 54 |
|
|
@@ -67,7 +92,6 @@ class InferenceRunner(StateManager):
|
|
| 67 |
caption, detected_objects_str, image_with_boxes = self.analyze_image(image_data['image'], kbvqa)
|
| 68 |
self.update_image_data(image_key, caption, detected_objects_str, True)
|
| 69 |
st.session_state['loading_in_progress'] = False
|
| 70 |
-
|
| 71 |
|
| 72 |
# Initialize qa_history for each image
|
| 73 |
qa_history = image_data.get('qa_history', [])
|
|
@@ -87,7 +111,6 @@ class InferenceRunner(StateManager):
|
|
| 87 |
# Use the selected sample question or the custom question
|
| 88 |
question = custom_question if selected_question == "Custom question..." else selected_question
|
| 89 |
|
| 90 |
-
|
| 91 |
if not question:
|
| 92 |
nested_col22.warning("Please select or enter a question.")
|
| 93 |
else:
|
|
@@ -100,20 +123,19 @@ class InferenceRunner(StateManager):
|
|
| 100 |
st.session_state['loading_in_progress'] = False
|
| 101 |
self.add_to_qa_history(image_key, question, answer)
|
| 102 |
|
| 103 |
-
|
| 104 |
# Display Q&A history for each image
|
| 105 |
for num, (q, a) in enumerate(qa_history):
|
| 106 |
nested_col22.text(f"Q{num+1}: {q}\nA{num+1}: {a}\n")
|
| 107 |
|
| 108 |
-
def display_message(self, message, warning=False, write=False, text=False):
|
| 109 |
-
pass
|
| 110 |
-
|
| 111 |
-
|
| 112 |
|
| 113 |
def run_inference(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
|
| 115 |
self.set_up_widgets()
|
| 116 |
-
|
| 117 |
load_fine_tuned_model = False
|
| 118 |
fine_tuned_model_already_loaded = False
|
| 119 |
reload_detection_model = False
|
|
@@ -123,27 +145,21 @@ class InferenceRunner(StateManager):
|
|
| 123 |
if st.session_state['settings_changed']:
|
| 124 |
self.col1.warning("Model settings have changed, please reload the model, this will take a second .. ")
|
| 125 |
|
| 126 |
-
|
| 127 |
st.session_state.button_label = "Reload Model" if self.is_model_loaded() and self.settings_changed else "Load Model"
|
| 128 |
|
| 129 |
with self.col1:
|
| 130 |
-
|
| 131 |
if st.session_state.method == "Fine-Tuned Model":
|
| 132 |
-
|
| 133 |
with st.container():
|
| 134 |
nested_col11, nested_col12 = st.columns([0.5, 0.5])
|
| 135 |
if nested_col11.button(st.session_state.button_label, on_click=self.disable_widgets, disabled=self.is_widget_disabled):
|
| 136 |
-
|
| 137 |
if st.session_state.button_label == "Load Model":
|
| 138 |
if self.is_model_loaded():
|
| 139 |
free_gpu_resources()
|
| 140 |
fine_tuned_model_already_loaded = True
|
| 141 |
-
|
| 142 |
else:
|
| 143 |
load_fine_tuned_model = True
|
| 144 |
else:
|
| 145 |
reload_detection_model = True
|
| 146 |
-
|
| 147 |
if nested_col12.button("Force Reload", on_click=self.disable_widgets, disabled=self.is_widget_disabled):
|
| 148 |
force_reload_full_model = True
|
| 149 |
|
|
@@ -172,14 +188,12 @@ class InferenceRunner(StateManager):
|
|
| 172 |
st.session_state['time_taken_to_load_model'] = int(time.time()-t1)
|
| 173 |
st.session_state['loading_in_progress'] = False
|
| 174 |
st.session_state['model_loaded'] = True
|
| 175 |
-
|
| 176 |
elif st.session_state.method == "In-Context Learning (n-shots)":
|
| 177 |
self.col1.warning(f'Model using {st.session_state.method} is not deployed yet, will be ready later.')
|
| 178 |
st.session_state['loading_in_progress'] = False
|
| 179 |
-
|
| 180 |
|
| 181 |
if self.is_model_loaded():
|
| 182 |
-
|
| 183 |
free_gpu_resources()
|
| 184 |
st.session_state['loading_in_progress'] = False
|
| 185 |
self.image_qa_app(self.get_model())
|
|
|
|
| 16 |
|
| 17 |
|
| 18 |
class InferenceRunner(StateManager):
|
| 19 |
+
|
| 20 |
+
"""
|
| 21 |
+
InferenceRunner manages the user interface and interactions for a Streamlit-based
|
| 22 |
+
Knowledge-Based Visual Question Answering (KBVQA) application. It handles image uploads,
|
| 23 |
+
displays sample images, and facilitates the question-answering process using the KBVQA model.
|
| 24 |
+
it inherits the StateManager class.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
def __init__(self):
|
| 28 |
+
"""
|
| 29 |
+
Initializes the InferenceRunner instance, setting up the necessary state.
|
| 30 |
+
"""
|
| 31 |
|
| 32 |
super().__init__()
|
| 33 |
self.initialize_state()
|
|
|
|
| 34 |
|
| 35 |
|
| 36 |
def answer_question(self, caption, detected_objects_str, question, model):
|
| 37 |
+
"""
|
| 38 |
+
Generates an answer to a given question based on the image's caption and detected objects.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
caption (str): The caption generated for the image.
|
| 42 |
+
detected_objects_str (str): String representation of objects detected in the image.
|
| 43 |
+
question (str): The user's question about the image.
|
| 44 |
+
model (KBVQA): The loaded KBVQA model used for generating the answer.
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
str: The generated answer to the question.
|
| 48 |
+
"""
|
| 49 |
free_gpu_resources()
|
| 50 |
answer = model.generate_answer(question, caption, detected_objects_str)
|
| 51 |
free_gpu_resources()
|
|
|
|
| 53 |
|
| 54 |
|
| 55 |
def image_qa_app(self, kbvqa):
|
| 56 |
+
"""
|
| 57 |
+
Main application interface for image-based question answering. It handles displaying
|
| 58 |
+
of sample images, uploading of new images, and facilitates the QA process.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
kbvqa (KBVQA): The loaded KBVQA model used for image analysis and question answering.
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
# Display sample images as clickable thumbnails
|
| 65 |
self.col1.write("Choose from sample images:")
|
| 66 |
+
cols = self.col1.columns(len(config.SAMPLE_IMAGES))
|
| 67 |
+
for idx, sample_image_path in enumerate(config.SAMPLE_IMAGES):
|
| 68 |
with cols[idx]:
|
| 69 |
image = Image.open(sample_image_path)
|
| 70 |
image_for_display = self.resize_image(sample_image_path, 80, 80)
|
|
|
|
| 72 |
if st.button(f'Select Sample Image {idx + 1}', key=f'sample_{idx}'):
|
| 73 |
self.process_new_image(sample_image_path, image, kbvqa)
|
| 74 |
|
|
|
|
|
|
|
|
|
|
| 75 |
# Image uploader
|
| 76 |
uploaded_image = self.col1.file_uploader("Or upload an Image", type=["png", "jpg", "jpeg"])
|
|
|
|
|
|
|
| 77 |
if uploaded_image is not None:
|
| 78 |
self.process_new_image(uploaded_image.name, Image.open(uploaded_image), kbvqa)
|
| 79 |
|
|
|
|
| 92 |
caption, detected_objects_str, image_with_boxes = self.analyze_image(image_data['image'], kbvqa)
|
| 93 |
self.update_image_data(image_key, caption, detected_objects_str, True)
|
| 94 |
st.session_state['loading_in_progress'] = False
|
|
|
|
| 95 |
|
| 96 |
# Initialize qa_history for each image
|
| 97 |
qa_history = image_data.get('qa_history', [])
|
|
|
|
| 111 |
# Use the selected sample question or the custom question
|
| 112 |
question = custom_question if selected_question == "Custom question..." else selected_question
|
| 113 |
|
|
|
|
| 114 |
if not question:
|
| 115 |
nested_col22.warning("Please select or enter a question.")
|
| 116 |
else:
|
|
|
|
| 123 |
st.session_state['loading_in_progress'] = False
|
| 124 |
self.add_to_qa_history(image_key, question, answer)
|
| 125 |
|
|
|
|
| 126 |
# Display Q&A history for each image
|
| 127 |
for num, (q, a) in enumerate(qa_history):
|
| 128 |
nested_col22.text(f"Q{num+1}: {q}\nA{num+1}: {a}\n")
|
| 129 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
|
| 131 |
def run_inference(self):
|
| 132 |
+
"""
|
| 133 |
+
Sets up the widgets and manages the inference process. This method handles model loading,
|
| 134 |
+
reloading, and the overall flow of the inference process based on user interactions.
|
| 135 |
+
|
| 136 |
+
"""
|
| 137 |
|
| 138 |
self.set_up_widgets()
|
|
|
|
| 139 |
load_fine_tuned_model = False
|
| 140 |
fine_tuned_model_already_loaded = False
|
| 141 |
reload_detection_model = False
|
|
|
|
| 145 |
if st.session_state['settings_changed']:
|
| 146 |
self.col1.warning("Model settings have changed, please reload the model, this will take a second .. ")
|
| 147 |
|
|
|
|
| 148 |
st.session_state.button_label = "Reload Model" if self.is_model_loaded() and self.settings_changed else "Load Model"
|
| 149 |
|
| 150 |
with self.col1:
|
|
|
|
| 151 |
if st.session_state.method == "Fine-Tuned Model":
|
|
|
|
| 152 |
with st.container():
|
| 153 |
nested_col11, nested_col12 = st.columns([0.5, 0.5])
|
| 154 |
if nested_col11.button(st.session_state.button_label, on_click=self.disable_widgets, disabled=self.is_widget_disabled):
|
|
|
|
| 155 |
if st.session_state.button_label == "Load Model":
|
| 156 |
if self.is_model_loaded():
|
| 157 |
free_gpu_resources()
|
| 158 |
fine_tuned_model_already_loaded = True
|
|
|
|
| 159 |
else:
|
| 160 |
load_fine_tuned_model = True
|
| 161 |
else:
|
| 162 |
reload_detection_model = True
|
|
|
|
| 163 |
if nested_col12.button("Force Reload", on_click=self.disable_widgets, disabled=self.is_widget_disabled):
|
| 164 |
force_reload_full_model = True
|
| 165 |
|
|
|
|
| 188 |
st.session_state['time_taken_to_load_model'] = int(time.time()-t1)
|
| 189 |
st.session_state['loading_in_progress'] = False
|
| 190 |
st.session_state['model_loaded'] = True
|
| 191 |
+
|
| 192 |
elif st.session_state.method == "In-Context Learning (n-shots)":
|
| 193 |
self.col1.warning(f'Model using {st.session_state.method} is not deployed yet, will be ready later.')
|
| 194 |
st.session_state['loading_in_progress'] = False
|
|
|
|
| 195 |
|
| 196 |
if self.is_model_loaded():
|
|
|
|
| 197 |
free_gpu_resources()
|
| 198 |
st.session_state['loading_in_progress'] = False
|
| 199 |
self.image_qa_app(self.get_model())
|