Spaces:
Runtime error
Runtime error
Add VQA
Browse files- label_prettify.py +1 -1
- prismer_model.py +8 -6
label_prettify.py
CHANGED
|
@@ -87,7 +87,7 @@ def ocr_detection_prettify(rgb_path, file_name):
|
|
| 87 |
ocr_labels_dict = torch.load(file_name.replace('.png', '.pt'))
|
| 88 |
|
| 89 |
plt.imshow(rgb)
|
| 90 |
-
plt.imshow(
|
| 91 |
|
| 92 |
for i in np.unique(ocr_labels)[:-1]:
|
| 93 |
text_idx_all = np.where(ocr_labels == i)
|
|
|
|
| 87 |
ocr_labels_dict = torch.load(file_name.replace('.png', '.pt'))
|
| 88 |
|
| 89 |
plt.imshow(rgb)
|
| 90 |
+
plt.imshow(ocr_labels, cmap='gray', alpha=0.8)
|
| 91 |
|
| 92 |
for i in np.unique(ocr_labels)[:-1]:
|
| 93 |
text_idx_all = np.where(ocr_labels == i)
|
prismer_model.py
CHANGED
|
@@ -75,11 +75,13 @@ class Model:
|
|
| 75 |
if exp_name == self.exp_name:
|
| 76 |
return
|
| 77 |
|
|
|
|
| 78 |
if self.exp_name == 'Prismer-Base':
|
| 79 |
-
|
| 80 |
elif self.exp_name == 'Prismer-Large':
|
| 81 |
-
|
| 82 |
|
|
|
|
| 83 |
if self.mode == 'caption':
|
| 84 |
config = {
|
| 85 |
'dataset': 'demo',
|
|
@@ -87,12 +89,12 @@ class Model:
|
|
| 87 |
'label_path': 'prismer/helpers/labels',
|
| 88 |
'experts': ['depth', 'normal', 'seg_coco', 'edge', 'obj_detection', 'ocr_detection'],
|
| 89 |
'image_resolution': 480,
|
| 90 |
-
'prismer_model':
|
| 91 |
'freeze': 'freeze_vision',
|
| 92 |
'prefix': '',
|
| 93 |
}
|
| 94 |
model = PrismerCaption(config)
|
| 95 |
-
state_dict = torch.load(f'prismer/logging/pretrain_{
|
| 96 |
|
| 97 |
elif self.mode == 'vqa':
|
| 98 |
config = {
|
|
@@ -101,12 +103,12 @@ class Model:
|
|
| 101 |
'label_path': 'prismer/helpers/labels',
|
| 102 |
'experts': ['depth', 'normal', 'seg_coco', 'edge', 'obj_detection', 'ocr_detection'],
|
| 103 |
'image_resolution': 480,
|
| 104 |
-
'prismer_model':
|
| 105 |
'freeze': 'freeze_vision',
|
| 106 |
}
|
| 107 |
|
| 108 |
model = PrismerVQA(config)
|
| 109 |
-
state_dict = torch.load(f'prismer/logging/vqa_{
|
| 110 |
|
| 111 |
model.load_state_dict(state_dict)
|
| 112 |
model.eval()
|
|
|
|
| 75 |
if exp_name == self.exp_name:
|
| 76 |
return
|
| 77 |
|
| 78 |
+
# remap model name
|
| 79 |
if self.exp_name == 'Prismer-Base':
|
| 80 |
+
self.exp_name = 'prismer_base'
|
| 81 |
elif self.exp_name == 'Prismer-Large':
|
| 82 |
+
self.exp_name = 'prismer_large'
|
| 83 |
|
| 84 |
+
# load checkpoints
|
| 85 |
if self.mode == 'caption':
|
| 86 |
config = {
|
| 87 |
'dataset': 'demo',
|
|
|
|
| 89 |
'label_path': 'prismer/helpers/labels',
|
| 90 |
'experts': ['depth', 'normal', 'seg_coco', 'edge', 'obj_detection', 'ocr_detection'],
|
| 91 |
'image_resolution': 480,
|
| 92 |
+
'prismer_model': self.exp_name,
|
| 93 |
'freeze': 'freeze_vision',
|
| 94 |
'prefix': '',
|
| 95 |
}
|
| 96 |
model = PrismerCaption(config)
|
| 97 |
+
state_dict = torch.load(f'prismer/logging/pretrain_{self.exp_name}/pytorch_model.bin', map_location='cuda:0')
|
| 98 |
|
| 99 |
elif self.mode == 'vqa':
|
| 100 |
config = {
|
|
|
|
| 103 |
'label_path': 'prismer/helpers/labels',
|
| 104 |
'experts': ['depth', 'normal', 'seg_coco', 'edge', 'obj_detection', 'ocr_detection'],
|
| 105 |
'image_resolution': 480,
|
| 106 |
+
'prismer_model': self.exp_name,
|
| 107 |
'freeze': 'freeze_vision',
|
| 108 |
}
|
| 109 |
|
| 110 |
model = PrismerVQA(config)
|
| 111 |
+
state_dict = torch.load(f'prismer/logging/vqa_{self.exp_name}/pytorch_model.bin', map_location='cuda:0')
|
| 112 |
|
| 113 |
model.load_state_dict(state_dict)
|
| 114 |
model.eval()
|