Update my_model/KBVQA.py
Browse files- my_model/KBVQA.py +17 -11
my_model/KBVQA.py
CHANGED
|
@@ -159,23 +159,29 @@ class KBVQA():
|
|
| 159 |
|
| 160 |
return output_text.capitalize()
|
| 161 |
|
| 162 |
-
def prepare_kbvqa_model(detection_model):
|
| 163 |
free_gpu_resources()
|
| 164 |
kbvqa = KBVQA()
|
| 165 |
kbvqa.detection_model = detection_model
|
| 166 |
# Progress bar for model loading
|
| 167 |
with st.spinner('Loading model...'):
|
| 168 |
-
|
| 169 |
-
progress_bar = st.progress(0)
|
| 170 |
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
|
| 180 |
if kbvqa.all_models_loaded:
|
| 181 |
st.success('Model loaded successfully!')
|
|
|
|
| 159 |
|
| 160 |
return output_text.capitalize()
|
| 161 |
|
| 162 |
+
def prepare_kbvqa_model(detection_model, only_reload_detection_model=False):
|
| 163 |
free_gpu_resources()
|
| 164 |
kbvqa = KBVQA()
|
| 165 |
kbvqa.detection_model = detection_model
|
| 166 |
# Progress bar for model loading
|
| 167 |
with st.spinner('Loading model...'):
|
|
|
|
|
|
|
| 168 |
|
| 169 |
+
if not only_reload_detection_model:
|
| 170 |
+
progress_bar = st.progress(0)
|
| 171 |
+
|
| 172 |
+
kbvqa.load_detector(kbvqa.detection_model)
|
| 173 |
+
progress_bar.progress(33)
|
| 174 |
+
kbvqa.load_caption_model()
|
| 175 |
+
free_gpu_resources()
|
| 176 |
+
progress_bar.progress(66)
|
| 177 |
+
kbvqa.load_fine_tuned_model()
|
| 178 |
+
free_gpu_resources()
|
| 179 |
+
progress_bar.progress(100)
|
| 180 |
+
|
| 181 |
+
else:
|
| 182 |
+
progress_bar = st.progress(0)
|
| 183 |
+
kbvqa.load_detector(kbvqa.detection_model)
|
| 184 |
+
progress_bar.progress(100)
|
| 185 |
|
| 186 |
if kbvqa.all_models_loaded:
|
| 187 |
st.success('Model loaded successfully!')
|