shikunl commited on
Commit
c2be56c
Β·
1 Parent(s): 07ccab1

Update with md5sum and half precision inference

Browse files
Files changed (1) hide show
  1. patch +0 -82
patch DELETED
@@ -1,82 +0,0 @@
1
- diff --git a/dataset/caption_dataset.py b/dataset/caption_dataset.py
2
- index 266fdda..0cc5d3f 100644
3
- --- a/dataset/caption_dataset.py
4
- +++ b/dataset/caption_dataset.py
5
- @@ -50,7 +50,7 @@ class Caption(Dataset):
6
- elif self.dataset == 'demo':
7
- img_path_split = self.data_list[index]['image'].split('/')
8
- img_name = img_path_split[-2] + '/' + img_path_split[-1]
9
- - image, labels, labels_info = get_expert_labels('', self.label_path, img_name, 'helpers', self.experts)
10
- + image, labels, labels_info = get_expert_labels('prismer', self.label_path, img_name, 'helpers', self.experts)
11
-
12
- experts = self.transform(image, labels)
13
- experts = post_label_process(experts, labels_info)
14
- diff --git a/dataset/utils.py b/dataset/utils.py
15
- index b368aac..418358c 100644
16
- --- a/dataset/utils.py
17
- +++ b/dataset/utils.py
18
- @@ -5,6 +5,7 @@
19
- # https://github.com/NVlabs/prismer/blob/main/LICENSE
20
-
21
- import os
22
- +import pathlib
23
- import re
24
- import json
25
- import torch
26
- @@ -14,10 +15,12 @@ import torchvision.transforms as transforms
27
- import torchvision.transforms.functional as transforms_f
28
- from dataset.randaugment import RandAugment
29
-
30
- -COCO_FEATURES = torch.load('dataset/coco_features.pt')['features']
31
- -ADE_FEATURES = torch.load('dataset/ade_features.pt')['features']
32
- -DETECTION_FEATURES = torch.load('dataset/detection_features.pt')['features']
33
- -BACKGROUND_FEATURES = torch.load('dataset/background_features.pt')
34
- +cur_dir = pathlib.Path(__file__).parent
35
- +
36
- +COCO_FEATURES = torch.load(cur_dir / 'coco_features.pt')['features']
37
- +ADE_FEATURES = torch.load(cur_dir / 'ade_features.pt')['features']
38
- +DETECTION_FEATURES = torch.load(cur_dir / 'detection_features.pt')['features']
39
- +BACKGROUND_FEATURES = torch.load(cur_dir / 'background_features.pt')
40
-
41
-
42
- class Transform:
43
- diff --git a/model/prismer.py b/model/prismer.py
44
- index 080253a..02362a4 100644
45
- --- a/model/prismer.py
46
- +++ b/model/prismer.py
47
- @@ -5,6 +5,7 @@
48
- # https://github.com/NVlabs/prismer/blob/main/LICENSE
49
-
50
- import json
51
- +import pathlib
52
- import torch.nn as nn
53
-
54
- from model.modules.vit import load_encoder
55
- @@ -12,6 +13,9 @@ from model.modules.roberta import load_decoder
56
- from transformers import RobertaTokenizer, RobertaConfig
57
-
58
-
59
- +cur_dir = pathlib.Path(__file__).parent
60
- +
61
- +
62
- class Prismer(nn.Module):
63
- def __init__(self, config):
64
- super().__init__()
65
- @@ -26,7 +30,7 @@ class Prismer(nn.Module):
66
- elif exp in ['obj_detection', 'ocr_detection']:
67
- self.experts[exp] = 64
68
-
69
- - prismer_config = json.load(open('configs/prismer.json', 'r'))[config['prismer_model']]
70
- + prismer_config = json.load(open(f'{cur_dir.parent}/configs/prismer.json', 'r'))[config['prismer_model']]
71
- roberta_config = RobertaConfig.from_dict(prismer_config['roberta_model'])
72
-
73
- self.tokenizer = RobertaTokenizer.from_pretrained(prismer_config['roberta_model']['model_name'])
74
- @@ -35,7 +39,7 @@ class Prismer(nn.Module):
75
-
76
- self.prepare_to_train(config['freeze'])
77
- self.ignored_modules = self.get_ignored_modules(config['freeze'])
78
- -
79
- +
80
- def prepare_to_train(self, mode='none'):
81
- for name, params in self.named_parameters():
82
- if mode == 'freeze_lang':