Update my_model/KBVQA.py
Browse files- my_model/KBVQA.py +14 -4
my_model/KBVQA.py
CHANGED
|
@@ -99,6 +99,7 @@ class KBVQA:
|
|
| 99 |
|
| 100 |
self.captioner = ImageCaptioningModel()
|
| 101 |
self.captioner.load_model()
|
|
|
|
| 102 |
|
| 103 |
def get_caption(self, img: Image.Image) -> str:
|
| 104 |
"""
|
|
@@ -110,8 +111,9 @@ class KBVQA:
|
|
| 110 |
Returns:
|
| 111 |
str: The generated caption for the image.
|
| 112 |
"""
|
| 113 |
-
|
| 114 |
-
|
|
|
|
| 115 |
|
| 116 |
def load_detector(self, model: str) -> None:
|
| 117 |
"""
|
|
@@ -123,6 +125,7 @@ class KBVQA:
|
|
| 123 |
|
| 124 |
self.detector = ObjectDetector()
|
| 125 |
self.detector.load_model(model)
|
|
|
|
| 126 |
|
| 127 |
def detect_objects(self, img: Image.Image) -> Tuple[Image.Image, str]:
|
| 128 |
"""
|
|
@@ -136,8 +139,11 @@ class KBVQA:
|
|
| 136 |
"""
|
| 137 |
|
| 138 |
image = self.detector.process_image(img)
|
|
|
|
| 139 |
detected_objects_string, detected_objects_list = self.detector.detect_objects(image, threshold=st.session_state['confidence_level'])
|
|
|
|
| 140 |
image_with_boxes = self.detector.draw_boxes(img, detected_objects_list)
|
|
|
|
| 141 |
return image_with_boxes, detected_objects_string
|
| 142 |
|
| 143 |
def load_fine_tuned_model(self) -> None:
|
|
@@ -150,6 +156,8 @@ class KBVQA:
|
|
| 150 |
low_cpu_mem_usage=True,
|
| 151 |
quantization_config=self.bnb_config,
|
| 152 |
token=self.access_token)
|
|
|
|
|
|
|
| 153 |
|
| 154 |
self.kbvqa_tokenizer = AutoTokenizer.from_pretrained(self.kbvqa_model_name,
|
| 155 |
use_fast=self.use_fast,
|
|
@@ -157,7 +165,7 @@ class KBVQA:
|
|
| 157 |
trust_remote_code=self.trust_remote,
|
| 158 |
add_eos_token=self.add_eos_token,
|
| 159 |
token=self.access_token)
|
| 160 |
-
|
| 161 |
|
| 162 |
@property
|
| 163 |
def all_models_loaded(self):
|
|
@@ -225,7 +233,7 @@ class KBVQA:
|
|
| 225 |
Returns:
|
| 226 |
str: The generated answer to the question.
|
| 227 |
"""
|
| 228 |
-
|
| 229 |
prompt = self.format_prompt(question, caption=caption, objects=detected_objects_str)
|
| 230 |
num_tokens = len(self.kbvqa_tokenizer.tokenize(prompt))
|
| 231 |
self.current_prompt_length = num_tokens
|
|
@@ -234,8 +242,10 @@ class KBVQA:
|
|
| 234 |
return
|
| 235 |
|
| 236 |
model_inputs = self.kbvqa_tokenizer(prompt, add_special_tokens=False, return_tensors="pt").to('cuda')
|
|
|
|
| 237 |
input_ids = model_inputs["input_ids"]
|
| 238 |
output_ids = self.kbvqa_model.generate(input_ids)
|
|
|
|
| 239 |
index = input_ids.shape[1] # needed to avoid printing the input prompt
|
| 240 |
history = self.kbvqa_tokenizer.decode(output_ids[0], skip_special_tokens=False)
|
| 241 |
output_text = self.kbvqa_tokenizer.decode(output_ids[0][index:], skip_special_tokens=True)
|
|
|
|
| 99 |
|
| 100 |
self.captioner = ImageCaptioningModel()
|
| 101 |
self.captioner.load_model()
|
| 102 |
+
free_gpu_resources()
|
| 103 |
|
| 104 |
def get_caption(self, img: Image.Image) -> str:
|
| 105 |
"""
|
|
|
|
| 111 |
Returns:
|
| 112 |
str: The generated caption for the image.
|
| 113 |
"""
|
| 114 |
+
caption = self.captioner.generate_caption(img)
|
| 115 |
+
free_gpu_resources()
|
| 116 |
+
return caption
|
| 117 |
|
| 118 |
def load_detector(self, model: str) -> None:
|
| 119 |
"""
|
|
|
|
| 125 |
|
| 126 |
self.detector = ObjectDetector()
|
| 127 |
self.detector.load_model(model)
|
| 128 |
+
free_gpu_resources()
|
| 129 |
|
| 130 |
def detect_objects(self, img: Image.Image) -> Tuple[Image.Image, str]:
|
| 131 |
"""
|
|
|
|
| 139 |
"""
|
| 140 |
|
| 141 |
image = self.detector.process_image(img)
|
| 142 |
+
free_gpu_resources()
|
| 143 |
detected_objects_string, detected_objects_list = self.detector.detect_objects(image, threshold=st.session_state['confidence_level'])
|
| 144 |
+
free_gpu_resources()
|
| 145 |
image_with_boxes = self.detector.draw_boxes(img, detected_objects_list)
|
| 146 |
+
free_gpu_resources()
|
| 147 |
return image_with_boxes, detected_objects_string
|
| 148 |
|
| 149 |
def load_fine_tuned_model(self) -> None:
|
|
|
|
| 156 |
low_cpu_mem_usage=True,
|
| 157 |
quantization_config=self.bnb_config,
|
| 158 |
token=self.access_token)
|
| 159 |
+
|
| 160 |
+
free_gpu_resources()
|
| 161 |
|
| 162 |
self.kbvqa_tokenizer = AutoTokenizer.from_pretrained(self.kbvqa_model_name,
|
| 163 |
use_fast=self.use_fast,
|
|
|
|
| 165 |
trust_remote_code=self.trust_remote,
|
| 166 |
add_eos_token=self.add_eos_token,
|
| 167 |
token=self.access_token)
|
| 168 |
+
free_gpu_resources()
|
| 169 |
|
| 170 |
@property
|
| 171 |
def all_models_loaded(self):
|
|
|
|
| 233 |
Returns:
|
| 234 |
str: The generated answer to the question.
|
| 235 |
"""
|
| 236 |
+
free_gpu_resources()
|
| 237 |
prompt = self.format_prompt(question, caption=caption, objects=detected_objects_str)
|
| 238 |
num_tokens = len(self.kbvqa_tokenizer.tokenize(prompt))
|
| 239 |
self.current_prompt_length = num_tokens
|
|
|
|
| 242 |
return
|
| 243 |
|
| 244 |
model_inputs = self.kbvqa_tokenizer(prompt, add_special_tokens=False, return_tensors="pt").to('cuda')
|
| 245 |
+
free_gpu_resources()
|
| 246 |
input_ids = model_inputs["input_ids"]
|
| 247 |
output_ids = self.kbvqa_model.generate(input_ids)
|
| 248 |
+
free_gpu_resources()
|
| 249 |
index = input_ids.shape[1] # needed to avoid printing the input prompt
|
| 250 |
history = self.kbvqa_tokenizer.decode(output_ids[0], skip_special_tokens=False)
|
| 251 |
output_text = self.kbvqa_tokenizer.decode(output_ids[0][index:], skip_special_tokens=True)
|