Alberto Carmona commited on
Commit
23d1bbb
·
1 Parent(s): ebd4e51

Track error cloning the repo

Browse files
Files changed (46) hide show
  1. clip/__init__.py +0 -1
  2. clip/bpe_simple_vocab_16e6.txt.gz +0 -3
  3. clip/clip.py +0 -193
  4. clip/model.py +0 -437
  5. clip/simple_tokenizer.py +0 -132
  6. configs/phase1/FineCapEval_clipRN50_mle.yml +0 -60
  7. configs/phase1/clipRN50_mle.yml +0 -52
  8. configs/phase1/transformer.yml +0 -41
  9. configs/phase2/FineCapEval_clipRN50_cider.yml +0 -61
  10. configs/phase2/FineCapEval_clipRN50_cider_clips.yml +0 -65
  11. configs/phase2/FineCapEval_clipRN50_clips.yml +0 -64
  12. configs/phase2/FineCapEval_clipRN50_clips_grammar.yml +0 -64
  13. configs/phase2/clipRN50_cider.yml +0 -58
  14. configs/phase2/clipRN50_cider_clips.yml +0 -61
  15. configs/phase2/clipRN50_clips.yml +0 -58
  16. configs/phase2/clipRN50_clips_grammar.yml +0 -64
  17. configs/phase2/transformer.yml +0 -41
  18. data/README.md +0 -1
  19. retrieval/README.md +0 -5
  20. retrieval/caption_data.py +0 -500
  21. retrieval/clip_model.py +0 -350
  22. retrieval/configs/clip_negative_text.yaml +0 -14
  23. retrieval/param.py +0 -209
  24. retrieval/pth_loader.py +0 -334
  25. retrieval/text_utils.py +0 -74
  26. retrieval/train_pl.py +0 -661
  27. save/README.md +0 -1
  28. scripts/build_bpe_subword_nmt.py +0 -214
  29. scripts/clip_prepro_feats.py +0 -170
  30. scripts/clipscore_prepro_feats.py +0 -162
  31. scripts/copy_model.sh +0 -9
  32. scripts/dump_to_h5df.py +0 -56
  33. scripts/dump_to_lmdb.py +0 -241
  34. scripts/make_bu_data.py +0 -52
  35. scripts/prepro_feats.py +0 -103
  36. scripts/prepro_labels.py +0 -206
  37. scripts/prepro_ngrams.py +0 -94
  38. scripts/prepro_reference_json.py +0 -69
  39. scripts_FineCapEval/clip_prepro_feats.py +0 -163
  40. scripts_FineCapEval/clipscore_prepro_feats.py +0 -154
  41. scripts_FineCapEval/prepro_labels.py +0 -209
  42. tools/eval.py +0 -125
  43. tools/eval_clip_retrieval.py +0 -231
  44. tools/eval_finecapeval.py +0 -204
  45. tools/finecapeval_inference.py +0 -186
  46. tools/train_pl.py +0 -709
clip/__init__.py DELETED
@@ -1 +0,0 @@
1
- from .clip import *
 
 
clip/bpe_simple_vocab_16e6.txt.gz DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
- size 1356917
 
 
 
 
clip/clip.py DELETED
@@ -1,193 +0,0 @@
1
- import hashlib
2
- import os
3
- import urllib
4
- import warnings
5
- from typing import Union, List
6
-
7
- import torch
8
- from PIL import Image
9
- from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
10
- from tqdm import tqdm
11
-
12
- from .model import build_model
13
- from .simple_tokenizer import SimpleTokenizer as _Tokenizer
14
-
15
- __all__ = ["available_models", "load", "tokenize"]
16
- _tokenizer = _Tokenizer()
17
-
18
- _MODELS = {
19
- "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
20
- "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
21
- "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
22
- "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
23
- }
24
-
25
-
26
- def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
27
- os.makedirs(root, exist_ok=True)
28
- filename = os.path.basename(url)
29
-
30
- expected_sha256 = url.split("/")[-2]
31
- download_target = os.path.join(root, filename)
32
-
33
- if os.path.exists(download_target) and not os.path.isfile(download_target):
34
- raise RuntimeError(f"{download_target} exists and is not a regular file")
35
-
36
- if os.path.isfile(download_target):
37
- if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
38
- return download_target
39
- else:
40
- warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
41
-
42
- with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
43
- with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
44
- while True:
45
- buffer = source.read(8192)
46
- if not buffer:
47
- break
48
-
49
- output.write(buffer)
50
- loop.update(len(buffer))
51
-
52
- if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
53
- raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
54
-
55
- return download_target
56
-
57
-
58
- def _transform(n_px):
59
- return Compose([
60
- Resize(n_px, interpolation=Image.BICUBIC),
61
- CenterCrop(n_px),
62
- lambda image: image.convert("RGB"),
63
- ToTensor(),
64
- Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
65
- ])
66
-
67
-
68
- def available_models() -> List[str]:
69
- """Returns the names of available CLIP models"""
70
- return list(_MODELS.keys())
71
-
72
-
73
- def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True):
74
- """Load a CLIP model
75
-
76
- Parameters
77
- ----------
78
- name : str
79
- A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
80
-
81
- device : Union[str, torch.device]
82
- The device to put the loaded model
83
-
84
- jit : bool
85
- Whether to load the optimized JIT model (default) or more hackable non-JIT model.
86
-
87
- Returns
88
- -------
89
- model : torch.nn.Module
90
- The CLIP model
91
-
92
- preprocess : Callable[[PIL.Image], torch.Tensor]
93
- A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
94
- """
95
- if name in _MODELS:
96
- model_path = _download(_MODELS[name])
97
- elif os.path.isfile(name):
98
- model_path = name
99
- else:
100
- raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
101
-
102
- try:
103
- # loading JIT archive
104
- model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
105
- state_dict = None
106
- except RuntimeError:
107
- # loading saved state dict
108
- if jit:
109
- warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
110
- jit = False
111
- state_dict = torch.load(model_path, map_location="cpu")
112
-
113
- if not jit:
114
- model = build_model(state_dict or model.state_dict()).to(device)
115
- if str(device) == "cpu":
116
- model.float()
117
- return model, _transform(model.visual.input_resolution)
118
-
119
- # patch the device names
120
- device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
121
- device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
122
-
123
- def patch_device(module):
124
- graphs = [module.graph] if hasattr(module, "graph") else []
125
- if hasattr(module, "forward1"):
126
- graphs.append(module.forward1.graph)
127
-
128
- for graph in graphs:
129
- for node in graph.findAllNodes("prim::Constant"):
130
- if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
131
- node.copyAttributes(device_node)
132
-
133
- model.apply(patch_device)
134
- patch_device(model.encode_image)
135
- patch_device(model.encode_text)
136
-
137
- # patch dtype to float32 on CPU
138
- if str(device) == "cpu":
139
- float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
140
- float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
141
- float_node = float_input.node()
142
-
143
- def patch_float(module):
144
- graphs = [module.graph] if hasattr(module, "graph") else []
145
- if hasattr(module, "forward1"):
146
- graphs.append(module.forward1.graph)
147
-
148
- for graph in graphs:
149
- for node in graph.findAllNodes("aten::to"):
150
- inputs = list(node.inputs())
151
- for i in [1, 2]: # dtype can be the second or third argument to aten::to()
152
- if inputs[i].node()["value"] == 5:
153
- inputs[i].node().copyAttributes(float_node)
154
-
155
- model.apply(patch_float)
156
- patch_float(model.encode_image)
157
- patch_float(model.encode_text)
158
-
159
- model.float()
160
-
161
- return model, _transform(model.input_resolution.item())
162
-
163
-
164
- def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
165
- """
166
- Returns the tokenized representation of given input string(s)
167
-
168
- Parameters
169
- ----------
170
- texts : Union[str, List[str]]
171
- An input string or a list of input strings to tokenize
172
-
173
- context_length : int
174
- The context length to use; all CLIP models use 77 as the context length
175
-
176
- Returns
177
- -------
178
- A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
179
- """
180
- if isinstance(texts, str):
181
- texts = [texts]
182
-
183
- sot_token = _tokenizer.encoder["<|startoftext|>"]
184
- eot_token = _tokenizer.encoder["<|endoftext|>"]
185
- all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
186
- result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
187
-
188
- for i, tokens in enumerate(all_tokens):
189
- if len(tokens) > context_length:
190
- raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
191
- result[i, :len(tokens)] = torch.tensor(tokens)
192
-
193
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
clip/model.py DELETED
@@ -1,437 +0,0 @@
1
- from collections import OrderedDict
2
- from typing import Tuple, Union
3
-
4
- import torch
5
- import torch.nn.functional as F
6
- from torch import nn
7
-
8
-
9
- class Bottleneck(nn.Module):
10
- expansion = 4
11
-
12
- def __init__(self, inplanes, planes, stride=1):
13
- super().__init__()
14
-
15
- # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
16
- self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
17
- self.bn1 = nn.BatchNorm2d(planes)
18
-
19
- self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
20
- self.bn2 = nn.BatchNorm2d(planes)
21
-
22
- self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
23
-
24
- self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
25
- self.bn3 = nn.BatchNorm2d(planes * self.expansion)
26
-
27
- self.relu = nn.ReLU(inplace=True)
28
- self.downsample = None
29
- self.stride = stride
30
-
31
- if stride > 1 or inplanes != planes * Bottleneck.expansion:
32
- # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
33
- self.downsample = nn.Sequential(OrderedDict([
34
- ("-1", nn.AvgPool2d(stride)),
35
- ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
36
- ("1", nn.BatchNorm2d(planes * self.expansion))
37
- ]))
38
-
39
- def forward(self, x: torch.Tensor):
40
- identity = x
41
-
42
- out = self.relu(self.bn1(self.conv1(x)))
43
- out = self.relu(self.bn2(self.conv2(out)))
44
- out = self.avgpool(out)
45
- out = self.bn3(self.conv3(out))
46
-
47
- if self.downsample is not None:
48
- identity = self.downsample(x)
49
-
50
- out += identity
51
- out = self.relu(out)
52
- return out
53
-
54
-
55
- class AttentionPool2d(nn.Module):
56
- def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
57
- super().__init__()
58
- self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
59
- self.k_proj = nn.Linear(embed_dim, embed_dim)
60
- self.q_proj = nn.Linear(embed_dim, embed_dim)
61
- self.v_proj = nn.Linear(embed_dim, embed_dim)
62
- self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
63
- self.num_heads = num_heads
64
-
65
- def forward(self, x):
66
- x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
67
- # print(x.shape, self.positional_embedding.shape)
68
- x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
69
- x = x + self.positional_embedding[0, :, None, :].to(x.dtype) # (HW+1)NC
70
- x, _ = F.multi_head_attention_forward(
71
- query=x, key=x, value=x,
72
- embed_dim_to_check=x.shape[-1],
73
- num_heads=self.num_heads,
74
- q_proj_weight=self.q_proj.weight,
75
- k_proj_weight=self.k_proj.weight,
76
- v_proj_weight=self.v_proj.weight,
77
- in_proj_weight=None,
78
- in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
79
- bias_k=None,
80
- bias_v=None,
81
- add_zero_attn=False,
82
- dropout_p=0,
83
- out_proj_weight=torch.ones_like(self.q_proj.weight),
84
- out_proj_bias=torch.zeros_like(self.q_proj.bias),
85
- # out_proj_weight=self.c_proj.weight,
86
- # out_proj_bias=self.c_proj.bias,
87
- use_separate_proj_weight=True,
88
- training=self.training,
89
- need_weights=False
90
- )
91
-
92
- return x[0]
93
-
94
-
95
- class ModifiedResNet(nn.Module):
96
- """
97
- A ResNet class that is similar to torchvision's but contains the following changes:
98
- - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
99
- - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
100
- - The final pooling layer is a QKV attention instead of an average pool
101
- """
102
-
103
- def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
104
- super().__init__()
105
- self.output_dim = output_dim
106
- self.input_resolution = input_resolution
107
-
108
- # the 3-layer stem
109
- self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
110
- self.bn1 = nn.BatchNorm2d(width // 2)
111
- self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
112
- self.bn2 = nn.BatchNorm2d(width // 2)
113
- self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
114
- self.bn3 = nn.BatchNorm2d(width)
115
- self.avgpool = nn.AvgPool2d(2)
116
- self.relu = nn.ReLU(inplace=True)
117
-
118
- # residual layers
119
- self._inplanes = width # this is a *mutable* variable used during construction
120
- self.layer1 = self._make_layer(width, layers[0])
121
- self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
122
- self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
123
- self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
124
-
125
- embed_dim = width * 32 # the ResNet feature dimension
126
- self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
127
-
128
- def _make_layer(self, planes, blocks, stride=1):
129
- layers = [Bottleneck(self._inplanes, planes, stride)]
130
-
131
- self._inplanes = planes * Bottleneck.expansion
132
- for _ in range(1, blocks):
133
- layers.append(Bottleneck(self._inplanes, planes))
134
-
135
- return nn.Sequential(*layers)
136
-
137
- def forward(self, x):
138
- def stem(x):
139
- for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
140
- x = self.relu(bn(conv(x)))
141
- x = self.avgpool(x)
142
- return x
143
-
144
- x = x.type(self.conv1.weight.dtype)
145
- x = stem(x)
146
- x = self.layer1(x)
147
- x = self.layer2(x)
148
- x = self.layer3(x)
149
- x = self.layer4(x)
150
- # print(x.shape)
151
- # x = self.attnpool(x)
152
- attnpool = self.attnpool(x)
153
-
154
- return (x, attnpool)
155
-
156
-
157
- class LayerNorm(nn.LayerNorm):
158
- """Subclass torch's LayerNorm to handle fp16."""
159
-
160
- def forward(self, x: torch.Tensor):
161
- orig_type = x.dtype
162
- ret = super().forward(x.type(torch.float32))
163
- return ret.type(orig_type)
164
-
165
-
166
- class QuickGELU(nn.Module):
167
- def forward(self, x: torch.Tensor):
168
- return x * torch.sigmoid(1.702 * x)
169
-
170
-
171
- class ResidualAttentionBlock(nn.Module):
172
- def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
173
- super().__init__()
174
-
175
- self.attn = nn.MultiheadAttention(d_model, n_head)
176
- self.ln_1 = LayerNorm(d_model)
177
- self.mlp = nn.Sequential(OrderedDict([
178
- ("c_fc", nn.Linear(d_model, d_model * 4)),
179
- ("gelu", QuickGELU()),
180
- ("c_proj", nn.Linear(d_model * 4, d_model))
181
- ]))
182
- self.ln_2 = LayerNorm(d_model)
183
- self.attn_mask = attn_mask
184
-
185
- def attention(self, x: torch.Tensor):
186
- self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
187
- return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
188
-
189
- def forward(self, x: torch.Tensor):
190
- x = x + self.attention(self.ln_1(x))
191
- x = x + self.mlp(self.ln_2(x))
192
- return x
193
-
194
-
195
- class Transformer(nn.Module):
196
- def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
197
- super().__init__()
198
- self.width = width
199
- self.layers = layers
200
- self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
201
-
202
- def forward(self, x: torch.Tensor):
203
- return self.resblocks(x)
204
-
205
-
206
- class VisualTransformer(nn.Module):
207
- def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
208
- super().__init__()
209
- self.input_resolution = input_resolution
210
- self.output_dim = output_dim
211
- self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
212
-
213
- scale = width ** -0.5
214
- self.class_embedding = nn.Parameter(scale * torch.randn(width))
215
- self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
216
- self.ln_pre = LayerNorm(width)
217
-
218
- self.transformer = Transformer(width, layers, heads)
219
-
220
- self.ln_post = LayerNorm(width)
221
- self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
222
-
223
- def forward(self, x: torch.Tensor):
224
- x = self.conv1(x) # shape = [*, width, grid, grid]
225
- x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
226
- x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
227
- x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
228
- x = x + self.positional_embedding.to(x.dtype)
229
- x = self.ln_pre(x)
230
-
231
- x = x.permute(1, 0, 2) # NLD -> LND
232
- x = self.transformer(x)
233
- x = x.permute(1, 0, 2) # LND -> NLD
234
-
235
- # x = self.ln_post(x[:, 0, :])
236
-
237
- x = self.ln_post(x)
238
- # if self.proj is not None:
239
- # x = x @ self.proj
240
-
241
- return x
242
-
243
-
244
- class CLIP(nn.Module):
245
- def __init__(self,
246
- embed_dim: int,
247
- # vision
248
- image_resolution: int,
249
- vision_layers: Union[Tuple[int, int, int, int], int],
250
- vision_width: int,
251
- vision_patch_size: int,
252
- # text
253
- context_length: int,
254
- vocab_size: int,
255
- transformer_width: int,
256
- transformer_heads: int,
257
- transformer_layers: int
258
- ):
259
- super().__init__()
260
-
261
- self.context_length = context_length
262
-
263
- if isinstance(vision_layers, (tuple, list)):
264
- vision_heads = vision_width * 32 // 64
265
- self.visual = ModifiedResNet(
266
- layers=vision_layers,
267
- output_dim=embed_dim,
268
- heads=vision_heads,
269
- input_resolution=image_resolution,
270
- width=vision_width
271
- )
272
- else:
273
- vision_heads = vision_width // 64
274
- self.visual = VisualTransformer(
275
- input_resolution=image_resolution,
276
- patch_size=vision_patch_size,
277
- width=vision_width,
278
- layers=vision_layers,
279
- heads=vision_heads,
280
- output_dim=embed_dim
281
- )
282
-
283
- self.transformer = Transformer(
284
- width=transformer_width,
285
- layers=transformer_layers,
286
- heads=transformer_heads,
287
- attn_mask=self.build_attention_mask()
288
- )
289
-
290
- self.vocab_size = vocab_size
291
- self.token_embedding = nn.Embedding(vocab_size, transformer_width)
292
- self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
293
- self.ln_final = LayerNorm(transformer_width)
294
-
295
- self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
296
- self.logit_scale = nn.Parameter(torch.ones([]))
297
-
298
- self.initialize_parameters()
299
-
300
- def initialize_parameters(self):
301
- nn.init.normal_(self.token_embedding.weight, std=0.02)
302
- nn.init.normal_(self.positional_embedding, std=0.01)
303
-
304
- if isinstance(self.visual, ModifiedResNet):
305
- if self.visual.attnpool is not None:
306
- std = self.visual.attnpool.c_proj.in_features ** -0.5
307
- nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
308
- nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
309
- nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
310
- nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
311
-
312
- for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
313
- for name, param in resnet_block.named_parameters():
314
- if name.endswith("bn3.weight"):
315
- nn.init.zeros_(param)
316
-
317
- proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
318
- attn_std = self.transformer.width ** -0.5
319
- fc_std = (2 * self.transformer.width) ** -0.5
320
- for block in self.transformer.resblocks:
321
- nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
322
- nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
323
- nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
324
- nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
325
-
326
- if self.text_projection is not None:
327
- nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
328
-
329
- def build_attention_mask(self):
330
- # lazily create causal attention mask, with full attention between the vision tokens
331
- # pytorch uses additive attention mask; fill with -inf
332
- mask = torch.empty(self.context_length, self.context_length)
333
- mask.fill_(float("-inf"))
334
- mask.triu_(1) # zero out the lower diagonal
335
- return mask
336
-
337
- @property
338
- def dtype(self):
339
- return self.visual.conv1.weight.dtype
340
-
341
- def encode_image(self, image):
342
- return self.visual(image.type(self.dtype))
343
-
344
- def encode_text(self, text):
345
- x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
346
-
347
- x = x + self.positional_embedding.type(self.dtype)
348
- x = x.permute(1, 0, 2) # NLD -> LND
349
- x = self.transformer(x)
350
- x = x.permute(1, 0, 2) # LND -> NLD
351
- x = self.ln_final(x).type(self.dtype)
352
-
353
- # x.shape = [batch_size, n_ctx, transformer.width]
354
- # take features from the eot embedding (eot_token is the highest number in each sequence)
355
- x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
356
-
357
- return x
358
-
359
- def forward(self, image, text):
360
- image_features = self.encode_image(image)
361
- text_features = self.encode_text(text)
362
-
363
- # normalized features
364
- image_features = image_features / image_features.norm(dim=-1, keepdim=True)
365
- text_features = text_features / text_features.norm(dim=-1, keepdim=True)
366
-
367
- # cosine similarity as logits
368
- logit_scale = self.logit_scale.exp()
369
- logits_per_image = logit_scale * image_features @ text_features.t()
370
- logits_per_text = logit_scale * text_features @ image_features.t()
371
-
372
- # shape = [global_batch_size, global_batch_size]
373
- return logits_per_image, logits_per_text
374
-
375
-
376
- def convert_weights(model: nn.Module):
377
- """Convert applicable model parameters to fp16"""
378
-
379
- def _convert_weights_to_fp16(l):
380
- if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
381
- l.weight.data = l.weight.data.half()
382
- if l.bias is not None:
383
- l.bias.data = l.bias.data.half()
384
-
385
- if isinstance(l, nn.MultiheadAttention):
386
- for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
387
- tensor = getattr(l, attr)
388
- if tensor is not None:
389
- tensor.data = tensor.data.half()
390
-
391
- for name in ["text_projection", "proj"]:
392
- if hasattr(l, name):
393
- attr = getattr(l, name)
394
- if attr is not None:
395
- attr.data = attr.data.half()
396
-
397
- model.apply(_convert_weights_to_fp16)
398
-
399
-
400
- def build_model(state_dict: dict):
401
- vit = "visual.proj" in state_dict
402
-
403
- if vit:
404
- vision_width = state_dict["visual.conv1.weight"].shape[0]
405
- vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
406
- vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
407
- grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
408
- image_resolution = vision_patch_size * grid_size
409
- else:
410
- counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
411
- vision_layers = tuple(counts)
412
- vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
413
- output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
414
- vision_patch_size = None
415
- assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
416
- image_resolution = output_width * 32
417
-
418
- embed_dim = state_dict["text_projection"].shape[1]
419
- context_length = state_dict["positional_embedding"].shape[0]
420
- vocab_size = state_dict["token_embedding.weight"].shape[0]
421
- transformer_width = state_dict["ln_final.weight"].shape[0]
422
- transformer_heads = transformer_width // 64
423
- transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
424
-
425
- model = CLIP(
426
- embed_dim,
427
- image_resolution, vision_layers, vision_width, vision_patch_size,
428
- context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
429
- )
430
-
431
- for key in ["input_resolution", "context_length", "vocab_size"]:
432
- if key in state_dict:
433
- del state_dict[key]
434
-
435
- convert_weights(model)
436
- model.load_state_dict(state_dict)
437
- return model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
clip/simple_tokenizer.py DELETED
@@ -1,132 +0,0 @@
1
- import gzip
2
- import html
3
- import os
4
- from functools import lru_cache
5
-
6
- import ftfy
7
- import regex as re
8
-
9
-
10
- @lru_cache()
11
- def default_bpe():
12
- return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
13
-
14
-
15
- @lru_cache()
16
- def bytes_to_unicode():
17
- """
18
- Returns list of utf-8 byte and a corresponding list of unicode strings.
19
- The reversible bpe codes work on unicode strings.
20
- This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
21
- When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
22
- This is a signficant percentage of your normal, say, 32K bpe vocab.
23
- To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
24
- And avoids mapping to whitespace/control characters the bpe code barfs on.
25
- """
26
- bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
27
- cs = bs[:]
28
- n = 0
29
- for b in range(2**8):
30
- if b not in bs:
31
- bs.append(b)
32
- cs.append(2**8+n)
33
- n += 1
34
- cs = [chr(n) for n in cs]
35
- return dict(zip(bs, cs))
36
-
37
-
38
- def get_pairs(word):
39
- """Return set of symbol pairs in a word.
40
- Word is represented as tuple of symbols (symbols being variable-length strings).
41
- """
42
- pairs = set()
43
- prev_char = word[0]
44
- for char in word[1:]:
45
- pairs.add((prev_char, char))
46
- prev_char = char
47
- return pairs
48
-
49
-
50
- def basic_clean(text):
51
- text = ftfy.fix_text(text)
52
- text = html.unescape(html.unescape(text))
53
- return text.strip()
54
-
55
-
56
- def whitespace_clean(text):
57
- text = re.sub(r'\s+', ' ', text)
58
- text = text.strip()
59
- return text
60
-
61
-
62
- class SimpleTokenizer(object):
63
- def __init__(self, bpe_path: str = default_bpe()):
64
- self.byte_encoder = bytes_to_unicode()
65
- self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
66
- merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
67
- merges = merges[1:49152-256-2+1]
68
- merges = [tuple(merge.split()) for merge in merges]
69
- vocab = list(bytes_to_unicode().values())
70
- vocab = vocab + [v+'</w>' for v in vocab]
71
- for merge in merges:
72
- vocab.append(''.join(merge))
73
- vocab.extend(['<|startoftext|>', '<|endoftext|>'])
74
- self.encoder = dict(zip(vocab, range(len(vocab))))
75
- self.decoder = {v: k for k, v in self.encoder.items()}
76
- self.bpe_ranks = dict(zip(merges, range(len(merges))))
77
- self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
78
- self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
79
-
80
- def bpe(self, token):
81
- if token in self.cache:
82
- return self.cache[token]
83
- word = tuple(token[:-1]) + ( token[-1] + '</w>',)
84
- pairs = get_pairs(word)
85
-
86
- if not pairs:
87
- return token+'</w>'
88
-
89
- while True:
90
- bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
91
- if bigram not in self.bpe_ranks:
92
- break
93
- first, second = bigram
94
- new_word = []
95
- i = 0
96
- while i < len(word):
97
- try:
98
- j = word.index(first, i)
99
- new_word.extend(word[i:j])
100
- i = j
101
- except:
102
- new_word.extend(word[i:])
103
- break
104
-
105
- if word[i] == first and i < len(word)-1 and word[i+1] == second:
106
- new_word.append(first+second)
107
- i += 2
108
- else:
109
- new_word.append(word[i])
110
- i += 1
111
- new_word = tuple(new_word)
112
- word = new_word
113
- if len(word) == 1:
114
- break
115
- else:
116
- pairs = get_pairs(word)
117
- word = ' '.join(word)
118
- self.cache[token] = word
119
- return word
120
-
121
- def encode(self, text):
122
- bpe_tokens = []
123
- text = whitespace_clean(basic_clean(text)).lower()
124
- for token in re.findall(self.pat, text):
125
- token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
126
- bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
127
- return bpe_tokens
128
-
129
- def decode(self, tokens):
130
- text = ''.join([self.decoder[token] for token in tokens])
131
- text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
132
- return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/phase1/FineCapEval_clipRN50_mle.yml DELETED
@@ -1,60 +0,0 @@
1
- caption_model: transformer
2
- noamopt: true
3
- noamopt_warmup: 20000
4
- label_smoothing: 0.0
5
- input_json: data/FineCapEval.json
6
- input_label_h5: none
7
- input_fc_dir: data/FineCapEval_clip_RN50_fc
8
- input_att_dir: data/FineCapEval_clip_RN50_att
9
- input_clipscore_vis_dir: data/FineCapEval_clipscore_vis
10
-
11
- seq_per_img: 5
12
- batch_size: 200
13
- learning_rate: 0.0005
14
-
15
- checkpoint_path: ./save/clipRN50_mle/clipRN50_mle
16
-
17
- # clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt'
18
-
19
- # Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
20
- # N=num_layers
21
- # d_model=input_encoding_size
22
- # d_ff=rnn_size
23
-
24
- # will be ignored
25
- num_layers: 6
26
- input_encoding_size: 512
27
- rnn_size: 2048
28
-
29
- # Transformer config
30
- N_enc: 6
31
- N_dec: 6
32
- d_model: 512
33
- d_ff: 2048
34
- num_att_heads: 8
35
- dropout: 0.1
36
-
37
-
38
- learning_rate_decay_start: 0
39
- scheduled_sampling_start: -1
40
- save_checkpoint_every: 3000
41
- language_eval: 1
42
- val_images_use: 5000
43
- max_epochs: 15
44
- train_sample_n: 5
45
-
46
- REFORWARD: false
47
-
48
- # _BASE_: transformer.yml
49
- reduce_on_plateau: false
50
- noamopt: false
51
- learning_rate: 0.000005
52
- learning_rate_decay_start: -1
53
-
54
- self_critical_after: 15
55
- max_epochs: 50
56
-
57
- verbose: false
58
- precision: 32
59
-
60
- use_clipscore: false
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/phase1/clipRN50_mle.yml DELETED
@@ -1,52 +0,0 @@
1
- caption_model: transformer
2
- noamopt: true
3
- # noamopt: false
4
- noamopt_warmup: 20000
5
- label_smoothing: 0.0
6
- input_json: data/cocotalk.json
7
- input_label_h5: data/cocotalk_label.h5
8
- input_fc_dir: data/cocotalk_clip_RN50_fc
9
- input_att_dir: data/cocotalk_clip_RN50_att
10
- input_clipscore_vis_dir: data/cocotalk_clipscore_vis
11
- seq_per_img: 5
12
- # batch_size: 600
13
- batch_size: 200
14
-
15
- learning_rate: 0.0005
16
-
17
- # checkpoint_path: ./save/trans_clip_rn50_sc_pl
18
- checkpoint_path: save/clipRN50_mle/clipRN50_mle
19
-
20
- # Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
21
- # N=num_layers
22
- # d_model=input_encoding_size
23
- # d_ff=rnn_size
24
-
25
- # will be ignored
26
- num_layers: 6
27
- input_encoding_size: 512
28
- rnn_size: 2048
29
-
30
- # Transformer config
31
- N_enc: 6
32
- N_dec: 6
33
- d_model: 512
34
- d_ff: 2048
35
- num_att_heads: 8
36
- dropout: 0.1
37
-
38
-
39
- learning_rate_decay_start: 0
40
- scheduled_sampling_start: -1
41
- save_checkpoint_every: 3000
42
- language_eval: 1
43
- val_images_use: 5000
44
- # max_epochs: 15
45
- max_epochs: 25
46
- train_sample_n: 5
47
-
48
- REFORWARD: false
49
-
50
-
51
- verbose: false
52
- precision: 16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/phase1/transformer.yml DELETED
@@ -1,41 +0,0 @@
1
- caption_model: transformer
2
- noamopt: true
3
- noamopt_warmup: 20000
4
- label_smoothing: 0.0
5
- input_json: data/cocotalk.json
6
- input_label_h5: data/cocotalk_label.h5
7
- input_att_dir: data/cocotalk_att
8
- seq_per_img: 5
9
- batch_size: 10
10
- learning_rate: 0.0005
11
-
12
- checkpoint_path: ./save/trans_rn50_sc
13
-
14
- # Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
15
- # N=num_layers
16
- # d_model=input_encoding_size
17
- # d_ff=rnn_size
18
-
19
- # will be ignored
20
- num_layers: 6
21
- input_encoding_size: 512
22
- rnn_size: 2048
23
-
24
- # Transformer config
25
- N_enc: 6
26
- N_dec: 6
27
- d_model: 512
28
- d_ff: 2048
29
- num_att_heads: 8
30
- dropout: 0.1
31
-
32
-
33
- learning_rate_decay_start: 0
34
- scheduled_sampling_start: -1
35
- save_checkpoint_every: 3000
36
- language_eval: 1
37
- val_images_use: 5000
38
- max_epochs: 15
39
- train_sample_n: 5
40
-
41
- REFORWARD: false
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/phase2/FineCapEval_clipRN50_cider.yml DELETED
@@ -1,61 +0,0 @@
1
- caption_model: transformer
2
- noamopt: true
3
- noamopt_warmup: 20000
4
- label_smoothing: 0.0
5
- input_json: data/FineCapEval.json
6
- input_label_h5: none
7
- input_fc_dir: data/FineCapEval_clip_RN50_fc
8
- input_att_dir: data/FineCapEval_clip_RN50_att
9
- input_clipscore_vis_dir: data/FineCapEval_clipscore_vis
10
-
11
- seq_per_img: 5
12
- batch_size: 200
13
- learning_rate: 0.0005
14
-
15
- checkpoint_path: ./save/clipRN50_cider/clipRN50_cider
16
-
17
- # clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt'
18
-
19
- # Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
20
- # N=num_layers
21
- # d_model=input_encoding_size
22
- # d_ff=rnn_size
23
-
24
- # will be ignored
25
- num_layers: 6
26
- input_encoding_size: 512
27
- rnn_size: 2048
28
-
29
- # Transformer config
30
- N_enc: 6
31
- N_dec: 6
32
- d_model: 512
33
- d_ff: 2048
34
- num_att_heads: 8
35
- dropout: 0.1
36
-
37
-
38
- learning_rate_decay_start: 0
39
- scheduled_sampling_start: -1
40
- save_checkpoint_every: 3000
41
- language_eval: 1
42
- val_images_use: 5000
43
- max_epochs: 15
44
- train_sample_n: 5
45
-
46
- REFORWARD: false
47
-
48
- # _BASE_: transformer.yml
49
- reduce_on_plateau: false
50
- noamopt: false
51
- learning_rate: 0.000005
52
- learning_rate_decay_start: -1
53
-
54
- self_critical_after: 15
55
- max_epochs: 50
56
-
57
- verbose: false
58
- precision: 32
59
-
60
- # use_clipscore: true
61
- use_clipscore: false
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/phase2/FineCapEval_clipRN50_cider_clips.yml DELETED
@@ -1,65 +0,0 @@
1
- caption_model: transformer
2
- noamopt: true
3
- noamopt_warmup: 20000
4
- label_smoothing: 0.0
5
- input_json: data/FineCapEval.json
6
- input_label_h5: none
7
- input_fc_dir: data/FineCapEval_clip_RN50_fc
8
- input_att_dir: data/FineCapEval_clip_RN50_att
9
- input_clipscore_vis_dir: data/FineCapEval_clipscore_vis
10
-
11
- seq_per_img: 5
12
- batch_size: 200
13
- learning_rate: 0.0005
14
-
15
- checkpoint_path: ./save/clipRN50_cider_clips/clipRN50_cider_clips
16
-
17
- # clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt'
18
-
19
- # Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
20
- # N=num_layers
21
- # d_model=input_encoding_size
22
- # d_ff=rnn_size
23
-
24
- # will be ignored
25
- num_layers: 6
26
- input_encoding_size: 512
27
- rnn_size: 2048
28
-
29
- # Transformer config
30
- N_enc: 6
31
- N_dec: 6
32
- d_model: 512
33
- d_ff: 2048
34
- num_att_heads: 8
35
- dropout: 0.1
36
-
37
-
38
- learning_rate_decay_start: 0
39
- scheduled_sampling_start: -1
40
- save_checkpoint_every: 3000
41
- language_eval: 1
42
- val_images_use: 5000
43
- max_epochs: 15
44
- train_sample_n: 5
45
-
46
- REFORWARD: false
47
-
48
- # _BASE_: transformer.yml
49
- reduce_on_plateau: false
50
- noamopt: false
51
- learning_rate: 0.000005
52
- learning_rate_decay_start: -1
53
-
54
- self_critical_after: 15
55
- max_epochs: 50
56
-
57
- verbose: false
58
- precision: 32
59
-
60
- # use_clipscore: true
61
- use_clipscore: false
62
- clipscore_reward_weight: 2.0
63
- clipscore_mode: clip_s
64
-
65
- use_multi_rewards: true
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/phase2/FineCapEval_clipRN50_clips.yml DELETED
@@ -1,64 +0,0 @@
1
- caption_model: transformer
2
- noamopt: true
3
- noamopt_warmup: 20000
4
- label_smoothing: 0.0
5
- input_json: data/FineCapEval.json
6
- input_label_h5: none
7
- input_fc_dir: data/FineCapEval_clip_RN50_fc
8
- input_att_dir: data/FineCapEval_clip_RN50_att
9
- input_clipscore_vis_dir: data/FineCapEval_clipscore_vis
10
- seq_per_img: 5
11
- batch_size: 160
12
- learning_rate: 0.0005
13
-
14
- checkpoint_path: ./save/clipRN50_clips/clipRN50_clips
15
-
16
- use_multi_rewards: false
17
- use_grammar: false
18
- use_grammar_baseline: false
19
- # clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt'
20
-
21
- # Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
22
- # N=num_layers
23
- # d_model=input_encoding_size
24
- # d_ff=rnn_size
25
-
26
- # will be ignored
27
- num_layers: 6
28
- input_encoding_size: 512
29
- rnn_size: 2048
30
-
31
- # Transformer config
32
- N_enc: 6
33
- N_dec: 6
34
- d_model: 512
35
- d_ff: 2048
36
- num_att_heads: 8
37
- dropout: 0.1
38
-
39
-
40
- learning_rate_decay_start: 0
41
- scheduled_sampling_start: -1
42
- save_checkpoint_every: 3000
43
- language_eval: 0
44
- val_images_use: 5000
45
- max_epochs: 15
46
- train_sample_n: 5
47
-
48
- REFORWARD: false
49
-
50
- # _BASE_: transformer.yml
51
- reduce_on_plateau: false
52
- noamopt: false
53
- learning_rate: 0.000005
54
- learning_rate_decay_start: -1
55
-
56
- self_critical_after: 15
57
- max_epochs: 50
58
-
59
- verbose: false
60
- precision: 32
61
-
62
- # use_clipscore: true
63
- use_clipscore: false
64
- clipscore_reward_weight: 2.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/phase2/FineCapEval_clipRN50_clips_grammar.yml DELETED
@@ -1,64 +0,0 @@
1
- caption_model: transformer
2
- noamopt: true
3
- noamopt_warmup: 20000
4
- label_smoothing: 0.0
5
- input_json: data/FineCapEval.json
6
- input_label_h5: none
7
- input_fc_dir: data/FineCapEval_clip_RN50_fc
8
- input_att_dir: data/FineCapEval_clip_RN50_att
9
- input_clipscore_vis_dir: data/FineCapEval_clipscore_vis
10
- seq_per_img: 5
11
- batch_size: 160
12
- learning_rate: 0.0005
13
-
14
- checkpoint_path: ./save/clipRN50_clips_grammar/clipRN50_clips_grammar
15
-
16
- use_multi_rewards: true
17
- use_grammar: true
18
- use_grammar_baseline: true
19
- # clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt'
20
-
21
- # Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
22
- # N=num_layers
23
- # d_model=input_encoding_size
24
- # d_ff=rnn_size
25
-
26
- # will be ignored
27
- num_layers: 6
28
- input_encoding_size: 512
29
- rnn_size: 2048
30
-
31
- # Transformer config
32
- N_enc: 6
33
- N_dec: 6
34
- d_model: 512
35
- d_ff: 2048
36
- num_att_heads: 8
37
- dropout: 0.1
38
-
39
-
40
- learning_rate_decay_start: 0
41
- scheduled_sampling_start: -1
42
- save_checkpoint_every: 3000
43
- language_eval: 0
44
- val_images_use: 5000
45
- max_epochs: 15
46
- train_sample_n: 5
47
-
48
- REFORWARD: false
49
-
50
- # _BASE_: transformer.yml
51
- reduce_on_plateau: false
52
- noamopt: false
53
- learning_rate: 0.000005
54
- learning_rate_decay_start: -1
55
-
56
- self_critical_after: 15
57
- max_epochs: 50
58
-
59
- verbose: false
60
- precision: 32
61
-
62
- # use_clipscore: true
63
- use_clipscore: false
64
- clipscore_reward_weight: 2.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/phase2/clipRN50_cider.yml DELETED
@@ -1,58 +0,0 @@
1
- caption_model: transformer
2
- noamopt: true
3
- noamopt_warmup: 20000
4
- label_smoothing: 0.0
5
- input_json: data/cocotalk.json
6
- input_label_h5: data/cocotalk_label.h5
7
- input_fc_dir: data/cocotalk_clip_RN50_fc
8
- input_att_dir: data/cocotalk_clip_RN50_att
9
- # used only for evaluation
10
- input_clipscore_vis_dir: data/cocotalk_clipscore_vis
11
-
12
- seq_per_img: 5
13
- batch_size: 200
14
- learning_rate: 0.0005
15
-
16
- # checkpoint_path: ./save/trans_clip_rn50_sc_pl_scst_cider
17
- checkpoint_path: save/clipRN50_cider/clipRN50_cider
18
-
19
- # Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
20
- # N=num_layers
21
- # d_model=input_encoding_size
22
- # d_ff=rnn_size
23
-
24
- # will be ignored
25
- num_layers: 6
26
- input_encoding_size: 512
27
- rnn_size: 2048
28
-
29
- # Transformer config
30
- N_enc: 6
31
- N_dec: 6
32
- d_model: 512
33
- d_ff: 2048
34
- num_att_heads: 8
35
- dropout: 0.1
36
-
37
-
38
- learning_rate_decay_start: 0
39
- scheduled_sampling_start: -1
40
- save_checkpoint_every: 3000
41
- language_eval: 1
42
- val_images_use: 5000
43
- max_epochs: 15
44
- train_sample_n: 5
45
-
46
- REFORWARD: false
47
-
48
- # _BASE_: transformer.yml
49
- reduce_on_plateau: false
50
- noamopt: false
51
- learning_rate: 0.000005
52
- learning_rate_decay_start: -1
53
-
54
- self_critical_after: 15
55
- max_epochs: 40
56
-
57
- verbose: false
58
- precision: 32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/phase2/clipRN50_cider_clips.yml DELETED
@@ -1,61 +0,0 @@
1
- caption_model: transformer
2
- noamopt: true
3
- noamopt_warmup: 20000
4
- label_smoothing: 0.0
5
- input_json: data/cocotalk.json
6
- input_label_h5: data/cocotalk_label.h5
7
- input_fc_dir: data/cocotalk_clip_RN50_fc
8
- input_att_dir: data/cocotalk_clip_RN50_att
9
- input_clipscore_vis_dir: data/cocotalk_clipscore_vis
10
- seq_per_img: 5
11
- batch_size: 160
12
- learning_rate: 0.0005
13
-
14
- checkpoint_path: save/clipRN50_cider_clips/clipRN50_cider_clips
15
-
16
- # Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
17
- # N=num_layers
18
- # d_model=input_encoding_size
19
- # d_ff=rnn_size
20
-
21
- # will be ignored
22
- num_layers: 6
23
- input_encoding_size: 512
24
- rnn_size: 2048
25
-
26
- # Transformer config
27
- N_enc: 6
28
- N_dec: 6
29
- d_model: 512
30
- d_ff: 2048
31
- num_att_heads: 8
32
- dropout: 0.1
33
-
34
-
35
- learning_rate_decay_start: 0
36
- scheduled_sampling_start: -1
37
- save_checkpoint_every: 3000
38
- language_eval: 1
39
- val_images_use: 5000
40
- max_epochs: 15
41
- train_sample_n: 5
42
-
43
- REFORWARD: false
44
-
45
- # _BASE_: transformer.yml
46
- reduce_on_plateau: false
47
- noamopt: false
48
- learning_rate: 0.000005
49
- learning_rate_decay_start: -1
50
-
51
- self_critical_after: 15
52
- max_epochs: 40
53
-
54
- verbose: false
55
- precision: 32
56
-
57
- use_clipscore: true
58
- clipscore_reward_weight: 2.0
59
- clipscore_mode: clip_s
60
-
61
- use_multi_rewards: true
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/phase2/clipRN50_clips.yml DELETED
@@ -1,58 +0,0 @@
1
- caption_model: transformer
2
- noamopt: true
3
- noamopt_warmup: 20000
4
- label_smoothing: 0.0
5
- input_json: data/cocotalk.json
6
- input_label_h5: data/cocotalk_label.h5
7
- input_fc_dir: data/cocotalk_clip_RN50_fc
8
- input_att_dir: data/cocotalk_clip_RN50_att
9
- input_clipscore_vis_dir: data/cocotalk_clipscore_vis
10
- seq_per_img: 5
11
- batch_size: 160
12
- learning_rate: 0.0005
13
-
14
- checkpoint_path: save/clipRN50_clips/clipRN50_clips
15
-
16
- # Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
17
- # N=num_layers
18
- # d_model=input_encoding_size
19
- # d_ff=rnn_size
20
-
21
- # will be ignored
22
- num_layers: 6
23
- input_encoding_size: 512
24
- rnn_size: 2048
25
-
26
- # Transformer config
27
- N_enc: 6
28
- N_dec: 6
29
- d_model: 512
30
- d_ff: 2048
31
- num_att_heads: 8
32
- dropout: 0.1
33
-
34
-
35
- learning_rate_decay_start: 0
36
- scheduled_sampling_start: -1
37
- save_checkpoint_every: 3000
38
- language_eval: 1
39
- val_images_use: 5000
40
- max_epochs: 15
41
- train_sample_n: 5
42
-
43
- REFORWARD: false
44
-
45
- # _BASE_: transformer.yml
46
- reduce_on_plateau: false
47
- noamopt: false
48
- learning_rate: 0.000005
49
- learning_rate_decay_start: -1
50
-
51
- self_critical_after: 15
52
- max_epochs: 40
53
-
54
- verbose: false
55
- precision: 32
56
-
57
- use_clipscore: true
58
- clipscore_reward_weight: 2.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/phase2/clipRN50_clips_grammar.yml DELETED
@@ -1,64 +0,0 @@
1
- caption_model: transformer
2
- noamopt: true
3
- noamopt_warmup: 20000
4
- label_smoothing: 0.0
5
- input_json: data/cocotalk.json
6
- input_label_h5: data/cocotalk_label.h5
7
- input_fc_dir: data/cocotalk_clip_RN50_fc
8
- input_att_dir: data/cocotalk_clip_RN50_att
9
- input_clipscore_vis_dir: data/cocotalk_clipscore_vis
10
- seq_per_img: 5
11
- batch_size: 160
12
- learning_rate: 0.0005
13
-
14
- checkpoint_path: save/clipRN50_clips_grammar/clipRN50_clips_grammar
15
-
16
- use_multi_rewards: true
17
- use_grammar: true
18
- use_grammar_baseline: true
19
- # clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt'
20
- clip_load_path: 'retrieval/save/clip_negative_text/clip_negative_text-epoch=12.ckpt'
21
-
22
- # Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
23
- # N=num_layers
24
- # d_model=input_encoding_size
25
- # d_ff=rnn_size
26
-
27
- # will be ignored
28
- num_layers: 6
29
- input_encoding_size: 512
30
- rnn_size: 2048
31
-
32
- # Transformer config
33
- N_enc: 6
34
- N_dec: 6
35
- d_model: 512
36
- d_ff: 2048
37
- num_att_heads: 8
38
- dropout: 0.1
39
-
40
-
41
- learning_rate_decay_start: 0
42
- scheduled_sampling_start: -1
43
- save_checkpoint_every: 3000
44
- language_eval: 1
45
- val_images_use: 5000
46
- max_epochs: 15
47
- train_sample_n: 5
48
-
49
- REFORWARD: false
50
-
51
- # _BASE_: transformer.yml
52
- reduce_on_plateau: false
53
- noamopt: false
54
- learning_rate: 0.000005
55
- learning_rate_decay_start: -1
56
-
57
- self_critical_after: 15
58
- max_epochs: 40
59
-
60
- verbose: false
61
- precision: 32
62
-
63
- use_clipscore: true
64
- clipscore_reward_weight: 2.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/phase2/transformer.yml DELETED
@@ -1,41 +0,0 @@
1
- caption_model: transformer
2
- noamopt: true
3
- noamopt_warmup: 20000
4
- label_smoothing: 0.0
5
- input_json: data/cocotalk.json
6
- input_label_h5: data/cocotalk_label.h5
7
- input_att_dir: data/cocotalk_att
8
- seq_per_img: 5
9
- batch_size: 10
10
- learning_rate: 0.0005
11
-
12
- checkpoint_path: ./save/trans_rn50_sc
13
-
14
- # Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
15
- # N=num_layers
16
- # d_model=input_encoding_size
17
- # d_ff=rnn_size
18
-
19
- # will be ignored
20
- num_layers: 6
21
- input_encoding_size: 512
22
- rnn_size: 2048
23
-
24
- # Transformer config
25
- N_enc: 6
26
- N_dec: 6
27
- d_model: 512
28
- d_ff: 2048
29
- num_att_heads: 8
30
- dropout: 0.1
31
-
32
-
33
- learning_rate_decay_start: 0
34
- scheduled_sampling_start: -1
35
- save_checkpoint_every: 3000
36
- language_eval: 1
37
- val_images_use: 5000
38
- max_epochs: 15
39
- train_sample_n: 5
40
-
41
- REFORWARD: false
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data/README.md DELETED
@@ -1 +0,0 @@
1
- directory to store preprocessed files
 
 
retrieval/README.md DELETED
@@ -1,5 +0,0 @@
1
- # Finetuning CLIP reward model
2
-
3
- ```bash
4
- python train_pl.py --cfg clip_negative_text --id clip_negative_text
5
- ```
 
 
 
 
 
 
retrieval/caption_data.py DELETED
@@ -1,500 +0,0 @@
1
- from torch.utils.data import DataLoader, Dataset, Sampler
2
- from pathlib import Path
3
- import json
4
- from multiprocessing import Pool
5
- from tqdm import tqdm
6
- from PIL import Image
7
- import random
8
- import numpy as np
9
- import torch
10
- import torchvision
11
- import torchvision.transforms as T
12
-
13
- from torch.utils.data.distributed import DistributedSampler
14
-
15
- from transformers import T5Tokenizer, BertTokenizer, BertTokenizerFast, CLIPTokenizer
16
-
17
- import text_utils
18
-
19
- project_dir = Path(__file__).parent.resolve()
20
- workspace_dir = project_dir.parent.parent
21
- dataset_dir = workspace_dir.joinpath('datasets/').resolve()
22
- # coco_dir = dataset_dir.joinpath('COCO')
23
- # vg_dir = dataset_dir.joinpath('VG')
24
- coco_img_dir = dataset_dir.joinpath('COCO/images/')
25
- coco_data_dir = project_dir.parent.joinpath('CLIP-ViL/CLIP-ViL-Direct/caption/data/')
26
- # coco_feature_dir = coco_dir.joinpath('features')
27
-
28
-
29
- class COCORetrievalDataset(Dataset):
30
- def __init__(self, split='karpathy_train', rank=-1, topk=-1, verbose=True, args=None, mode='train'):
31
- super().__init__()
32
-
33
- self.topk = topk
34
- self.verbose = verbose
35
- self.args = args
36
- self.rank = rank
37
- self.mode = mode
38
-
39
- # Loading datasets to data
40
- self.source = split
41
- if self.verbose:
42
- print('Data source: ', self.source)
43
-
44
- # if self.args.tokenizer is None:
45
- # self.args.tokenizer = self.args.decoder_backbone
46
-
47
- # if 'bert' in self.args.tokenizer:
48
- # self.tokenizer = BertTokenizerFast.from_pretrained(
49
- # self.args.tokenizer,
50
- # # max_length=self.args.max_text_length,
51
- # # do_lower_case=self.args.do_lower_case
52
- # )
53
- # elif 'clip' in self.args.tokenizer:
54
- # self.tokenizer = CLIPTokenizer.from_pretrained(
55
- # self.args.tokenizer,
56
- # # max_length=self.args.max_text_length,
57
- # # do_lower_case=self.args.do_lower_case
58
- # )
59
-
60
- self.tokenizer = CLIPTokenizer.from_pretrained(
61
- self.args.tokenizer,
62
- # max_length=self.args.max_text_length,
63
- # do_lower_case=self.args.do_lower_case
64
- )
65
-
66
- with open(coco_data_dir.joinpath('cocotalk.json')) as f:
67
- self.vocab = list(json.load(f)['ix_to_word'].values())
68
- popped = self.vocab.pop(-1)
69
- assert popped == 'UNK'
70
- if self.verbose:
71
- print('vocab size: ', len(self.vocab))
72
-
73
-
74
- data_info_path = coco_data_dir.joinpath('dataset_coco.json')
75
- with open(data_info_path) as f:
76
- karpathy_data = json.load(f)
77
-
78
- split_rename = {
79
- 'train': 'train',
80
- 'restval': 'train',
81
- 'val': 'val',
82
- 'test': 'test'
83
- }
84
-
85
- n_images = 0
86
-
87
- data = []
88
- # self.vocab = set()
89
- for datum in karpathy_data['images']:
90
- re_split = split_rename[datum['split']]
91
-
92
- # if re_split == 'train':
93
- # for d in datum['sentences']:
94
- # self.vocab = self.vocab.union(set(d['tokens']))
95
-
96
- if re_split != self.source.split('_')[-1]:
97
- continue
98
-
99
- if re_split == 'train':
100
- # for d in datum['sentences']:
101
- # img_id = datum['filename'].split('.')[0]
102
- # new_datum = {
103
- # 'filename': datum['filename'],
104
- # 'img_id': img_id,
105
- # 'sent': d['raw'].strip(),
106
- # 'targets': [d['raw'].strip() for d in datum['sentences']],
107
- # 'is_train': True,
108
- # 'cocoid': datum['cocoid']
109
- # }
110
- # data.append(new_datum)
111
- img_id = datum['filename'].split('.')[0]
112
- new_datum = {
113
- 'filename': datum['filename'],
114
- 'img_id': img_id,
115
- # 'sent': d['raw'],
116
- # 'targets': [d['raw'].strip() for d in datum['sentences']],
117
- 'targets': [" ".join(d['tokens']) for d in datum['sentences']],
118
- 'is_train': True,
119
- 'cocoid': datum['cocoid']
120
- }
121
- data.append(new_datum)
122
-
123
- else:
124
- img_id = datum['filename'].split('.')[0]
125
- new_datum = {
126
- 'filename': datum['filename'],
127
- 'img_id': img_id,
128
- # 'sent': d['raw'],
129
- # 'targets': [d['raw'].strip() for d in datum['sentences']],
130
- 'targets': [" ".join(d['tokens']) for d in datum['sentences']],
131
- 'is_train': False,
132
- 'cocoid': datum['cocoid']
133
- }
134
- data.append(new_datum)
135
-
136
- n_images += 1
137
-
138
- if self.verbose:
139
- print(f"{self.source} has {n_images} images")
140
- # print(f"Loaded {len(data)} data from", split)
141
-
142
- self.n_gpus = torch.cuda.device_count()
143
-
144
- if self.topk > 0:
145
- data = data[:self.topk]
146
- if self.verbose:
147
- print(f"Use only {self.topk} data")
148
-
149
- self.data = data
150
-
151
- # if self.verbose:
152
- # print("# all sentences:", len(self.data))
153
-
154
- if self.args.load_feat:
155
- # feat_dir = coco_dir.joinpath(''
156
- # self.feat_loader = HybridLoader('/scratch-space/CLIP-ViL/CLIP-ViL-Direct/caption/data/cocotalk_clipscore_vis', ext='.npy', in_memory=False)
157
- self.feat_loader = HybridLoader(
158
- coco_data_dir.joinpath('cocotalk_clipscore_vis'),
159
- ext='.npy', in_memory=False)
160
- else:
161
- if 'openai/clip' in self.args.encoder_backbone:
162
- # from transformers import CLIPProcessor
163
- # self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32",
164
- # size=args.image_size,
165
- # do_resize=True,
166
- # do_center_crop=False,
167
- # )
168
- # self.img_transform = lambda image: self.processor.feature_extractor(
169
- # image,
170
- # return_tensors='pt')['pixel_values'][0]
171
-
172
- self.image_mean = [0.48145466, 0.4578275, 0.40821073]
173
- self.image_std = [0.26862954, 0.26130258, 0.27577711]
174
-
175
- # captioning
176
- # self.img_transform = T.Compose([
177
- # T.Resize((self.args.image_size, self.args.image_size))
178
- # ])
179
-
180
- # retrieval
181
- self.img_transform = T.Compose([
182
- T.Resize(self.args.image_size, interpolation=T.functional.InterpolationMode.BICUBIC),
183
- T.CenterCrop(self.args.image_size)
184
- ])
185
-
186
- self.img_tensor_transform = T.Compose([
187
- # T.RandomCrop(224),
188
- # T.RandomHorizontalFlip(p=0.3),
189
- T.ConvertImageDtype(torch.float),
190
- T.Normalize(self.image_mean, self.image_std)
191
- ]
192
- )
193
- # elif 'google/vit' in self.args.encoder_backbone:
194
- # self.image_mean = [0.5, 0.5, 0.5]
195
- # self.image_std = [0.5, 0.5, 0.5]
196
-
197
- # self.img_transform = T.Compose([
198
- # # T.PILToTensor(),
199
- # T.Resize((self.args.image_size, self.args.image_size))
200
- # ])
201
-
202
- # self.img_tensor_transform = T.Compose([
203
- # # T.RandomCrop(224),
204
- # # T.RandomHorizontalFlip(p=0.3),
205
- # T.ConvertImageDtype(torch.float),
206
- # T.Normalize(self.image_mean, self.image_std)
207
- # ]
208
- # )
209
-
210
- def get_negative_text(self, text):
211
- neg_type = random.choice(['repeat', 'remove', 'insert', 'swap', 'shuffle'])
212
-
213
- if neg_type == 'repeat':
214
- text = text_utils.repeat(text)
215
- elif neg_type == 'remove':
216
- text = text_utils.remove(text)
217
- elif neg_type == 'insert':
218
- text = text_utils.insert(text, self.vocab)
219
- elif neg_type == 'swap':
220
- text = text_utils.swap(text, self.vocab)
221
- elif neg_type == 'shuffle':
222
- text = text_utils.shuffle(text)
223
-
224
- return text, neg_type
225
-
226
- def __len__(self):
227
- return len(self.data)
228
-
229
- def __getitem__(self, idx):
230
- datum = self.data[idx]
231
- return self.process_datum(datum)
232
-
233
- def process_datum(self, datum):
234
- out_dict = {}
235
-
236
- ###### Image ######
237
-
238
- if self.args.load_feat:
239
- cocoid = datum['cocoid']
240
- out_dict['cocoid'] = str(cocoid)
241
- img_feat = self.feat_loader.get(str(cocoid))
242
- out_dict['img_feat'] = torch.from_numpy(img_feat)
243
-
244
- else:
245
- img_id = datum['img_id']
246
- out_dict['img_id'] = img_id
247
-
248
- if 'train' in datum['filename']:
249
- img_split = 'train2014'
250
- elif 'val' in datum['filename']:
251
- img_split = 'val2014'
252
- img_path = coco_img_dir.joinpath(img_split).joinpath(datum['filename']).with_suffix('.jpg')
253
- assert img_path.exists()
254
- img_path = str(img_path)
255
- out_dict['img_path'] = img_path
256
-
257
- img_tensor = torchvision.io.read_image(img_path)
258
- # out_dict['img_tensor'] = img
259
-
260
- # img = Image.open(img_path).convert('RGB')
261
- # img_tensor = torch.as_tensor(np.asarray(img))
262
- out_dict['img_tensor'] = self.img_transform(img_tensor)
263
- # self.img_transform(img_tensor)
264
- # out_dict['img_tensor'] = self.img_transform(img)
265
-
266
- ###### Text #####
267
- # if datum['is_train']:
268
- # sent = datum['sent'].strip()
269
-
270
- sent = random.choice(datum['targets'])
271
-
272
- # target_ids = self.tokenizer.encode(
273
- # sent, max_length=self.args.gen_max_length, truncation=True)
274
-
275
- # assert len(target_ids) <= self.args.gen_max_length, len(target_ids)
276
- out_dict['sent'] = sent
277
- # out_dict['target_ids'] = torch.LongTensor(target_ids)
278
- # out_dict['target_length'] = len(target_ids)
279
-
280
-
281
- # negative sample
282
- neg_sent, neg_type = self.get_negative_text(sent)
283
-
284
- # neg_target_ids = self.tokenizer.encode(
285
- # neg_sent, max_length=self.args.gen_max_length, truncation=True)
286
-
287
- # assert len(neg_target_ids) <= self.args.gen_max_length, len(neg_target_ids)
288
- out_dict['neg_sent'] = neg_sent
289
- out_dict['neg_type'] = neg_type
290
- # out_dict['neg_target_ids'] = torch.LongTensor(neg_target_ids)
291
- # out_dict['neg_target_length'] = len(neg_target_ids)
292
-
293
-
294
- if 'targets' in datum:
295
- out_dict['targets'] = datum['targets']
296
-
297
- return out_dict
298
-
299
- def collate_fn(self, batch):
300
- batch_entry = {}
301
-
302
- B = len(batch)
303
-
304
- # if 'target_ids' in batch[0]:
305
- # T_W_L = max(entry['target_length'] for entry in batch)
306
- # target_ids = torch.ones(
307
- # B, T_W_L, dtype=torch.long) * self.tokenizer.pad_token_id
308
-
309
- # if 'target_ids' in batch[0]:
310
- # T_W_L = max(entry['target_length'] for entry in batch)
311
- # target_ids = torch.ones(
312
- # B, T_W_L, dtype=torch.long) * self.tokenizer.pad_token_id
313
-
314
-
315
-
316
- targets = []
317
- img_ids = []
318
- img_paths = []
319
-
320
- coco_ids = []
321
-
322
- if self.args.load_feat:
323
- img_feats = torch.zeros(B, 512, dtype=torch.float)
324
- else:
325
- # imgs = []
326
- img_tensor = torch.zeros(B, 3, self.args.image_size, self.args.image_size, dtype=torch.uint8)
327
-
328
- for i, entry in enumerate(batch):
329
-
330
- if self.args.load_feat:
331
- coco_ids.append(entry['cocoid'])
332
- img_feats[i] = entry['img_feat']
333
-
334
- else:
335
-
336
- img_ids.append(entry['img_id'])
337
- img_paths.append(entry['img_path'])
338
- img_tensor[i] = entry['img_tensor']
339
-
340
- # if 'target_ids' in entry:
341
- # target_ids[i, :entry['target_length']] = entry['target_ids']
342
-
343
- if 'targets' in entry:
344
- targets.append(entry['targets'])
345
-
346
- if 'sent' in batch[0]:
347
- # word_mask = target_ids != self.tokenizer.pad_token_id
348
- # target_ids[~word_mask] = -100
349
- # batch_entry['target_ids'] = target_ids
350
-
351
- tokenized = self.tokenizer([entry['sent'] for entry in batch], truncation=True, padding=True, return_tensors='pt')
352
- neg_tokenized = self.tokenizer([entry['neg_sent'] for entry in batch], truncation=True, padding=True, return_tensors='pt')
353
- # sent, max_length=self.args.gen_max_length, truncation=True)
354
-
355
- batch_entry['text'] = (tokenized.input_ids, tokenized.attention_mask)
356
- batch_entry['neg_text'] = (neg_tokenized.input_ids, neg_tokenized.attention_mask)
357
-
358
-
359
- if self.args.load_feat:
360
- batch_entry['coco_ids'] = coco_ids
361
- batch_entry['img_feats'] = img_feats
362
-
363
- else:
364
-
365
- img_tensor = self.img_tensor_transform(img_tensor)
366
-
367
- batch_entry['img_id'] = img_ids
368
- batch_entry['img_paths'] = img_paths
369
- batch_entry['img_tensor'] = img_tensor
370
-
371
- batch_entry['targets'] = targets
372
-
373
- # print('batch created')
374
-
375
- # batch_entry['task'] = 'caption'
376
-
377
- return batch_entry
378
-
379
-
380
- # def get_loader(args, split='karpathy_train', mode='train',
381
- # batch_size=32, workers=4, distributed=False, gpu=0,
382
- # topk=-1):
383
-
384
- # verbose = (gpu == 0)
385
-
386
- # dataset = COCORetrievalDataset(
387
- # split,
388
- # rank=gpu,
389
- # topk=topk,
390
- # verbose=verbose,
391
- # args=args,
392
- # mode=mode)
393
-
394
- # # if distributed:
395
- # # sampler = DistributedSampler(dataset)
396
- # # else:
397
- # # sampler = None
398
-
399
- # if mode == 'train':
400
- # loader = DataLoader(
401
- # dataset, batch_size=batch_size, shuffle=(sampler is None),
402
- # num_workers=workers, pin_memory=True, sampler=sampler,
403
- # collate_fn=dataset.collate_fn)
404
- # else:
405
- # loader = DataLoader(
406
- # dataset,
407
- # batch_size=batch_size, shuffle=False,
408
- # num_workers=workers, pin_memory=True,
409
- # sampler=sampler,
410
- # collate_fn=dataset.collate_fn,
411
- # drop_last=False)
412
-
413
- # # if verbose:
414
- # # loader.evaluator = COCOCaptionEvaluator()
415
-
416
- # # loader.task = 'caption'
417
-
418
- # return loader
419
-
420
-
421
- # class COCOCaptionEvaluator:
422
- # def __init__(self):
423
- # import language_evaluation
424
- # self.evaluator = language_evaluation.CocoEvaluator(verbose=False)
425
-
426
- # def evaluate(self, predicts, answers):
427
-
428
- # results = self.evaluator.run_evaluation(predicts, answers)
429
-
430
- # return results
431
-
432
- import six
433
- import os
434
- import h5py
435
-
436
- class HybridLoader:
437
- """
438
- If db_path is a director, then use normal file loading
439
- If lmdb, then load from lmdb
440
- The loading method depend on extention.
441
-
442
- in_memory: if in_memory is True, we save all the features in memory
443
- For individual np(y|z)s, we don't need to do that because the system will do this for us.
444
- Should be useful for lmdb or h5.
445
- (Copied this idea from vilbert)
446
- """
447
-
448
- def __init__(self, db_path, ext='.npy', in_memory=False):
449
- self.db_path = db_path
450
- self.ext = ext
451
- if self.ext == '.npy':
452
- self.loader = lambda x: np.load(six.BytesIO(x))
453
- else:
454
- self.loader = lambda x: np.load(six.BytesIO(x))['feat']
455
- # if db_path.endswith('.lmdb'):
456
- # self.db_type = 'lmdb'
457
- # self.lmdb = lmdbdict(db_path, unsafe=True)
458
- # self.lmdb._key_dumps = DUMPS_FUNC['ascii']
459
- # self.lmdb._value_loads = LOADS_FUNC['identity']
460
- # elif db_path.endswith('.pth'): # Assume a key,value dictionary
461
- # self.db_type = 'pth'
462
- # self.feat_file = torch.load(db_path)
463
- # self.loader = lambda x: x
464
- # print('HybridLoader: ext is ignored')
465
- # elif db_path.endswith('h5'):
466
- # self.db_type = 'h5'
467
- # self.loader = lambda x: np.array(x).astype('float32')
468
- # else:
469
- # self.db_type = 'dir'
470
-
471
- self.in_memory = in_memory
472
- if self.in_memory:
473
- self.features = {}
474
-
475
- def get(self, key):
476
-
477
- # if self.in_memory and key in self.features:
478
- # # We save f_input because we want to save the
479
- # # compressed bytes to save memory
480
- # f_input = self.features[key]
481
- # elif self.db_type == 'lmdb':
482
- # f_input = self.lmdb[key]
483
- # elif self.db_type == 'pth':
484
- # f_input = self.feat_file[key]
485
- # elif self.db_type == 'h5':
486
- # f_input = h5py.File(self.db_path, 'r')[key]
487
- # else:
488
- # f_input = open(os.path.join(
489
- # self.db_path, key + self.ext), 'rb').read()
490
-
491
- f_input = open(os.path.join(
492
- self.db_path, key + self.ext), 'rb').read()
493
-
494
- if self.in_memory and key not in self.features:
495
- self.features[key] = f_input
496
-
497
- # load image
498
- feat = self.loader(f_input)
499
-
500
- return feat
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
retrieval/clip_model.py DELETED
@@ -1,350 +0,0 @@
1
- from transformers import CLIPModel, CLIPTokenizer
2
- import os
3
- import json
4
- import argparse
5
- from random import shuffle, seed
6
- import string
7
- # non-standard dependencies:
8
- import h5py
9
- from six.moves import cPickle
10
- import numpy as np
11
- import torch
12
- import torchvision.models as models
13
- import skimage.io
14
-
15
- from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
16
- from PIL import Image
17
- from torch import nn
18
-
19
-
20
- class CLIPScore(nn.Module):
21
- def __init__(self, clipscore_w=2.5, image_size=224, mode='clip_s', use_grammar=False, joint_out=False):
22
- super(CLIPScore, self).__init__()
23
- # from transformers import CLIPModel, CLIPTokenizer
24
- self.clip_model = CLIPModel.from_pretrained(
25
- 'openai/clip-vit-base-patch32')
26
- self.tokenizer = CLIPTokenizer.from_pretrained(
27
- 'openai/clip-vit-base-patch32')
28
-
29
- self.clip_model.eval()
30
-
31
- self.clipscore_w = clipscore_w
32
-
33
- self.image_transform = self._transform(image_size)
34
-
35
- self.mode = mode
36
- assert mode in ['clip_s', 'refclip_s']
37
-
38
- self.use_grammar = use_grammar
39
- self.joint_out = joint_out
40
-
41
- if self.use_grammar and self.joint_out is False:
42
- self.grammar_score_head = nn.Sequential(
43
- nn.Linear(self.clip_model.text_embed_dim, self.clip_model.projection_dim, bias=False),
44
- nn.ReLU(),
45
- nn.Linear(self.clip_model.projection_dim, 2, bias=False)
46
- )
47
-
48
- def _transform(self, n_px):
49
- return Compose([
50
- Resize(n_px, interpolation=Image.BICUBIC),
51
- CenterCrop(n_px),
52
- lambda image: image.convert("RGB"),
53
- ToTensor(),
54
- Normalize((0.48145466, 0.4578275, 0.40821073),
55
- (0.26862954, 0.26130258, 0.27577711)),
56
- ])
57
-
58
- def load_image(self, image_path):
59
- image = Image.open(image_path)
60
- return image
61
-
62
- # @torch.no_grad()
63
- def image_extract(self, image):
64
- if isinstance(image, str):
65
- image = self.load_image(image)
66
- if not isinstance(image, torch.Tensor):
67
- image = self.image_transform(image)
68
-
69
- img_tensor = image.view(-1, 3, 224, 224)
70
- device = next(self.clip_model.parameters()).device
71
- img_tensor = img_tensor.to(device)
72
-
73
- clip_model = self.clip_model
74
-
75
- img_feat = clip_model.vision_model(img_tensor).pooler_output
76
- img_feat = clip_model.visual_projection(img_feat)
77
- img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True)
78
-
79
- return img_feat
80
-
81
- # @torch.no_grad()
82
- def text_extract(self, text, prompt="A photo depicts", proj_norm=True):
83
- if isinstance(text, str):
84
- text_batch = [" ".join([prompt, text])]
85
- elif isinstance(text, list):
86
- text_batch = [" ".join([prompt, txt]) for txt in text]
87
-
88
- if isinstance(text, tuple) and isinstance(text[0], torch.Tensor):
89
- input_ids, attention_mask = text
90
- else:
91
- input_text = text_batch
92
-
93
- tokenized = self.tokenizer(
94
- input_text, return_tensors='pt', padding=True)
95
-
96
- input_ids = tokenized.input_ids
97
- attention_mask = tokenized.attention_mask
98
-
99
- clip_model = self.clip_model
100
- device = next(self.clip_model.parameters()).device
101
- input_ids = input_ids.to(device)
102
- attention_mask = attention_mask.to(device)
103
-
104
- text_feat = clip_model.text_model(input_ids, attention_mask).pooler_output
105
-
106
- if proj_norm:
107
- text_feat = clip_model.text_projection(text_feat)
108
- text_feat = text_feat / text_feat.norm(dim=-1, keepdim=True)
109
-
110
- return text_feat
111
-
112
- # @torch.no_grad()
113
- def calc_clip_s(self, img_feat, text_feat):
114
- return self.clipscore_w * torch.relu((img_feat * text_feat).sum(dim=-1))
115
-
116
- # @torch.no_grad()
117
- def calc_refclip_s(self, img_feat=None, text_feat=None, ref_text_feat=None, ref_text_mask=None, clip_s=None):
118
-
119
- if clip_s is None:
120
- clip_s = self.calc_clip_s(img_feat, text_feat)
121
-
122
- B, dim = img_feat.size()
123
-
124
- ref_text_feat = ref_text_feat.view(B, -1, dim)
125
-
126
- K = ref_text_feat.size(1)
127
-
128
- text_feat = text_feat.view(B, 1, dim).expand(-1, K, -1)
129
- assert ref_text_feat.size() == text_feat.size(
130
- ), (ref_text_feat.size(), text_feat.size())
131
-
132
- ref_score = self.calc_clip_s(text_feat, ref_text_feat)
133
- if ref_text_mask is not None:
134
- if not isinstance(ref_text_mask, torch.Tensor):
135
- ref_text_mask = torch.tensor(
136
- ref_text_mask, dtype=ref_score.dtype, device=ref_score.device)
137
- ref_score = ref_score.view(B, K) * ref_text_mask.view(B, K)
138
-
139
- ref_score = ref_score.view(B, K).max(dim=1).values
140
-
141
- assert clip_s.size() == (B,)
142
- assert clip_s.size() == ref_score.size()
143
-
144
- # harmonic mean
145
- refclip_s = 2 / (1 / clip_s + 1 / ref_score)
146
- return refclip_s
147
-
148
- # # @torch.no_grad()
149
- # def forward(self,
150
- # images=None, text=None,
151
- # img_feat=None, text_feat=None,
152
- # ref_text=None, ref_text_feat=None, ref_text_mask=None,
153
- # prompt="A photo depicts",
154
- # mode=None):
155
- # if img_feat is None:
156
- # img_feat = self.image_extract(images)
157
- # img_feat = img_feat.view(-1, 512)
158
-
159
- # if text_feat is None:
160
- # text_feat = self.text_extract(text, prompt=prompt)
161
- # text_feat = text_feat.view(-1, 512)
162
-
163
- # if mode is None:
164
- # mode = self.mode
165
- # assert mode in ['clip_s', 'refclip_s']
166
-
167
- # if mode == 'clip_s':
168
- # clip_s = self.calc_clip_s(img_feat, text_feat)
169
- # return clip_s
170
- # elif mode == 'refclip_s':
171
- # if ref_text_feat is None:
172
- # ref_text_feat = self.text_extract(ref_text, prompt=prompt)
173
- # ref_text_feat = ref_text_feat.view(-1, 512)
174
-
175
- # refclip_s = self.calc_refclip_s(
176
- # img_feat, text_feat, ref_text_feat, ref_text_mask=ref_text_mask)
177
- # return refclip_s
178
-
179
-
180
- def train_step(self,
181
- images=None, text=None,
182
- img_feat=None, text_feat=None,
183
- neg_text=None, neg_text_feat=None,
184
- # ref_text=None, ref_text_feat=None, ref_text_mask=None,
185
- prompt="A photo depicts",
186
- # return_loss=True,
187
- **kwargs):
188
-
189
- if img_feat is None:
190
- img_feat = self.image_extract(images)
191
- img_feat = img_feat.view(-1, 512)
192
-
193
- B = img_feat.size(0)
194
-
195
- if self.joint_out:
196
- pos_text_feat = self.text_extract(text, prompt=prompt, proj_norm=False).view(B, 512)
197
- neg_text_feat = self.text_extract(neg_text, prompt=prompt, proj_norm=False).view(-1, 512)
198
- neg_B = neg_text_feat.size(0)
199
-
200
- # [B+neg_B, 512]
201
- text_feat = torch.cat([pos_text_feat, neg_text_feat], dim=0)
202
-
203
- text_cont_feat = self.clip_model.text_projection(text_feat)
204
- text_cont_feat = text_cont_feat / text_cont_feat.norm(dim=-1, keepdim=True)
205
-
206
- text_cont_feat = text_cont_feat.view(B+neg_B, 512)
207
-
208
- logit_scale = self.clip_model.logit_scale.exp()
209
-
210
- # [B+neg_B * B]
211
- logits_per_text = torch.matmul(text_cont_feat, img_feat.t()) * logit_scale
212
-
213
- # image-to-text label: positive text
214
- caption_loss = -torch.diag(nn.functional.log_softmax(logits_per_text, dim=0)[:B]).mean()
215
-
216
- # calculate text-to-image only on positive text
217
- image_loss = -torch.diag(nn.functional.log_softmax(logits_per_text[:B], dim=1)).mean()
218
-
219
- clip_loss = (caption_loss + image_loss) / 2.0
220
-
221
- out = {
222
- 'clip_loss': clip_loss,
223
- 'img_feat': img_feat,
224
- 'text_feat': text_cont_feat[:B].detach(),
225
- # 'neg_text_feat': neg_text_feat,
226
- }
227
-
228
- return out
229
-
230
-
231
- else:
232
- if text_feat is None:
233
- text_feat = self.text_extract(text, prompt=prompt, proj_norm=False)
234
-
235
- text_cont_feat = self.clip_model.text_projection(text_feat)
236
- text_cont_feat = text_cont_feat / \
237
- text_cont_feat.norm(dim=-1, keepdim=True)
238
-
239
- text_cont_feat = text_cont_feat.view(B, 512)
240
-
241
-
242
- # cosine similarity as logits
243
- logit_scale = self.clip_model.logit_scale.exp()
244
- logits_per_text = torch.matmul(text_cont_feat, img_feat.t()) * logit_scale
245
- # logits_per_image = logits_per_text.T
246
-
247
- clip_loss = clip_loss_fn(logits_per_text)
248
-
249
-
250
- # negative sampling
251
- pos_text_feat = text_feat.view(B, 512)
252
- neg_text_feat = self.text_extract(neg_text, prompt=prompt, proj_norm=False).view(B, 512)
253
-
254
- grammar_text_feat = torch.cat([pos_text_feat, neg_text_feat], dim=0)
255
-
256
- # 2B, 1
257
- grammar_text_logit = self.grammar_score_head(grammar_text_feat)
258
- grammar_labels = torch.LongTensor([1] * B + [0] * B).to(grammar_text_logit.device).view(2 * B)
259
-
260
- grammar_loss = torch.nn.functional.cross_entropy(grammar_text_logit, grammar_labels)
261
-
262
- grammar_pred = grammar_text_logit.argmax(dim=1, keepdim=False)
263
- grammar_pos_pred = grammar_pred[:B]
264
- grammar_neg_pred = grammar_pred[B:]
265
- # grammar_acc = (grammar_pred == grammar_labels).float().mean()
266
-
267
- out = {
268
- 'clip_loss': clip_loss,
269
- 'grammar_loss': grammar_loss,
270
- 'img_feat': img_feat,
271
- 'text_feat': text_cont_feat,
272
- 'neg_text_feat': neg_text_feat,
273
- 'grammar_pos_pred': grammar_pos_pred,
274
- 'grammar_neg_pred': grammar_neg_pred,
275
- }
276
-
277
- return out
278
-
279
- def train_step_old(self,
280
- images=None, text=None,
281
- img_feat=None, text_feat=None,
282
- neg_text=None, neg_text_feat=None,
283
- # ref_text=None, ref_text_feat=None, ref_text_mask=None,
284
- prompt="A photo depicts",
285
- # return_loss=True,
286
- **kwargs):
287
-
288
- if img_feat is None:
289
- img_feat = self.image_extract(images)
290
- img_feat = img_feat.view(-1, 512)
291
-
292
- B = img_feat.size(0)
293
-
294
-
295
-
296
- if text_feat is None:
297
- text_feat = self.text_extract(text, prompt=prompt, proj_norm=False)
298
-
299
- text_cont_feat = self.clip_model.text_projection(text_feat)
300
- text_cont_feat = text_cont_feat / text_cont_feat.norm(dim=-1, keepdim=True)
301
- text_cont_feat = text_cont_feat.view(B, 512)
302
-
303
- # cosine similarity as logits
304
- logit_scale = self.clip_model.logit_scale.exp()
305
- logits_per_text = torch.matmul(text_cont_feat, img_feat.t()) * logit_scale
306
- # logits_per_image = logits_per_text.T
307
-
308
- clip_loss = clip_loss_fn(logits_per_text)
309
-
310
-
311
- # negative sampling
312
- pos_text_feat = text_feat.view(B, 512)
313
- neg_text_feat = self.text_extract(neg_text, prompt=prompt, proj_norm=False).view(B, 512)
314
-
315
- grammar_text_feat = torch.cat([pos_text_feat, neg_text_feat], dim=0)
316
-
317
- # 2B, 1
318
- grammar_text_logit = self.grammar_score_head(grammar_text_feat)
319
- grammar_labels = torch.LongTensor([1] * B + [0] * B).to(grammar_text_logit.device).view(2 * B)
320
-
321
- grammar_loss = torch.nn.functional.cross_entropy(grammar_text_logit, grammar_labels)
322
-
323
- grammar_pred = grammar_text_logit.argmax(dim=1, keepdim=False)
324
- grammar_pos_pred = grammar_pred[:B]
325
- grammar_neg_pred = grammar_pred[B:]
326
- # grammar_acc = (grammar_pred == grammar_labels).float().mean()
327
-
328
- out = {
329
- 'clip_loss': clip_loss,
330
- 'grammar_loss': grammar_loss,
331
- 'img_feat': img_feat,
332
- 'text_feat': text_cont_feat,
333
- 'neg_text_feat': neg_text_feat,
334
- 'grammar_pos_pred': grammar_pos_pred,
335
- 'grammar_neg_pred': grammar_neg_pred,
336
- }
337
-
338
- return out
339
-
340
- # contrastive loss function, adapted from
341
- # https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html
342
- def contrastive_loss(logits: torch.Tensor, dim: int) -> torch.Tensor:
343
- neg_ce = torch.diag(nn.functional.log_softmax(logits, dim=dim))
344
- return -neg_ce.mean()
345
-
346
-
347
- def clip_loss_fn(similarity: torch.Tensor) -> torch.Tensor:
348
- caption_loss = contrastive_loss(similarity, dim=0)
349
- image_loss = contrastive_loss(similarity, dim=1)
350
- return (caption_loss + image_loss) / 2.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
retrieval/configs/clip_negative_text.yaml DELETED
@@ -1,14 +0,0 @@
1
- checkpoint_dir: ./save/clip_negative_text/
2
-
3
- losses_log_every: 25
4
- precision: 32
5
- load_feat: true
6
- data_in_memory: false
7
-
8
- batch_size: 1600
9
- valid_batch_size: 200
10
- clip_grad_norm: 0
11
-
12
- epochs: 30
13
- use_grammar: true
14
- joint_out: false
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
retrieval/param.py DELETED
@@ -1,209 +0,0 @@
1
- import argparse
2
- import random
3
-
4
- import numpy as np
5
- import torch
6
-
7
- import pprint
8
- import yaml
9
-
10
-
11
- def str2bool(v):
12
- if v.lower() in ('yes', 'true', 't', 'y', '1'):
13
- return True
14
- elif v.lower() in ('no', 'false', 'f', 'n', '0'):
15
- return False
16
- else:
17
- raise argparse.ArgumentTypeError('Boolean value expected.')
18
-
19
-
20
- def is_interactive():
21
- import __main__ as main
22
- return not hasattr(main, '__file__')
23
-
24
-
25
- def get_optimizer(optim, verbose=False):
26
- # Bind the optimizer
27
- if optim == 'rms':
28
- if verbose:
29
- print("Optimizer: Using RMSProp")
30
- optimizer = torch.optim.RMSprop
31
- elif optim == 'adam':
32
- if verbose:
33
- print("Optimizer: Using Adam")
34
- optimizer = torch.optim.Adam
35
- elif optim == 'adamw':
36
- if verbose:
37
- print("Optimizer: Using AdamW")
38
- # optimizer = torch.optim.AdamW
39
- optimizer = 'adamw'
40
- elif optim == 'adamax':
41
- if verbose:
42
- print("Optimizer: Using Adamax")
43
- optimizer = torch.optim.Adamax
44
- elif optim == 'sgd':
45
- if verbose:
46
- print("Optimizer: SGD")
47
- optimizer = torch.optim.SGD
48
- else:
49
- assert False, "Please add your optimizer %s in the list." % optim
50
-
51
- return optimizer
52
-
53
-
54
- def parse_args(parse=True, **optional_kwargs):
55
- parser = argparse.ArgumentParser()
56
-
57
- parser.add_argument('--seed', type=int, default=9595, help='random seed')
58
-
59
- # Data Splits
60
- parser.add_argument("--train", default='karpathy_train')
61
- parser.add_argument("--valid", default='karpathy_val')
62
- parser.add_argument("--test", default='karpathy_test')
63
- # parser.add_argument('--test_only', action='store_true')
64
-
65
- # Quick experiments
66
- parser.add_argument('--train_topk', type=int, default=-1)
67
- parser.add_argument('--valid_topk', type=int, default=-1)
68
-
69
- # Checkpoint
70
- parser.add_argument('--output', type=str, default='snap/test')
71
- parser.add_argument('--load', type=str, default=None, help='Load the model (usually the fine-tuned model).')
72
- parser.add_argument('--from_scratch', action='store_true')
73
-
74
- # CPU/GPU
75
- parser.add_argument("--multiGPU", action='store_const', default=False, const=True)
76
- parser.add_argument('--fp16', action='store_true')
77
- parser.add_argument("--distributed", action='store_true')
78
- parser.add_argument("--num_workers", default=0, type=int)
79
- parser.add_argument('--local_rank', type=int, default=-1)
80
- # parser.add_argument('--rank', type=int, default=-1)
81
-
82
- # Model Config
83
- # parser.add_argument('--encoder_backbone', type=str, default='openai/clip-vit-base-patch32')
84
- # parser.add_argument('--decoder_backbone', type=str, default='bert-base-uncased')
85
- parser.add_argument('--tokenizer', type=str, default='openai/clip-vit-base-patch32')
86
-
87
- # parser.add_argument('--position_embedding_type', type=str, default='absolute')
88
-
89
- # parser.add_argument('--encoder_transform', action='store_true')
90
-
91
- parser.add_argument('--max_text_length', type=int, default=40)
92
-
93
- # parser.add_argument('--image_size', type=int, default=224)
94
- # parser.add_argument('--patch_size', type=int, default=32)
95
-
96
- # parser.add_argument('--decoder_num_layers', type=int, default=12)
97
-
98
- # Training
99
- parser.add_argument('--batch_size', type=int, default=256)
100
- parser.add_argument('--valid_batch_size', type=int, default=None)
101
-
102
- parser.add_argument('--optim', default='adamw')
103
-
104
- parser.add_argument('--warmup_ratio', type=float, default=0.05)
105
- parser.add_argument('--weight_decay', type=float, default=0.01)
106
- parser.add_argument('--clip_grad_norm', type=float, default=-1.0)
107
- parser.add_argument('--gradient_accumulation_steps', type=int, default=1)
108
- parser.add_argument('--lr', type=float, default=1e-4)
109
- parser.add_argument('--adam_eps', type=float, default=1e-6)
110
- parser.add_argument('--adam_beta1', type=float, default=0.9)
111
- parser.add_argument('--adam_beta2', type=float, default=0.999)
112
-
113
- parser.add_argument('--epochs', type=int, default=20)
114
- # parser.add_argument('--dropout', type=float, default=0.1)
115
-
116
-
117
- # Inference
118
- # parser.add_argument('--num_beams', type=int, default=1)
119
- # parser.add_argument('--gen_max_length', type=int, default=20)
120
-
121
- parser.add_argument('--start_from', type=str, default=None)
122
-
123
- # Data
124
- # parser.add_argument('--do_lower_case', type=str2bool, default=None)
125
-
126
- # parser.add_argument('--prefix', type=str, default=None)
127
-
128
-
129
- # COCO Caption
130
- # parser.add_argument('--no_prefix', action='store_true')
131
-
132
- parser.add_argument('--no_cls', action='store_true')
133
-
134
- parser.add_argument('--cfg', type=str, default=None)
135
- parser.add_argument('--id', type=str, default=None)
136
-
137
- # Etc.
138
- parser.add_argument('--comment', type=str, default='')
139
- parser.add_argument("--dry", action='store_true')
140
-
141
- # Parse the arguments.
142
- if parse:
143
- args = parser.parse_args()
144
- # For interative engironmnet (ex. jupyter)
145
- else:
146
- args = parser.parse_known_args()[0]
147
-
148
- loaded_kwargs = {}
149
- if args.cfg is not None:
150
- cfg_path = f'configs/{args.cfg}.yaml'
151
- with open(cfg_path, 'r') as f:
152
- loaded_kwargs = yaml.safe_load(f)
153
-
154
- # Namespace => Dictionary
155
- parsed_kwargs = vars(args)
156
- parsed_kwargs.update(optional_kwargs)
157
-
158
- kwargs = {}
159
- kwargs.update(parsed_kwargs)
160
- kwargs.update(loaded_kwargs)
161
-
162
- args = Config(**kwargs)
163
-
164
- # Bind optimizer class.
165
- verbose = False
166
- args.optimizer = get_optimizer(args.optim, verbose=verbose)
167
-
168
- # Set seeds
169
- torch.manual_seed(args.seed)
170
- random.seed(args.seed)
171
- np.random.seed(args.seed)
172
-
173
- return args
174
-
175
-
176
- class Config(object):
177
- def __init__(self, **kwargs):
178
- """Configuration Class: set kwargs as class attributes with setattr"""
179
- for k, v in kwargs.items():
180
- setattr(self, k, v)
181
-
182
- @property
183
- def config_str(self):
184
- return pprint.pformat(self.__dict__)
185
-
186
- def __repr__(self):
187
- """Pretty-print configurations in alphabetical order"""
188
- config_str = 'Configurations\n'
189
- config_str += self.config_str
190
- return config_str
191
-
192
- # def update(self, **kwargs):
193
- # for k, v in kwargs.items():
194
- # setattr(self, k, v)
195
-
196
- # def save(self, path):
197
- # with open(path, 'w') as f:
198
- # yaml.dump(self.__dict__, f, default_flow_style=False)
199
-
200
- # @classmethod
201
- # def load(cls, path):
202
- # with open(path, 'r') as f:
203
- # kwargs = yaml.load(f)
204
-
205
- # return Config(**kwargs)
206
-
207
-
208
- if __name__ == '__main__':
209
- args = parse_args(True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
retrieval/pth_loader.py DELETED
@@ -1,334 +0,0 @@
1
- from __future__ import absolute_import
2
- from __future__ import division
3
- from __future__ import print_function
4
-
5
- import json
6
- import h5py
7
- from lmdbdict import lmdbdict
8
- from lmdbdict.methods import DUMPS_FUNC, LOADS_FUNC
9
- import os
10
- import numpy as np
11
- import numpy.random as npr
12
- import random
13
-
14
- import torch
15
- import torch.utils.data as data
16
-
17
- import multiprocessing
18
- import six
19
-
20
- verbose = True
21
- # import torch
22
- # if torch.cuda.current_device() in [0, -1]:
23
- if 'LOCAL_RANK' in os.environ and os.environ['LOCAL_RANK'] != '0':
24
- verbose = False
25
-
26
- class HybridLoader:
27
- """
28
- If db_path is a director, then use normal file loading
29
- If lmdb, then load from lmdb
30
- The loading method depend on extention.
31
-
32
- in_memory: if in_memory is True, we save all the features in memory
33
- For individual np(y|z)s, we don't need to do that because the system will do this for us.
34
- Should be useful for lmdb or h5.
35
- (Copied this idea from vilbert)
36
- """
37
- def __init__(self, db_path, ext, in_memory=False):
38
- self.db_path = db_path
39
- self.ext = ext
40
- if self.ext == '.npy':
41
- self.loader = lambda x: np.load(six.BytesIO(x))
42
- else:
43
- self.loader = lambda x: np.load(six.BytesIO(x))['feat']
44
- if db_path.endswith('.lmdb'):
45
- self.db_type = 'lmdb'
46
- self.lmdb = lmdbdict(db_path, unsafe=True)
47
- self.lmdb._key_dumps = DUMPS_FUNC['ascii']
48
- self.lmdb._value_loads = LOADS_FUNC['identity']
49
- elif db_path.endswith('.pth'): # Assume a key,value dictionary
50
- self.db_type = 'pth'
51
- self.feat_file = torch.load(db_path)
52
- self.loader = lambda x: x
53
- print('HybridLoader: ext is ignored')
54
- elif db_path.endswith('h5'):
55
- self.db_type = 'h5'
56
- self.loader = lambda x: np.array(x).astype('float32')
57
- else:
58
- self.db_type = 'dir'
59
-
60
- self.in_memory = in_memory
61
- if self.in_memory:
62
- self.features = {}
63
-
64
- def get(self, key):
65
-
66
- if self.in_memory and key in self.features:
67
- # We save f_input because we want to save the
68
- # compressed bytes to save memory
69
- f_input = self.features[key]
70
- elif self.db_type == 'lmdb':
71
- f_input = self.lmdb[key]
72
- elif self.db_type == 'pth':
73
- f_input = self.feat_file[key]
74
- elif self.db_type == 'h5':
75
- f_input = h5py.File(self.db_path, 'r')[key]
76
- else:
77
- f_input = open(os.path.join(self.db_path, key + self.ext), 'rb').read()
78
-
79
- if self.in_memory and key not in self.features:
80
- self.features[key] = f_input
81
-
82
- # load image
83
- feat = self.loader(f_input)
84
-
85
- return feat
86
-
87
- class CaptionDataset(data.Dataset):
88
-
89
- def get_vocab_size(self):
90
- return self.vocab_size
91
-
92
- def get_vocab(self):
93
- return self.ix_to_word
94
-
95
- def get_seq_length(self):
96
- return self.seq_length
97
-
98
- def __init__(self, opt):
99
- self.opt = opt
100
- self.seq_per_img = opt.seq_per_img
101
-
102
- # feature related options
103
- self.use_fc = getattr(opt, 'use_fc', True)
104
- self.use_att = getattr(opt, 'use_att', True)
105
- self.use_box = getattr(opt, 'use_box', 0)
106
- self.norm_att_feat = getattr(opt, 'norm_att_feat', 0)
107
- self.norm_box_feat = getattr(opt, 'norm_box_feat', 0)
108
-
109
- # load the json file which contains additional information about the dataset
110
- if verbose:
111
- print('DataLoader loading json file: ', opt.input_json)
112
- self.info = json.load(open(self.opt.input_json))
113
- if 'ix_to_word' in self.info:
114
- self.ix_to_word = self.info['ix_to_word']
115
- self.vocab_size = len(self.ix_to_word)
116
- if verbose:
117
- print('vocab size is ', self.vocab_size)
118
-
119
- # open the hdf5 file
120
- if verbose:
121
- print('DataLoader loading h5 file: ', opt.input_fc_dir, opt.input_att_dir, opt.input_box_dir, opt.input_label_h5)
122
- """
123
- Setting input_label_h5 to none is used when only doing generation.
124
- For example, when you need to test on coco test set.
125
- """
126
- if self.opt.input_label_h5 != 'none':
127
- self.h5_label_file = h5py.File(self.opt.input_label_h5, 'r', driver='core')
128
- # load in the sequence data
129
- seq_size = self.h5_label_file['labels'].shape
130
- self.label = self.h5_label_file['labels'][:]
131
- self.seq_length = seq_size[1]
132
- if verbose:
133
- print('max sequence length in data is', self.seq_length)
134
- # load the pointers in full to RAM (should be small enough)
135
- self.label_start_ix = self.h5_label_file['label_start_ix'][:]
136
- self.label_end_ix = self.h5_label_file['label_end_ix'][:]
137
- else:
138
- self.seq_length = 1
139
-
140
- self.data_in_memory = getattr(opt, 'data_in_memory', False)
141
- self.fc_loader = HybridLoader(self.opt.input_fc_dir, '.npy', in_memory=self.data_in_memory)
142
- self.att_loader = HybridLoader(self.opt.input_att_dir, '.npz', in_memory=self.data_in_memory)
143
- self.box_loader = HybridLoader(self.opt.input_box_dir, '.npy', in_memory=self.data_in_memory)
144
-
145
- self.use_clipscore = getattr(opt, 'use_clipscore', False)
146
- if self.use_clipscore:
147
- self.clipscore_loader = HybridLoader(self.opt.input_clipscore_vis_dir, '.npy', in_memory=self.data_in_memory)
148
-
149
-
150
- self.num_images = len(self.info['images']) # self.label_start_ix.shape[0]
151
- if verbose:
152
- print('read %d image features' %(self.num_images))
153
-
154
- # separate out indexes for each of the provided splits
155
- self.split_ix = {'train': [], 'val': [], 'test': []}
156
- for ix in range(len(self.info['images'])):
157
- img = self.info['images'][ix]
158
- if not 'split' in img:
159
- self.split_ix['train'].append(ix)
160
- self.split_ix['val'].append(ix)
161
- self.split_ix['test'].append(ix)
162
- elif img['split'] == 'train':
163
- self.split_ix['train'].append(ix)
164
- elif img['split'] == 'val':
165
- self.split_ix['val'].append(ix)
166
- elif img['split'] == 'test':
167
- self.split_ix['test'].append(ix)
168
- elif opt.train_only == 0: # restval
169
- self.split_ix['train'].append(ix)
170
-
171
- if verbose:
172
- print('assigned %d images to split train' %len(self.split_ix['train']))
173
- print('assigned %d images to split val' %len(self.split_ix['val']))
174
- print('assigned %d images to split test' %len(self.split_ix['test']))
175
-
176
- def get_captions(self, ix, seq_per_img):
177
- # fetch the sequence labels
178
- ix1 = self.label_start_ix[ix] - 1 #label_start_ix starts from 1
179
- ix2 = self.label_end_ix[ix] - 1
180
- ncap = ix2 - ix1 + 1 # number of captions available for this image
181
- assert ncap > 0, 'an image does not have any label. this can be handled but right now isn\'t'
182
-
183
- if ncap < seq_per_img:
184
- # we need to subsample (with replacement)
185
- seq = np.zeros([seq_per_img, self.seq_length], dtype = 'int')
186
- for q in range(seq_per_img):
187
- ixl = random.randint(ix1,ix2)
188
- seq[q, :] = self.label[ixl, :self.seq_length]
189
- else:
190
- ixl = random.randint(ix1, ix2 - seq_per_img + 1)
191
- seq = self.label[ixl: ixl + seq_per_img, :self.seq_length]
192
-
193
- return seq
194
-
195
- def collate_func(self, batch):
196
- seq_per_img = self.seq_per_img
197
-
198
- fc_batch = []
199
- att_batch = []
200
- label_batch = []
201
-
202
- clip_vis_feat_batch = []
203
-
204
- wrapped = False
205
-
206
- infos = []
207
- gts = []
208
-
209
- for sample in batch:
210
- # fetch image
211
- if self.use_clipscore:
212
- tmp_fc, tmp_att, tmp_seq, \
213
- ix, tmp_clip_vis_feat = sample
214
-
215
- clip_vis_feat_batch.append(tmp_clip_vis_feat)
216
- else:
217
- tmp_fc, tmp_att, tmp_seq, \
218
- ix = sample
219
-
220
- fc_batch.append(tmp_fc)
221
- att_batch.append(tmp_att)
222
-
223
- tmp_label = np.zeros([seq_per_img, self.seq_length + 2], dtype = 'int')
224
- if hasattr(self, 'h5_label_file'):
225
- # if there is ground truth
226
- tmp_label[:, 1 : self.seq_length + 1] = tmp_seq
227
- label_batch.append(tmp_label)
228
-
229
- # Used for reward evaluation
230
- if hasattr(self, 'h5_label_file'):
231
- # if there is ground truth
232
- gts.append(self.label[self.label_start_ix[ix] - 1: self.label_end_ix[ix]])
233
- else:
234
- gts.append([])
235
-
236
- # record associated info as well
237
- info_dict = {}
238
- info_dict['ix'] = ix
239
- info_dict['id'] = self.info['images'][ix]['id']
240
- info_dict['file_path'] = self.info['images'][ix].get('file_path', '')
241
- infos.append(info_dict)
242
-
243
- # #sort by att_feat length
244
- # fc_batch, att_batch, label_batch, gts, infos = \
245
- # zip(*sorted(zip(fc_batch, att_batch, np.vsplit(label_batch, batch_size), gts, infos), key=lambda x: len(x[1]), reverse=True))
246
- if self.use_clipscore:
247
- fc_batch, att_batch, label_batch, clip_vis_feat_batch, gts, infos = \
248
- zip(*sorted(zip(fc_batch, att_batch, label_batch, clip_vis_feat_batch, gts, infos), key=lambda x: 0, reverse=True))
249
- else:
250
- fc_batch, att_batch, label_batch, gts, infos = \
251
- zip(*sorted(zip(fc_batch, att_batch, label_batch, gts, infos), key=lambda x: 0, reverse=True))
252
- data = {}
253
- data['fc_feats'] = np.stack(fc_batch)
254
- # merge att_feats
255
- max_att_len = max([_.shape[0] for _ in att_batch])
256
- data['att_feats'] = np.zeros([len(att_batch), max_att_len, att_batch[0].shape[1]], dtype = 'float32')
257
- for i in range(len(att_batch)):
258
- data['att_feats'][i, :att_batch[i].shape[0]] = att_batch[i]
259
- data['att_masks'] = np.zeros(data['att_feats'].shape[:2], dtype='float32')
260
- for i in range(len(att_batch)):
261
- data['att_masks'][i, :att_batch[i].shape[0]] = 1
262
- # set att_masks to None if attention features have same length
263
- if data['att_masks'].sum() == data['att_masks'].size:
264
- data['att_masks'] = None
265
-
266
- if self.use_clipscore:
267
- data['clip_vis_feats'] = np.stack(clip_vis_feat_batch)
268
-
269
- data['labels'] = np.vstack(label_batch)
270
- # generate mask
271
- nonzeros = np.array(list(map(lambda x: (x != 0).sum()+2, data['labels'])))
272
- mask_batch = np.zeros([data['labels'].shape[0], self.seq_length + 2], dtype = 'float32')
273
- for ix, row in enumerate(mask_batch):
274
- row[:nonzeros[ix]] = 1
275
- data['masks'] = mask_batch
276
- data['labels'] = data['labels'].reshape(len(batch), seq_per_img, -1)
277
- data['masks'] = data['masks'].reshape(len(batch), seq_per_img, -1)
278
-
279
- data['gts'] = gts # all ground truth captions of each images
280
- data['infos'] = infos
281
-
282
- data = {k:torch.from_numpy(v) if type(v) is np.ndarray else v for k,v in data.items()} # Turn all ndarray to torch tensor
283
-
284
- return data
285
-
286
- def __getitem__(self, ix):
287
- """This function returns a tuple that is further passed to collate_fn
288
- """
289
- if self.use_att:
290
- att_feat = self.att_loader.get(str(self.info['images'][ix]['id']))
291
- # Reshape to K x C
292
- att_feat = att_feat.reshape(-1, att_feat.shape[-1])
293
- if self.norm_att_feat:
294
- att_feat = att_feat / np.linalg.norm(att_feat, 2, 1, keepdims=True)
295
- if self.use_box:
296
- box_feat = self.box_loader.get(str(self.info['images'][ix]['id']))
297
- # devided by image width and height
298
- x1,y1,x2,y2 = np.hsplit(box_feat, 4)
299
- h,w = self.info['images'][ix]['height'], self.info['images'][ix]['width']
300
- box_feat = np.hstack((x1/w, y1/h, x2/w, y2/h, (x2-x1)*(y2-y1)/(w*h))) # question? x2-x1+1??
301
- if self.norm_box_feat:
302
- box_feat = box_feat / np.linalg.norm(box_feat, 2, 1, keepdims=True)
303
- att_feat = np.hstack([att_feat, box_feat])
304
- # sort the features by the size of boxes
305
- att_feat = np.stack(sorted(att_feat, key=lambda x:x[-1], reverse=True))
306
- else:
307
- att_feat = np.zeros((0,0), dtype='float32')
308
- if self.use_fc:
309
- try:
310
- fc_feat = self.fc_loader.get(str(self.info['images'][ix]['id']))
311
- except:
312
- # Use average of attention when there is no fc provided (For bottomup feature)
313
- fc_feat = att_feat.mean(0)
314
- else:
315
- fc_feat = np.zeros((0), dtype='float32')
316
- if hasattr(self, 'h5_label_file'):
317
- seq = self.get_captions(ix, self.seq_per_img)
318
- else:
319
- seq = None
320
-
321
- if self.use_clipscore:
322
- clip_vis_feat = self.clipscore_loader.get(
323
- str(self.info['images'][ix]['id']))
324
-
325
- return (fc_feat,
326
- att_feat, seq,
327
- ix, clip_vis_feat)
328
-
329
- return (fc_feat,
330
- att_feat, seq,
331
- ix)
332
-
333
- def __len__(self):
334
- return len(self.info['images'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
retrieval/text_utils.py DELETED
@@ -1,74 +0,0 @@
1
- import random
2
-
3
- def repeat(text, n_max_gram=3, n_max_repeat=3):
4
- """repeat n-grams"""
5
- tokens = text.split()
6
-
7
- n_gram = random.randint(1, n_max_gram)
8
-
9
- repeat_token_idx = random.randint(0, len(tokens) - n_gram)
10
-
11
- repeated_tokens = tokens[repeat_token_idx:repeat_token_idx+n_gram]
12
-
13
- n_repeat = random.randint(1, n_max_repeat)
14
- for _ in range(n_repeat):
15
- insert_idx = random.randint(0, len(tokens))
16
- tokens = tokens[:insert_idx] + \
17
- repeated_tokens + tokens[insert_idx:]
18
-
19
- new_text = " ".join(tokens)
20
- return new_text
21
-
22
- def remove(text, n_max_gram=3):
23
- """remove n-grams"""
24
- tokens = text.split()
25
-
26
- n_gram = random.randint(1, n_max_gram)
27
-
28
- remove_token_idx = random.randint(0, len(tokens) - n_gram)
29
-
30
- tokens = tokens[:remove_token_idx] + tokens[remove_token_idx + n_gram:]
31
-
32
- new_text = " ".join(tokens)
33
- return new_text
34
-
35
- def insert(text, vocab, n_max_tokens=3):
36
- """Insert tokens"""
37
- tokens = text.split()
38
-
39
- n_insert_token = random.randint(1, n_max_tokens)
40
-
41
- for _ in range(n_insert_token):
42
- insert_token_idx = random.randint(0, len(tokens) - 1)
43
- insert_token = random.choice(vocab)
44
- tokens = tokens[:insert_token_idx] + [insert_token] + tokens[insert_token_idx:]
45
-
46
- new_text = " ".join(tokens)
47
- return new_text
48
-
49
- def swap(text, vocab, n_max_tokens=3):
50
- """Swap tokens"""
51
- tokens = text.split()
52
-
53
- n_swap_tokens = random.randint(1, n_max_tokens)
54
-
55
- for _ in range(n_swap_tokens):
56
- swap_token_idx = random.randint(0, len(tokens) - 1)
57
-
58
- swap_token = random.choice(vocab)
59
- while swap_token == tokens[swap_token_idx]:
60
- swap_token = random.choice(vocab)
61
-
62
- tokens[swap_token_idx] = swap_token
63
-
64
- new_text = " ".join(tokens)
65
- return new_text
66
-
67
- def shuffle(text):
68
- """shuffle tokens"""
69
- tokens = text.split()
70
-
71
- random.shuffle(tokens)
72
-
73
- new_text = " ".join(tokens)
74
- return new_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
retrieval/train_pl.py DELETED
@@ -1,661 +0,0 @@
1
- from ast import parse
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
- import torch.optim as optim
6
-
7
- import numpy as np
8
-
9
- import time
10
- import os
11
- from collections import defaultdict
12
-
13
- # import captioning.utils.opts as opts
14
- # import captioning.models as models
15
- # from captioning.data.pth_loader import CaptionDataset
16
- # import captioning.utils.eval_utils as eval_utils
17
- # import captioning.utils.misc as utils
18
- # from captioning.utils.rewards import init_scorer, get_self_critical_reward
19
- # from captioning.modules.loss_wrapper import LossWrapper
20
-
21
- from clip_model import CLIPScore
22
- from caption_data import COCORetrievalDataset
23
-
24
- import pytorch_lightning as pl
25
-
26
- import detectron2.utils.comm as d2comm
27
- from detectron2.utils.env import seed_all_rng
28
- seed_all_rng(1234)
29
-
30
-
31
- class LitModel(pl.LightningModule):
32
- def __init__(self, opt):
33
- super().__init__()
34
- self.opt = opt
35
- self.args = args
36
- # Intilaize dataset
37
- # self.dataset = CaptionDataset(opt)
38
-
39
- # self.dataset =
40
-
41
- # opt.vocab_size = self.dataset.vocab_size
42
- # opt.seq_length = self.dataset.seq_length
43
- # self.batch_size = opt.batch_size
44
-
45
- # Build model
46
- # opt.vocab = self.dataset.get_vocab()
47
- # model = models.setup(opt)
48
- # print(model)
49
- # del opt.vocab
50
-
51
- # wrapper with loss in it.
52
- # lw_model = LossWrapper(model, opt)
53
-
54
- self.model = CLIPScore(use_grammar=opt.use_grammar, joint_out=opt.joint_out)
55
- # self.lw_model = lw_model
56
-
57
- for p in self.model.clip_model.vision_model.parameters():
58
- p.requires_grad = False
59
- for p in self.model.clip_model.visual_projection.parameters():
60
- p.requires_grad = False
61
-
62
- # self.struc_flag = None
63
- # self.sc_flag = None
64
-
65
-
66
- def forward(self, *args, **kwargs):
67
- """
68
- I hate this design. Never pretend it as a nn.Module
69
- """
70
- raise NotImplementedError
71
-
72
- def train_dataloader(self):
73
- # train_dataset = torch.utils.data.Subset(
74
- # self.dataset,
75
- # self.dataset.split_ix['train']
76
- # )
77
-
78
- # train_loader = torch.utils.data.DataLoader(
79
- # dataset=train_dataset,
80
- # batch_size=self.batch_size,
81
- # shuffle=True,
82
- # num_workers=4,
83
- # collate_fn=self.dataset.collate_func
84
- # )
85
-
86
- train_dataset = COCORetrievalDataset(
87
- split='karpathy_train', mode='train',
88
- args=opt,
89
- verbose=verbose
90
- )
91
-
92
- train_loader = torch.utils.data.DataLoader(
93
- dataset=train_dataset,
94
- batch_size=opt.batch_size,
95
- shuffle=True,
96
- num_workers=4,
97
- collate_fn=train_dataset.collate_fn
98
- )
99
-
100
- return train_loader
101
-
102
- def val_dataloader(self, split='karpathy_val'):
103
- # val_dataset = torch.utils.data.Subset(
104
- # self.dataset,
105
- # self.dataset.split_ix[split]
106
- # )
107
- # val_loader = torch.utils.data.DataLoader(
108
- # val_dataset,
109
- # batch_size=self.batch_size,
110
- # shuffle=False,
111
- # num_workers=4,
112
- # drop_last=False,
113
- # collate_fn=self.dataset.collate_func
114
- # )
115
-
116
- val_dataset = COCORetrievalDataset(
117
- split=split, mode='val',
118
- args=opt,
119
- verbose=verbose
120
- )
121
-
122
- val_loader = torch.utils.data.DataLoader(
123
- dataset=val_dataset,
124
- batch_size=opt.valid_batch_size,
125
- shuffle=False,
126
- num_workers=4,
127
- drop_last=False,
128
- collate_fn=val_dataset.collate_fn
129
- )
130
-
131
- return val_loader
132
-
133
- def test_dataloader(self):
134
-
135
- return self.val_dataloader('karpathy_test')
136
-
137
- def training_step(self, data, batch_idx):
138
-
139
-
140
- batch = data
141
- self.model.train()
142
-
143
- model_out = self.model.train_step(
144
- img_feat=batch['img_feats'],
145
- text=batch['text'],
146
- neg_text=batch['neg_text'],
147
- )
148
-
149
- clip_loss = model_out['clip_loss']
150
-
151
- if self.opt.joint_out:
152
- loss = clip_loss
153
- else:
154
- grammar_loss = model_out['grammar_loss']
155
- loss = clip_loss + grammar_loss
156
-
157
-
158
- data_time = self.trainer.profiler.recorded_durations["get_train_batch"][-1]
159
- data_time = torch.tensor(data_time)
160
-
161
- # print('batch_idx', batch_idx)
162
- # print('loss:', loss)
163
-
164
- # logger_logs = model_out.copy()
165
- logger_logs = {}
166
-
167
- logger_logs['loss'] = loss.detach()
168
-
169
- logger_logs['clip_loss'] = clip_loss.detach()
170
-
171
- if not self.opt.joint_out:
172
- logger_logs['grammar_loss'] = grammar_loss.detach()
173
-
174
- logger_logs['data_time'] = data_time.detach()
175
-
176
- # UserWarning: The {progress_bar:dict keyword} was deprecated in 0.9.1 and will be removed in 1.0.0
177
- # Please use self.log(...) inside the lightningModule instead.
178
-
179
- # # log on a step or aggregate epoch metric to the logger and/or progress bar
180
- # # (inside LightningModule)
181
- # self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
182
- # warnings.warn(*args, **kwargs)
183
- # UserWarning: The {log:dict keyword} was deprecated in 0.9.1 and will be removed in 1.0.0
184
- # Please use self.log(...) inside the lightningModule instead.
185
-
186
- # output = {
187
- # 'loss': loss,
188
- # 'log': logger_logs,
189
- # 'progress_bar': {'data_time': data_time}
190
- # }
191
-
192
- for k, v in logger_logs.items():
193
- if k in ['data_time', 'clip_loss', 'grammar_loss']:
194
- self.log('train/'+k, v, prog_bar=True)
195
- else:
196
- self.log('train/'+k, v)
197
-
198
- # print('training step logged')
199
-
200
- return loss
201
-
202
- def validation_step(self, data, batch_idx):
203
-
204
- batch = data
205
- self.model.eval()
206
-
207
- with torch.no_grad():
208
- model_out = self.model.train_step(
209
- img_feat=batch['img_feats'],
210
- text=batch['text'],
211
- neg_text=batch['neg_text'],
212
- )
213
-
214
- if self.opt.joint_out:
215
- clip_loss = model_out['clip_loss']
216
- loss = clip_loss
217
-
218
- output = {
219
- # 'val_loss': loss,
220
- 'loss': loss.detach(),
221
- 'clip_loss': clip_loss.detach(),
222
- # 'grammar_loss': grammar_loss.detach(),
223
-
224
- 'img_feat': model_out['img_feat'].detach(),
225
- 'text_feat': model_out['text_feat'].detach(),
226
- # 'neg_text_feat': model_out['neg_text_feat'].detach(),
227
- # 'grammar_pos_pred': model_out['grammar_pos_pred'].detach(),
228
- # 'grammar_neg_pred': model_out['grammar_neg_pred'].detach(),
229
- # 'predictions': predictions,
230
- # 'n_predictions': n_predictions,
231
- }
232
- else:
233
- clip_loss = model_out['clip_loss']
234
- grammar_loss = model_out['grammar_loss']
235
- loss = clip_loss + grammar_loss
236
-
237
- output = {
238
- # 'val_loss': loss,
239
- 'loss': loss.detach(),
240
- 'clip_loss': clip_loss.detach(),
241
- 'grammar_loss': grammar_loss.detach(),
242
-
243
- 'img_feat': model_out['img_feat'].detach(),
244
- 'text_feat': model_out['text_feat'].detach(),
245
- # 'neg_text_feat': model_out['neg_text_feat'].detach(),
246
- 'grammar_pos_pred': model_out['grammar_pos_pred'].detach(),
247
- 'grammar_neg_pred': model_out['grammar_neg_pred'].detach(),
248
- # 'predictions': predictions,
249
- # 'n_predictions': n_predictions,
250
- }
251
- return output
252
-
253
- def test_step(self, *args, **kwargs):
254
- return self.validation_step(*args, **kwargs)
255
-
256
- def validation_epoch_end(self, outputs, split='val'):
257
- outputs = d2comm.gather(outputs)
258
- # master node
259
- if d2comm.is_main_process():
260
- assert self.trainer.node_rank == 0 and self.trainer.local_rank == 0
261
- outputs = sum(outputs, [])
262
-
263
- out = {}
264
-
265
- val_loss_mean = sum([_['loss'].cpu() for _ in outputs]) / len(outputs)
266
- val_clip_loss_mean = sum([_['clip_loss'].cpu() for _ in outputs]) / len(outputs)
267
- if not self.opt.joint_out:
268
- val_grammar_loss_mean = sum([_['grammar_loss'].cpu() for _ in outputs]) / len(outputs)
269
-
270
- print('loss', val_loss_mean.item())
271
- print('clip_loss', val_clip_loss_mean.item())
272
- if not self.opt.joint_out:
273
- print('grammar_loss', val_grammar_loss_mean.item())
274
-
275
- logit_scale = self.model.clip_model.logit_scale.exp().cpu()
276
-
277
- text_feats = torch.cat([_['text_feat'].cpu() for _ in outputs], dim=0)
278
- img_feats = torch.cat([_['img_feat'].cpu() for _ in outputs], dim=0)
279
-
280
- assert text_feats.size() == (5000, 512), text_feats.size()
281
- assert img_feats.size() == (5000, 512), img_feats.size()
282
-
283
- logits_per_text = torch.matmul(text_feats, img_feats.t()) * logit_scale
284
- logits_per_image = logits_per_text.T
285
-
286
- # text-to-image retrieval
287
- print('Text-to-Image retrieval')
288
- for k in [1, 5, 10]:
289
- text_to_image_topk = logits_per_text.topk(k, dim=1).indices
290
-
291
- n_text = len(text_to_image_topk)
292
-
293
- labels = torch.arange(0, n_text).view(-1, 1)
294
-
295
- n_retrieved = ((text_to_image_topk == labels).sum(dim=1) > 0).sum()
296
-
297
- recall_k = n_retrieved / n_text * 100
298
-
299
- out[f'text_to_image_recall_{k}'] = recall_k.item()
300
-
301
- print(f'R@{k}: {recall_k.item():.2f}%')
302
-
303
- # image-to-text retrieval
304
- print('Image-to-Text retrieval')
305
- for k in [1, 5, 10]:
306
- image_to_text_topk = logits_per_image.topk(k, dim=1).indices
307
-
308
- n_image = len(image_to_text_topk)
309
-
310
- labels = torch.arange(0, n_image).view(-1, 1)
311
-
312
- n_retrieved = ((image_to_text_topk == labels).sum(dim=1) > 0).sum()
313
-
314
- recall_k = n_retrieved / n_image * 100
315
-
316
- out[f'image_to_text_recall_{k}'] = recall_k.item()
317
-
318
- print(f'R@{k}: {recall_k.item():.2f}%')
319
-
320
- out.update({
321
- 'loss': val_loss_mean.item(),
322
- 'clip_loss': val_clip_loss_mean.item()
323
- })
324
-
325
- if not self.opt.joint_out:
326
- # grammar scoring
327
- grammar_pos_pred = torch.cat([_['grammar_pos_pred'].cpu() for _ in outputs], dim=0)
328
- grammar_neg_pred = torch.cat([_['grammar_neg_pred'].cpu() for _ in outputs], dim=0)
329
-
330
- TP = (grammar_pos_pred == 1).sum().item()
331
- FP = (grammar_pos_pred == 0).sum().item()
332
- FN = (grammar_neg_pred == 1).sum().item()
333
- TN = (grammar_neg_pred == 0).sum().item()
334
- print('Grammar check')
335
- print(f'TP: {TP} FP: {FP} FN: {FN} TN: {TN}')
336
-
337
- precision = TP / (TP + FP) * 100
338
- recall = TP / (TP + FN) * 100
339
- accuracy = (TP + TN) / (TP + FP + FN + TN) * 100
340
- f1 = 2 * precision * recall / (precision + recall)
341
- print(f'Precision: {precision:.2f}%')
342
- print(f'Recall: {recall:.2f}%')
343
- print(f'Accuracy: {accuracy:.2f}%')
344
- print(f'F1: {f1:.2f}%')
345
- print('Total: {}'.format(len(grammar_pos_pred)))
346
-
347
- out.update({
348
- 'grammar_loss': val_grammar_loss_mean,
349
-
350
- 'grammar_precision': precision,
351
- 'grammar_recall': recall,
352
- 'grammar_accuracy': accuracy,
353
- 'grammar_f1': f1,
354
-
355
- })
356
-
357
- else:
358
- out = {}
359
-
360
- out = d2comm.all_gather(out)[0] # Only the one from master node
361
- assert len(out) > 0 # make sure the head has index 0
362
-
363
- # must all be tensors
364
- out = {k: torch.tensor(v) if not torch.is_tensor(
365
- v) else v for k, v in out.items()}
366
-
367
- for k, v in out.items():
368
- self.log(f'{split}/{k}', v)
369
-
370
- def test_epoch_end(self, outputs):
371
-
372
- self.validation_epoch_end(outputs, 'test')
373
-
374
- def configure_optimizers(self):
375
- # opt = self.opt
376
- # model = self.model
377
-
378
- # parameters = [p for p in model.parameters() if p.requires_grad]
379
-
380
- # if opt.noamopt:
381
- # # assert opt.caption_model in ['transformer', 'bert', 'm2transformer'], 'noamopt can only work with transformer'
382
- # optimizer = utils.get_std_opt(
383
- # model, optim_func=opt.optim, factor=opt.noamopt_factor, warmup=opt.noamopt_warmup)
384
- # elif opt.reduce_on_plateau:
385
- # # optimizer = utils.build_optimizer(model.parameters(), opt)
386
- # optimizer = utils.build_optimizer(parameters, opt)
387
- # optimizer = utils.ReduceLROnPlateau(optimizer,
388
- # factor=opt.reduce_on_plateau_factor,
389
- # patience=opt.reduce_on_plateau_patience)
390
- # else:
391
- # # optimizer = utils.build_optimizer(model.parameters(), opt)
392
- # optimizer = utils.build_optimizer(parameters, opt)
393
-
394
-
395
- # from transformers.optimization import AdamW, get_linear_schedule_with_warmup
396
- # batch_per_epoch = len(self.train_loader)
397
- # t_total = batch_per_epoch // self.args.gradient_accumulation_steps * self.args.epochs
398
- # warmup_ratio = self.args.warmup_ratio
399
- # warmup_iters = int(t_total * warmup_ratio)
400
- # if self.verbose:
401
- # print("Batch per epoch: %d" % batch_per_epoch)
402
- # print("Total Iters: %d" % t_total)
403
- # print('Warmup ratio:', warmup_ratio)
404
- # print("Warm up Iters: %d" % warmup_iters)
405
-
406
- if self.args.optim == 'adamw':
407
- no_decay = ["bias", "LayerNorm.weight"]
408
- optimizer_grouped_parameters = [
409
- {
410
- "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
411
- "weight_decay": self.args.weight_decay,
412
- },
413
- {
414
- "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
415
- "weight_decay": 0.0,
416
- },
417
- ]
418
-
419
- for group in optimizer_grouped_parameters:
420
- group['params'] = [p for p in group['params'] if p.requires_grad]
421
-
422
- from transformers.optimization import AdamW
423
- optim = AdamW(optimizer_grouped_parameters,
424
- lr=self.args.lr, eps=self.args.adam_eps)
425
- # lr_scheduler = get_linear_schedule_with_warmup(
426
- # optim, warmup_iters, t_total)
427
-
428
- # optimizers = []
429
- optimizers = [optim]
430
- lr_schedulers = []
431
-
432
- return optimizers, lr_schedulers
433
-
434
- def optimizer_step(self, epoch, batch_idx, optimizer,
435
- optimizer_idx, *args, **kwargs):
436
- # # warm up lr
437
- # opt = self.opt
438
- # iteration = self.trainer.global_step
439
- # if opt.use_warmup and (iteration < opt.noamopt_warmup):
440
- # opt.current_lr = opt.learning_rate * \
441
- # (iteration+1) / opt.noamopt_warmup
442
- # utils.set_lr(optimizer, opt.current_lr)
443
-
444
- super().optimizer_step(epoch, batch_idx, optimizer,
445
- optimizer_idx, *args, **kwargs)
446
-
447
- # print('optimizer step')
448
-
449
- def state_dict(self):
450
- """
451
- Save the model state dict as well as opt and vocab
452
- """
453
- state_dict = self.model.state_dict()
454
- device = next(iter(state_dict.values())).device
455
- assert '_vocab' not in state_dict and '_opt' not in state_dict, 'Just in case'
456
- # state_dict.update({
457
- # '_vocab': utils.serialize_to_tensor(self.model.vocab).to(device),
458
- # '_opt': utils.serialize_to_tensor(self.opt).to(device)
459
- # })
460
- return state_dict
461
-
462
- def load_state_dict(self, state_dict=None, strict=True):
463
- # if '_vocab' in state_dict:
464
- # self.model.vocab = utils.deserialize(state_dict['_vocab'])
465
- # del state_dict['_vocab']
466
- # elif strict:
467
- # raise KeyError
468
- # if '_opt' in state_dict:
469
- # saved_model_opt = utils.deserialize(state_dict['_opt'])
470
- # del state_dict['_opt']
471
- # opt = self.opt
472
- # # Make sure the saved opt is compatible with the curren topt
473
- # need_be_same = ["caption_model",
474
- # "rnn_type", "rnn_size", "num_layers"]
475
- # for checkme in need_be_same:
476
- # if getattr(saved_model_opt, checkme) in ['updown', 'topdown'] and \
477
- # getattr(opt, checkme) in ['updown', 'topdown']:
478
- # continue
479
- # assert getattr(saved_model_opt, checkme) == getattr(
480
- # opt, checkme), "Command line argument and saved model disagree on '%s' " % checkme
481
- # elif strict:
482
- # raise KeyError
483
- self.model.load_state_dict(state_dict, strict)
484
-
485
-
486
- class OnEpochStartCallback(pl.Callback):
487
-
488
- def on_epoch_start(self, trainer, pl_module):
489
- # Update lr/training stage/scheduled sampling prob etc.
490
- opt = pl_module.opt
491
- model = pl_module.model
492
- epoch = trainer.current_epoch
493
- optimizer = trainer.optimizers[0]
494
-
495
- # if not opt.noamopt and not opt.reduce_on_plateau:
496
- # # Assign the learning rate
497
- # if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
498
- # frac = (
499
- # epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every
500
- # decay_factor = opt.learning_rate_decay_rate ** frac
501
- # opt.current_lr = opt.learning_rate * decay_factor
502
- # else:
503
- # opt.current_lr = opt.learning_rate
504
- # utils.set_lr(optimizer, opt.current_lr) # set the decayed rate
505
- # # Assign the scheduled sampling prob
506
- # if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
507
- # frac = (
508
- # epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every
509
- # opt.ss_prob = min(opt.scheduled_sampling_increase_prob *
510
- # frac, opt.scheduled_sampling_max_prob)
511
- # model.ss_prob = opt.ss_prob
512
-
513
- # # If start self critical training
514
- # if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
515
- # sc_flag = True
516
- # init_scorer(opt.cached_tokens)
517
- # else:
518
- # sc_flag = False
519
-
520
- # # If start structure loss training
521
- # if opt.structure_after != -1 and epoch >= opt.structure_after:
522
- # struc_flag = True
523
- # init_scorer(opt.cached_tokens)
524
- # else:
525
- # struc_flag = False
526
-
527
- # pl_module.struc_flag = struc_flag
528
- # pl_module.sc_flag = sc_flag
529
-
530
-
531
- class ModelCheckpoint(pl.callbacks.ModelCheckpoint):
532
-
533
- def on_keyboard_interrupt(self, trainer, pl_module):
534
- # Save model when keyboard interrupt
535
- filepath = os.path.join(self.dirpath, self.prefix + 'interrupt.ckpt')
536
- self._save_model(filepath)
537
-
538
- from param import parse_args
539
- # opt = opts.parse_opt()
540
- args = parse_args()
541
- opt = args
542
-
543
- checkpoint_callback = ModelCheckpoint(
544
- filepath=opt.checkpoint_dir + '{epoch:02d}',
545
- # dirpath=opt.checkpoint_path,
546
- save_last=True,
547
- save_top_k=1,
548
- verbose=True,
549
- # monitor='to_monitor',
550
- # monitor='val/to_monitor',
551
- # monitor='val/CIDEr',
552
- monitor='val/loss',
553
- mode='min',
554
- # prefix=opt.id+'_',
555
- prefix=opt.id,
556
- # filename=f'{opt.id}_',
557
- )
558
-
559
- verbose = True
560
- # import torch
561
- # if torch.cuda.current_device() in [0, -1]:
562
- if 'LOCAL_RANK' in os.environ and os.environ['LOCAL_RANK'] != '0':
563
- verbose = False
564
-
565
- # if verbose:
566
- # print(opt)
567
- # print("""
568
- # val_image_use,
569
- # save_checkpoint_very
570
- # save_every_epoch,
571
- # save_history-ckpt will be ignored.
572
- # """)
573
-
574
- # Lightning defines batch size as batch size per gpu
575
- assert opt.batch_size % torch.cuda.device_count() == 0
576
- opt.batch_size = opt.batch_size // torch.cuda.device_count()
577
- opt.valid_batch_size = opt.valid_batch_size // torch.cuda.device_count()
578
-
579
- # If resume from last checkpoint
580
- # if opt.start_from is not None and os.path.isfile(os.path.join(opt.start_from, f'{opt.id}_last.ckpt')):
581
- # resume_from = os.path.join(opt.start_from, f'{opt.id}_last.ckpt')
582
- if opt.start_from is not None and os.path.isfile(os.path.join(opt.start_from, f'{opt.id}-last.ckpt')):
583
- resume_from = os.path.join(opt.start_from, f'{opt.id}-last.ckpt')
584
- if verbose:
585
- print('resume from', resume_from)
586
- else:
587
- resume_from = None
588
-
589
- from pytorch_lightning.loggers import WandbLogger
590
- wandb_logger = WandbLogger(
591
- # project='CLIP-ViL-COCOCaption',
592
- project='CLIP-Finetune-COCO',
593
- name=opt.id,
594
- )
595
-
596
- if verbose:
597
- wandb_logger.experiment.config.update(opt)
598
- from pathlib import Path
599
- import glob
600
- import wandb
601
- # src_dir = Path(__file__).resolve().parent.parent
602
- glob_str = "*.py"
603
- base_path = './'
604
- wandb.save(glob_str=glob_str, base_path=base_path)
605
-
606
- glob_str = "**/*.yaml"
607
- base_path = './'
608
- wandb.save(glob_str=glob_str, base_path=base_path)
609
-
610
- # code = wandb.Artifact('project-source', type='code')
611
- # for path in glob.glob('**/*.py', recursive=True):
612
- # code.add_file(path, name='source/'+path)
613
- # print(path)
614
- # wandb.run.use_artifact(code)
615
-
616
-
617
-
618
-
619
- lit = LitModel(opt)
620
- # warning grad_clip_mode is ignored.
621
- trainer = pl.Trainer(
622
- callbacks=[
623
- OnEpochStartCallback(),
624
- # pl.callbacks.lr_logger.LearningRateLogger()
625
- pl.callbacks.LearningRateMonitor()
626
- ],
627
- default_root_dir=opt.checkpoint_dir,
628
- resume_from_checkpoint=resume_from,
629
-
630
- distributed_backend='ddp',
631
- gpus=torch.cuda.device_count(),
632
-
633
- # gpus=1,
634
-
635
- check_val_every_n_epoch=1,
636
- # max_epochs=opt.max_epochs,
637
- max_epochs=opt.epochs,
638
- # gradient_clip_val=opt.grad_clip_value,
639
- gradient_clip_val=opt.clip_grad_norm,
640
-
641
- checkpoint_callback=checkpoint_callback,
642
- log_gpu_memory='min_max',
643
- # log_save_interval=opt.losses_log_every,
644
- log_every_n_steps=opt.losses_log_every,
645
- profiler=True,
646
- # profiler='simple',
647
- # row_log_interval=10, # what is it?
648
- flush_logs_every_n_steps=10,
649
- num_sanity_val_steps=0,
650
- # val_check_interval=0.01,
651
- # limit_train_batches=500,
652
- # progress_bar_refresh_rate=0,
653
- # fast_dev_run=True,
654
- precision=opt.precision,
655
- logger=wandb_logger
656
- )
657
-
658
- if os.getenv('EVALUATE', '0') == '1':
659
- trainer.test(lit)
660
- else:
661
- trainer.fit(lit)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
save/README.md DELETED
@@ -1 +0,0 @@
1
- Directory for checkpoints
 
 
scripts/build_bpe_subword_nmt.py DELETED
@@ -1,214 +0,0 @@
1
- """
2
- Preprocess a raw json dataset into hdf5/json files for use in data_loader.lua
3
-
4
- Input: json file that has the form
5
- [{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...]
6
- example element in this list would look like
7
- {'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895}
8
-
9
- This script reads this json, does some basic preprocessing on the captions
10
- (e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays
11
-
12
- Output: a json file and an hdf5 file
13
- The hdf5 file contains several fields:
14
- /labels is (M,max_length) uint32 array of encoded labels, zero padded
15
- /label_start_ix and /label_end_ix are (N,) uint32 arrays of pointers to the
16
- first and last indices (in range 1..M) of labels for each image
17
- /label_length stores the length of the sequence for each of the M sequences
18
-
19
- The json file has a dict that contains:
20
- - an 'ix_to_word' field storing the vocab in form {ix:'word'}, where ix is 1-indexed
21
- - an 'images' field that is a list holding auxiliary information for each image,
22
- such as in particular the 'split' it was assigned to.
23
- """
24
-
25
- from __future__ import absolute_import
26
- from __future__ import division
27
- from __future__ import print_function
28
-
29
- import os
30
- import json
31
- import argparse
32
- from random import shuffle, seed
33
- import string
34
- # non-standard dependencies:
35
- import h5py
36
- import numpy as np
37
- import torch
38
- import torchvision.models as models
39
- import skimage.io
40
- from PIL import Image
41
-
42
- import codecs
43
- import tempfile
44
- from subword_nmt import learn_bpe, apply_bpe
45
-
46
- # python scripts/build_bpe_subword_nmt.py --input_json data/dataset_coco.json --output_json data/cocotalkbpe.json --output_h5 data/cocotalkbpe
47
-
48
- def build_vocab(imgs, params):
49
- # count up the number of words
50
- captions = []
51
- for img in imgs:
52
- for sent in img['sentences']:
53
- captions.append(' '.join(sent['tokens']))
54
- captions='\n'.join(captions)
55
- all_captions = tempfile.NamedTemporaryFile(delete=False)
56
- all_captions.close()
57
- with open(all_captions.name, 'w') as txt_file:
58
- txt_file.write(captions)
59
-
60
- #
61
- codecs_output = tempfile.NamedTemporaryFile(delete=False)
62
- codecs_output.close()
63
- with codecs.open(codecs_output.name, 'w', encoding='UTF-8') as output:
64
- learn_bpe.learn_bpe(codecs.open(all_captions.name, encoding='UTF-8'), output, params['symbol_count'])
65
-
66
- with codecs.open(codecs_output.name, encoding='UTF-8') as codes:
67
- bpe = apply_bpe.BPE(codes)
68
-
69
- tmp = tempfile.NamedTemporaryFile(delete=False)
70
- tmp.close()
71
-
72
- tmpout = codecs.open(tmp.name, 'w', encoding='UTF-8')
73
-
74
- for _, img in enumerate(imgs):
75
- img['final_captions'] = []
76
- for sent in img['sentences']:
77
- txt = ' '.join(sent['tokens'])
78
- txt = bpe.segment(txt).strip()
79
- img['final_captions'].append(txt.split(' '))
80
- tmpout.write(txt)
81
- tmpout.write('\n')
82
- if _ < 20:
83
- print(txt)
84
-
85
- tmpout.close()
86
- tmpin = codecs.open(tmp.name, encoding='UTF-8')
87
-
88
- vocab = learn_bpe.get_vocabulary(tmpin)
89
- vocab = sorted(vocab.keys(), key=lambda x: vocab[x], reverse=True)
90
-
91
- # Always insert UNK
92
- print('inserting the special UNK token')
93
- vocab.append('UNK')
94
-
95
- print('Vocab size:', len(vocab))
96
-
97
- os.remove(all_captions.name)
98
- with open(codecs_output.name, 'r') as codes:
99
- bpe = codes.read()
100
- os.remove(codecs_output.name)
101
- os.remove(tmp.name)
102
-
103
- return vocab, bpe
104
-
105
- def encode_captions(imgs, params, wtoi):
106
- """
107
- encode all captions into one large array, which will be 1-indexed.
108
- also produces label_start_ix and label_end_ix which store 1-indexed
109
- and inclusive (Lua-style) pointers to the first and last caption for
110
- each image in the dataset.
111
- """
112
-
113
- max_length = params['max_length']
114
- N = len(imgs)
115
- M = sum(len(img['final_captions']) for img in imgs) # total number of captions
116
-
117
- label_arrays = []
118
- label_start_ix = np.zeros(N, dtype='uint32') # note: these will be one-indexed
119
- label_end_ix = np.zeros(N, dtype='uint32')
120
- label_length = np.zeros(M, dtype='uint32')
121
- caption_counter = 0
122
- counter = 1
123
- for i,img in enumerate(imgs):
124
- n = len(img['final_captions'])
125
- assert n > 0, 'error: some image has no captions'
126
-
127
- Li = np.zeros((n, max_length), dtype='uint32')
128
- for j,s in enumerate(img['final_captions']):
129
- label_length[caption_counter] = min(max_length, len(s)) # record the length of this sequence
130
- caption_counter += 1
131
- for k,w in enumerate(s):
132
- if k < max_length:
133
- Li[j,k] = wtoi[w]
134
-
135
- # note: word indices are 1-indexed, and captions are padded with zeros
136
- label_arrays.append(Li)
137
- label_start_ix[i] = counter
138
- label_end_ix[i] = counter + n - 1
139
-
140
- counter += n
141
-
142
- L = np.concatenate(label_arrays, axis=0) # put all the labels together
143
- assert L.shape[0] == M, 'lengths don\'t match? that\'s weird'
144
- assert np.all(label_length > 0), 'error: some caption had no words?'
145
-
146
- print('encoded captions to array of size ', L.shape)
147
- return L, label_start_ix, label_end_ix, label_length
148
-
149
- def main(params):
150
-
151
- imgs = json.load(open(params['input_json'], 'r'))
152
- imgs = imgs['images']
153
-
154
- seed(123) # make reproducible
155
-
156
- # create the vocab
157
- vocab, bpe = build_vocab(imgs, params)
158
- itow = {i+1:w for i,w in enumerate(vocab)} # a 1-indexed vocab translation table
159
- wtoi = {w:i+1 for i,w in enumerate(vocab)} # inverse table
160
-
161
- # encode captions in large arrays, ready to ship to hdf5 file
162
- L, label_start_ix, label_end_ix, label_length = encode_captions(imgs, params, wtoi)
163
-
164
- # create output h5 file
165
- N = len(imgs)
166
- f_lb = h5py.File(params['output_h5']+'_label.h5', "w")
167
- f_lb.create_dataset("labels", dtype='uint32', data=L)
168
- f_lb.create_dataset("label_start_ix", dtype='uint32', data=label_start_ix)
169
- f_lb.create_dataset("label_end_ix", dtype='uint32', data=label_end_ix)
170
- f_lb.create_dataset("label_length", dtype='uint32', data=label_length)
171
- f_lb.close()
172
-
173
- # create output json file
174
- out = {}
175
- out['ix_to_word'] = itow # encode the (1-indexed) vocab
176
- out['images'] = []
177
- out['bpe'] = bpe
178
- for i,img in enumerate(imgs):
179
-
180
- jimg = {}
181
- jimg['split'] = img['split']
182
- if 'filename' in img: jimg['file_path'] = os.path.join(img['filepath'], img['filename']) # copy it over, might need
183
- if 'cocoid' in img: jimg['id'] = img['cocoid'] # copy over & mantain an id, if present (e.g. coco ids, useful)
184
-
185
- if params['images_root'] != '':
186
- with Image.open(os.path.join(params['images_root'], img['filepath'], img['filename'])) as _img:
187
- jimg['width'], jimg['height'] = _img.size
188
-
189
- out['images'].append(jimg)
190
-
191
- json.dump(out, open(params['output_json'], 'w'))
192
- print('wrote ', params['output_json'])
193
-
194
- if __name__ == "__main__":
195
-
196
- parser = argparse.ArgumentParser()
197
-
198
- # input json
199
- parser.add_argument('--input_json', required=True, help='input json file to process into hdf5')
200
- parser.add_argument('--output_json', default='data.json', help='output json file')
201
- parser.add_argument('--output_h5', default='data', help='output h5 file')
202
- parser.add_argument('--images_root', default='', help='root location in which images are stored, to be prepended to file_path in input json')
203
-
204
- # options
205
- parser.add_argument('--max_length', default=16, type=int, help='max length of a caption, in number of words. captions longer than this get clipped.')
206
- parser.add_argument('--symbol_count', default=10000, type=int, help='only words that occur more than this number of times will be put in vocab')
207
-
208
- args = parser.parse_args()
209
- params = vars(args) # convert to ordinary dict
210
- print('parsed input parameters:')
211
- print(json.dumps(params, indent = 2))
212
- main(params)
213
-
214
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/clip_prepro_feats.py DELETED
@@ -1,170 +0,0 @@
1
- """
2
- Preprocess a raw json dataset into features files for use in data_loader.py
3
-
4
- Input: json file that has the form
5
- [{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...]
6
- example element in this list would look like
7
- {'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895}
8
-
9
- This script reads this json, does some basic preprocessing on the captions
10
- (e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays
11
-
12
- Output: two folders of features
13
- """
14
-
15
- from __future__ import absolute_import
16
- from __future__ import division
17
- from __future__ import print_function
18
-
19
- import os
20
- import json
21
- import argparse
22
- from random import shuffle, seed
23
- import string
24
- # non-standard dependencies:
25
- import h5py
26
- from six.moves import cPickle
27
- import numpy as np
28
- import torch
29
- import torchvision.models as models
30
- import skimage.io
31
-
32
- from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
33
- from PIL import Image
34
- from torch import nn
35
-
36
- preprocess = Compose([
37
- Resize((448, 448), interpolation=Image.BICUBIC),
38
- CenterCrop((448, 448)),
39
- ToTensor()
40
- ])
41
-
42
-
43
- from clip.clip import load
44
- from timm.models.vision_transformer import resize_pos_embed
45
- import timm
46
-
47
- from captioning.utils.resnet_utils import myResnet
48
- import captioning.utils.resnet as resnet
49
-
50
- from tqdm import tqdm
51
-
52
-
53
- def main(params):
54
- if params["model_type"] != 'vit_base_patch32_224_in21k':
55
- model, transform = load(params["model_type"], jit=False)
56
- else:
57
- model = timm.create_model(params["model_type"], pretrained=True)
58
- model = model.cuda()
59
-
60
- if params["model_type"] != 'vit_base_patch32_224_in21k':
61
- save_model_type = params["model_type"].split("-")[0]
62
- mean = torch.Tensor([0.48145466, 0.4578275, 0.40821073]).to("cuda").reshape(3, 1, 1)
63
- std = torch.Tensor([0.26862954, 0.26130258, 0.27577711]).to("cuda").reshape(3, 1, 1)
64
-
65
- if "RN" in params["model_type"]:
66
- num_patches = 196 #600 * 1000 // 32 // 32
67
- pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, model.visual.attnpool.positional_embedding.shape[-1], device='cuda'),)
68
- pos_embed.weight = resize_pos_embed(model.visual.attnpool.positional_embedding.unsqueeze(0), pos_embed)
69
- model.visual.attnpool.positional_embedding = pos_embed
70
-
71
- else:
72
- save_model_type = 'vit_base'
73
- mean = torch.Tensor([0.5, 0.5, 0.5]).to("cuda").reshape(3, 1, 1)
74
- std = torch.Tensor([0.5, 0.5, 0.5]).to("cuda").reshape(3, 1, 1)
75
-
76
- num_patches = 196 #600 * 1000 // 32 // 32
77
- pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, 768, device='cuda'),)
78
- pos_embed.weight = resize_pos_embed(model.pos_embed, pos_embed)
79
- model.pos_embed = pos_embed
80
-
81
- if params["model_type"] == "ViT-B/32":
82
- num_patches = 196 #600 * 1000 // 32 // 32
83
- pos_embed = nn.Parameter(torch.zeros(num_patches + 1, 768, device='cuda'),)
84
- pos_embed.weight = resize_pos_embed(model.visual.positional_embedding.unsqueeze(0), pos_embed.unsqueeze(0))
85
- model.visual.positional_embedding = pos_embed
86
- imgs = json.load(open(params['input_json'], 'r'))
87
-
88
- imgs = imgs['images']
89
-
90
- if args.n_jobs > 1:
91
- print('Total imgs:', len(imgs))
92
- print('Using {} jobs'.format(args.n_jobs))
93
- print('job id:', args.job_id)
94
- imgs = imgs[args.job_id::args.n_jobs]
95
-
96
- N = len(imgs)
97
-
98
- seed(123) # make reproducible
99
-
100
- dir_fc = params['output_dir']+'_clip_'+save_model_type+'_fc'
101
- dir_att = params['output_dir']+'_clip_'+save_model_type+'_att'
102
- if not os.path.isdir(dir_fc):
103
- os.mkdir(dir_fc)
104
- if not os.path.isdir(dir_att):
105
- os.mkdir(dir_att)
106
-
107
- for i,img in enumerate(tqdm(imgs)):
108
- # load the image
109
- with torch.no_grad():
110
-
111
- image = preprocess(Image.open(os.path.join(params['images_root'], img['filepath'], img['filename']) ).convert("RGB"))
112
- image = torch.tensor(np.stack([image])).cuda()
113
- image -= mean
114
- image /= std
115
- if "RN" in params["model_type"]:
116
- tmp_att, tmp_fc = model.encode_image(image)
117
- tmp_att = tmp_att[0].permute(1, 2, 0)
118
- tmp_fc = tmp_fc[0]
119
- elif params["model_type"] == 'vit_base_patch32_224_in21k':
120
- x = model(image)
121
- tmp_fc = x[0, 0, :]
122
- tmp_att = x[0, 1:, :].reshape( 14, 14, 768 )
123
- else:
124
- x = model.visual.conv1(image.half()) # shape = [*, width, grid, grid]
125
- x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
126
- x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
127
- x = torch.cat([model.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
128
- x = x + model.visual.positional_embedding.to(x.dtype)[:x.shape[1], :]
129
- x = model.visual.ln_pre(x)
130
-
131
- x = x.permute(1, 0, 2) # NLD -> LND
132
-
133
- for layer_idx, layer in enumerate(model.visual.transformer.resblocks):
134
- x = layer(x)
135
-
136
- x = x.permute(1, 0, 2)
137
- tmp_fc = x[0, 0, :]
138
- tmp_att = x[0, 1:, :].reshape( 14, 14, 768 )
139
-
140
- np.save(os.path.join(dir_fc, str(img['cocoid'])), tmp_fc.data.cpu().float().numpy())
141
- np.savez_compressed(os.path.join(dir_att, str(img['cocoid'])), feat=tmp_att.data.cpu().float().numpy())
142
-
143
-
144
- # if i % 1000 == 0:
145
- # print('processing %d/%d (%.2f%% done)' % (i, N, i*100.0/N))
146
- print('wrote ', dir_fc, dir_att)
147
-
148
- if __name__ == "__main__":
149
-
150
- parser = argparse.ArgumentParser()
151
-
152
- # input json
153
- parser.add_argument('--input_json', required=True, help='input json file to process into hdf5')
154
- parser.add_argument('--output_dir', default='data', help='output h5 file')
155
-
156
- # options
157
- parser.add_argument('--images_root', default='', help='root location in which images are stored, to be prepended to file_path in input json')
158
- parser.add_argument('--att_size', default=14, type=int, help='14x14 or 7x7')
159
- parser.add_argument('--model_type', default='RN50', type=str, help='RN50, RN101, RN50x4, ViT-B/32, vit_base_patch32_224_in21k')
160
-
161
- parser.add_argument('--n_jobs', default=-1, type=int, help='number of jobs to run in parallel')
162
- parser.add_argument('--job_id', default=0, type=int, help='job id')
163
- parser.add_argument('--batch_size', default=1, type=int, help='batch size')
164
-
165
-
166
- args = parser.parse_args()
167
- params = vars(args) # convert to ordinary dict
168
- print('parsed input parameters:')
169
- print(json.dumps(params, indent = 2))
170
- main(params)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/clipscore_prepro_feats.py DELETED
@@ -1,162 +0,0 @@
1
- """
2
- Preprocess a raw json dataset into features files for use in data_loader.py
3
-
4
- Input: json file that has the form
5
- [{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...]
6
- example element in this list would look like
7
- {'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895}
8
-
9
- This script reads this json, does some basic preprocessing on the captions
10
- (e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays
11
-
12
- Output: two folders of features
13
- """
14
-
15
- from __future__ import absolute_import
16
- from __future__ import division
17
- from __future__ import print_function
18
-
19
- import os
20
- import json
21
- import argparse
22
- from random import shuffle, seed
23
- import string
24
- # non-standard dependencies:
25
- import h5py
26
- from six.moves import cPickle
27
- import numpy as np
28
- import torch
29
- import torchvision.models as models
30
- import skimage.io
31
-
32
- from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
33
- from PIL import Image
34
- from torch import nn
35
-
36
- # preprocess = Compose([
37
- # Resize((448, 448), interpolation=Image.BICUBIC),
38
- # CenterCrop((448, 448)),
39
- # ToTensor()
40
- # ])
41
-
42
-
43
- # from clip.clip import load
44
- # from timm.models.vision_transformer import resize_pos_embed
45
- # import timm
46
-
47
- # from captioning.utils.resnet_utils import myResnet
48
- # import captioning.utils.resnet as resnet
49
-
50
- from captioning.utils.clipscore import CLIPScore
51
-
52
- from tqdm import tqdm
53
-
54
-
55
-
56
- def main(params):
57
-
58
- clipscore_model = CLIPScore()
59
- clipscore_model.to('cuda')
60
-
61
- imgs = json.load(open(params['input_json'], 'r'))
62
- imgs = imgs['images']
63
-
64
- if args.n_jobs > 1:
65
- print('Total imgs:', len(imgs))
66
- print('Using {} jobs'.format(args.n_jobs))
67
- print('job id:', args.job_id)
68
- imgs = imgs[args.job_id::args.n_jobs]
69
-
70
- N = len(imgs)
71
-
72
- seed(123) # make reproducible
73
-
74
- # dir_fc = params['output_dir']+'_clip_'+save_model_type+'_fc'
75
- # dir_att = params['output_dir']+'_clip_'+save_model_type+'_att'
76
-
77
- vis_dir_fc = params['output_dir']+'_clipscore_vis'
78
- if not os.path.isdir(vis_dir_fc):
79
- os.mkdir(vis_dir_fc)
80
-
81
- # text_dir_fc = params['output_dir']+'_clipscore_text'
82
- # if not os.path.isdir(text_dir_fc):
83
- # os.mkdir(text_dir_fc)
84
-
85
- # if not os.path.isdir(dir_att):
86
- # os.mkdir(dir_att)
87
-
88
- for i, img in enumerate(tqdm(imgs)):
89
- # load the image
90
-
91
- img_path = os.path.join(params['images_root'], img['filepath'], img['filename'])
92
- img_feat = clipscore_model.image_extract(img_path)
93
- img_feat = img_feat.view(512)
94
-
95
- # for d in img['sentences']:
96
- # text = d['raw'].strip()
97
- # text_feat = clipscore_model.text_extract(text)
98
-
99
-
100
- # with torch.no_grad():
101
-
102
- # image = preprocess(Image.open(os.path.join(params['images_root'], img['filepath'], img['filename']) ).convert("RGB"))
103
- # image = torch.tensor(np.stack([image])).cuda()
104
- # image -= mean
105
- # image /= std
106
- # if "RN" in params["model_type"]:
107
- # tmp_att, tmp_fc = model.encode_image(image)
108
- # tmp_att = tmp_att[0].permute(1, 2, 0)
109
- # tmp_fc = tmp_fc[0]
110
- # elif params["model_type"] == 'vit_base_patch32_224_in21k':
111
- # x = model(image)
112
- # tmp_fc = x[0, 0, :]
113
- # tmp_att = x[0, 1:, :].reshape( 14, 14, 768 )
114
- # else:
115
- # x = model.visual.conv1(image.half()) # shape = [*, width, grid, grid]
116
- # x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
117
- # x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
118
- # x = torch.cat([model.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
119
- # x = x + model.visual.positional_embedding.to(x.dtype)[:x.shape[1], :]
120
- # x = model.visual.ln_pre(x)
121
-
122
- # x = x.permute(1, 0, 2) # NLD -> LND
123
-
124
- # for layer_idx, layer in enumerate(model.visual.transformer.resblocks):
125
- # x = layer(x)
126
-
127
- # x = x.permute(1, 0, 2)
128
- # tmp_fc = x[0, 0, :]
129
- # tmp_att = x[0, 1:, :].reshape( 14, 14, 768 )
130
-
131
- np.save(os.path.join(vis_dir_fc, str(img['cocoid'])), img_feat.data.cpu().float().numpy())
132
- # np.save(os.path.join(text_dir_fc, str(img['cocoid'])), tmp_fc.data.cpu().float().numpy())
133
-
134
-
135
- # np.savez_compressed(os.path.join(dir_att, str(img['cocoid'])), feat=tmp_att.data.cpu().float().numpy())
136
-
137
- if i % 1000 == 0:
138
- print('processing %d/%d (%.2f%% done)' % (i, N, i*100.0/N))
139
- print('wrote ', vis_dir_fc)
140
-
141
- if __name__ == "__main__":
142
-
143
- parser = argparse.ArgumentParser()
144
-
145
- # input json
146
- # dataset_coco.json
147
- parser.add_argument('--input_json', required=True, help='input json file to process into hdf5')
148
- parser.add_argument('--output_dir', default='data', help='output h5 file')
149
-
150
- # options
151
- parser.add_argument('--images_root', default='', help='root location in which images are stored, to be prepended to file_path in input json')
152
- # parser.add_argument('--att_size', default=14, type=int, help='14x14 or 7x7')
153
- # parser.add_argument('--model_type', default='RN50', type=str, help='RN50, RN101, RN50x4, ViT-B/32, vit_base_patch32_224_in21k')
154
-
155
- parser.add_argument('--n_jobs', default=-1, type=int, help='number of jobs to run in parallel')
156
- parser.add_argument('--job_id', default=0, type=int, help='job id')
157
-
158
- args = parser.parse_args()
159
- params = vars(args) # convert to ordinary dict
160
- print('parsed input parameters:')
161
- print(json.dumps(params, indent = 2))
162
- main(params)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/copy_model.sh DELETED
@@ -1,9 +0,0 @@
1
- #!/bin/sh
2
-
3
- if [ ! -d log_$2 ]; then
4
- cp -r log_$1 log_$2
5
- cd log_$2
6
- mv infos_$1-best.pkl infos_$2-best.pkl
7
- mv infos_$1.pkl infos_$2.pkl
8
- cd ../
9
- fi
 
 
 
 
 
 
 
 
 
 
scripts/dump_to_h5df.py DELETED
@@ -1,56 +0,0 @@
1
- import argparse
2
- import h5py
3
- import os
4
- import numpy as np
5
- import json
6
- from tqdm import tqdm
7
-
8
-
9
- def main(params):
10
-
11
- imgs = json.load(open(params['input_json'], 'r'))
12
- imgs = imgs['images']
13
- N = len(imgs)
14
-
15
- if params['fc_input_dir'] is not None:
16
- print('processing fc')
17
- with h5py.File(params['fc_output']) as file_fc:
18
- for i, img in enumerate(tqdm(imgs)):
19
- npy_fc_path = os.path.join(
20
- params['fc_input_dir'],
21
- str(img['cocoid']) + '.npy')
22
-
23
- d_set_fc = file_fc.create_dataset(
24
- str(img['cocoid']), data=np.load(npy_fc_path))
25
- file_fc.close()
26
-
27
- if params['att_input_dir'] is not None:
28
- print('processing att')
29
- with h5py.File(params['att_output']) as file_att:
30
- for i, img in enumerate(tqdm(imgs)):
31
- npy_att_path = os.path.join(
32
- params['att_input_dir'],
33
- str(img['cocoid']) + '.npz')
34
-
35
- d_set_att = file_att.create_dataset(
36
- str(img['cocoid']),
37
- data=np.load(npy_att_path)['feat'])
38
- file_att.close()
39
-
40
-
41
- if __name__ == "__main__":
42
-
43
- parser = argparse.ArgumentParser()
44
-
45
- parser.add_argument('--input_json', required=True, help='input json file to process into hdf5')
46
- parser.add_argument('--fc_output', default='data', help='output h5 filename for fc')
47
- parser.add_argument('--att_output', default='data', help='output h5 file for att')
48
- parser.add_argument('--fc_input_dir', default=None, help='input directory for numpy fc files')
49
- parser.add_argument('--att_input_dir', default=None, help='input directory for numpy att files')
50
-
51
- args = parser.parse_args()
52
- params = vars(args) # convert to ordinary dict
53
- print('parsed input parameters:')
54
- print(json.dumps(params, indent=2))
55
-
56
- main(params)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/dump_to_lmdb.py DELETED
@@ -1,241 +0,0 @@
1
- # copy from https://github.com/Lyken17/Efficient-PyTorch/tools
2
-
3
- from __future__ import absolute_import
4
- from __future__ import division
5
- from __future__ import print_function
6
-
7
- import os
8
- import os.path as osp
9
- import os, sys
10
- import os.path as osp
11
- from PIL import Image
12
- import six
13
- import string
14
-
15
- from lmdbdict import lmdbdict
16
- from lmdbdict.methods import DUMPS_FUNC, LOADS_FUNC
17
- import pickle
18
- import tqdm
19
- import numpy as np
20
- import argparse
21
- import json
22
-
23
- import torch
24
- import torch.utils.data as data
25
- from torch.utils.data import DataLoader
26
-
27
- import csv
28
- csv.field_size_limit(sys.maxsize)
29
- FIELDNAMES = ['image_id', 'status']
30
-
31
- class FolderLMDB(data.Dataset):
32
- def __init__(self, db_path, fn_list=None):
33
- self.db_path = db_path
34
- self.lmdb = lmdbdict(db_path, unsafe=True)
35
- self.lmdb._key_dumps = DUMPS_FUNC['ascii']
36
- self.lmdb._value_loads = LOADS_FUNC['identity']
37
- if fn_list is not None:
38
- self.length = len(fn_list)
39
- self.keys = fn_list
40
- else:
41
- raise Error
42
-
43
- def __getitem__(self, index):
44
- byteflow = self.lmdb[self.keys[index]]
45
-
46
- # load image
47
- imgbuf = byteflow
48
- buf = six.BytesIO()
49
- buf.write(imgbuf)
50
- buf.seek(0)
51
- try:
52
- if args.extension == '.npz':
53
- feat = np.load(buf)['feat']
54
- else:
55
- feat = np.load(buf)
56
- except Exception as e:
57
- print(self.keys[index], e)
58
- return None
59
-
60
- return feat
61
-
62
- def __len__(self):
63
- return self.length
64
-
65
- def __repr__(self):
66
- return self.__class__.__name__ + ' (' + self.db_path + ')'
67
-
68
-
69
- def make_dataset(dir, extension):
70
- images = []
71
- dir = os.path.expanduser(dir)
72
- for root, _, fnames in sorted(os.walk(dir)):
73
- for fname in sorted(fnames):
74
- if has_file_allowed_extension(fname, [extension]):
75
- path = os.path.join(root, fname)
76
- images.append(path)
77
-
78
- return images
79
-
80
-
81
- def raw_reader(path):
82
- with open(path, 'rb') as f:
83
- bin_data = f.read()
84
- return bin_data
85
-
86
-
87
- def raw_npz_reader(path):
88
- with open(path, 'rb') as f:
89
- bin_data = f.read()
90
- try:
91
- npz_data = np.load(six.BytesIO(bin_data))['feat']
92
- except Exception as e:
93
- print(path)
94
- npz_data = None
95
- print(e)
96
- return bin_data, npz_data
97
-
98
-
99
- def raw_npy_reader(path):
100
- with open(path, 'rb') as f:
101
- bin_data = f.read()
102
- try:
103
- npy_data = np.load(six.BytesIO(bin_data))
104
- except Exception as e:
105
- print(path)
106
- npy_data = None
107
- print(e)
108
- return bin_data, npy_data
109
-
110
-
111
- class Folder(data.Dataset):
112
-
113
- def __init__(self, root, loader, extension, fn_list=None):
114
- super(Folder, self).__init__()
115
- self.root = root
116
- if fn_list:
117
- samples = [os.path.join(root, str(_)+extension) for _ in fn_list]
118
- else:
119
- samples = make_dataset(self.root, extension)
120
-
121
- self.loader = loader
122
- self.extension = extension
123
- self.samples = samples
124
-
125
- def __getitem__(self, index):
126
- """
127
- Args:
128
- index (int): Index
129
- Returns:
130
- tuple: (sample, target) where target is class_index of the target class.
131
- """
132
- path = self.samples[index]
133
- sample = self.loader(path)
134
-
135
- return (path.split('/')[-1].split('.')[0],) + sample
136
-
137
- def __len__(self):
138
- return len(self.samples)
139
-
140
-
141
- def folder2lmdb(dpath, fn_list, write_frequency=5000):
142
- directory = osp.expanduser(osp.join(dpath))
143
- print("Loading dataset from %s" % directory)
144
- if args.extension == '.npz':
145
- dataset = Folder(directory, loader=raw_npz_reader, extension='.npz',
146
- fn_list=fn_list)
147
- else:
148
- dataset = Folder(directory, loader=raw_npy_reader, extension='.npy',
149
- fn_list=fn_list)
150
- data_loader = DataLoader(dataset, num_workers=16, collate_fn=lambda x: x)
151
-
152
- # lmdb_path = osp.join(dpath, "%s.lmdb" % (directory.split('/')[-1]))
153
- lmdb_path = osp.join("%s.lmdb" % (directory))
154
- isdir = os.path.isdir(lmdb_path)
155
-
156
- print("Generate LMDB to %s" % lmdb_path)
157
- db = lmdbdict(lmdb_path, mode='w', key_method='ascii', value_method='identity')
158
-
159
- tsvfile = open(args.output_file, 'a')
160
- writer = csv.DictWriter(tsvfile, delimiter='\t', fieldnames=FIELDNAMES)
161
- names = []
162
- all_keys = []
163
- for idx, data in enumerate(tqdm.tqdm(data_loader)):
164
- # print(type(data), data)
165
- name, byte, npz = data[0]
166
- if npz is not None:
167
- db[name] = byte
168
- all_keys.append(name)
169
- names.append({'image_id': name, 'status': str(npz is not None)})
170
- if idx % write_frequency == 0:
171
- print("[%d/%d]" % (idx, len(data_loader)))
172
- print('writing')
173
- db.flush()
174
- # write in tsv
175
- for name in names:
176
- writer.writerow(name)
177
- names = []
178
- tsvfile.flush()
179
- print('writing finished')
180
- # write all keys
181
- # txn.put("keys".encode(), pickle.dumps(all_keys))
182
- # # finish iterating through dataset
183
- # txn.commit()
184
- for name in names:
185
- writer.writerow(name)
186
- tsvfile.flush()
187
- tsvfile.close()
188
-
189
- print("Flushing database ...")
190
- db.flush()
191
- del db
192
-
193
- def parse_args():
194
- """
195
- Parse input arguments
196
- """
197
- parser = argparse.ArgumentParser(description='Generate bbox output from a Fast R-CNN network')
198
- # parser.add_argument('--json)
199
- parser.add_argument('--input_json', default='./data/dataset_coco.json', type=str)
200
- parser.add_argument('--output_file', default='.dump_cache.tsv', type=str)
201
- parser.add_argument('--folder', default='./data/cocobu_att', type=str)
202
- parser.add_argument('--extension', default='.npz', type=str)
203
-
204
- args = parser.parse_args()
205
- return args
206
-
207
- if __name__ == "__main__":
208
- global args
209
- args = parse_args()
210
-
211
- args.output_file += args.folder.split('/')[-1]
212
- if args.folder.find('/') > 0:
213
- args.output_file = args.folder[:args.folder.rfind('/')+1]+args.output_file
214
- print(args.output_file)
215
-
216
- img_list = json.load(open(args.input_json, 'r'))['images']
217
- fn_list = [str(_['cocoid']) for _ in img_list]
218
- found_ids = set()
219
- try:
220
- with open(args.output_file, 'r') as tsvfile:
221
- reader = csv.DictReader(tsvfile, delimiter='\t', fieldnames=FIELDNAMES)
222
- for item in reader:
223
- if item['status'] == 'True':
224
- found_ids.add(item['image_id'])
225
- except:
226
- pass
227
- fn_list = [_ for _ in fn_list if _ not in found_ids]
228
- folder2lmdb(args.folder, fn_list)
229
-
230
- # Test existing.
231
- found_ids = set()
232
- with open(args.output_file, 'r') as tsvfile:
233
- reader = csv.DictReader(tsvfile, delimiter='\t', fieldnames=FIELDNAMES)
234
- for item in reader:
235
- if item['status'] == 'True':
236
- found_ids.add(item['image_id'])
237
-
238
- folder_dataset = FolderLMDB(args.folder+'.lmdb', list(found_ids))
239
- data_loader = DataLoader(folder_dataset, num_workers=16, collate_fn=lambda x: x)
240
- for data in tqdm.tqdm(data_loader):
241
- assert data[0] is not None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/make_bu_data.py DELETED
@@ -1,52 +0,0 @@
1
- from __future__ import absolute_import
2
- from __future__ import division
3
- from __future__ import print_function
4
-
5
- import os
6
- import base64
7
- import numpy as np
8
- import csv
9
- import sys
10
- import zlib
11
- import time
12
- import mmap
13
- import argparse
14
-
15
- parser = argparse.ArgumentParser()
16
-
17
- # output_dir
18
- parser.add_argument('--downloaded_feats', default='data/bu_data', help='downloaded feature directory')
19
- parser.add_argument('--output_dir', default='data/cocobu', help='output feature files')
20
-
21
- args = parser.parse_args()
22
-
23
- csv.field_size_limit(sys.maxsize)
24
-
25
-
26
- FIELDNAMES = ['image_id', 'image_w','image_h','num_boxes', 'boxes', 'features']
27
- infiles = ['trainval/karpathy_test_resnet101_faster_rcnn_genome.tsv',
28
- 'trainval/karpathy_val_resnet101_faster_rcnn_genome.tsv',\
29
- 'trainval/karpathy_train_resnet101_faster_rcnn_genome.tsv.0', \
30
- 'trainval/karpathy_train_resnet101_faster_rcnn_genome.tsv.1']
31
-
32
- os.makedirs(args.output_dir+'_att')
33
- os.makedirs(args.output_dir+'_fc')
34
- os.makedirs(args.output_dir+'_box')
35
-
36
- for infile in infiles:
37
- print('Reading ' + infile)
38
- with open(os.path.join(args.downloaded_feats, infile), "r") as tsv_in_file:
39
- reader = csv.DictReader(tsv_in_file, delimiter='\t', fieldnames = FIELDNAMES)
40
- for item in reader:
41
- item['image_id'] = int(item['image_id'])
42
- item['num_boxes'] = int(item['num_boxes'])
43
- for field in ['boxes', 'features']:
44
- item[field] = np.frombuffer(base64.decodestring(item[field].encode('ascii')),
45
- dtype=np.float32).reshape((item['num_boxes'],-1))
46
- np.savez_compressed(os.path.join(args.output_dir+'_att', str(item['image_id'])), feat=item['features'])
47
- np.save(os.path.join(args.output_dir+'_fc', str(item['image_id'])), item['features'].mean(0))
48
- np.save(os.path.join(args.output_dir+'_box', str(item['image_id'])), item['boxes'])
49
-
50
-
51
-
52
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/prepro_feats.py DELETED
@@ -1,103 +0,0 @@
1
- """
2
- Preprocess a raw json dataset into features files for use in data_loader.py
3
-
4
- Input: json file that has the form
5
- [{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...]
6
- example element in this list would look like
7
- {'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895}
8
-
9
- This script reads this json, does some basic preprocessing on the captions
10
- (e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays
11
-
12
- Output: two folders of features
13
- """
14
-
15
- from __future__ import absolute_import
16
- from __future__ import division
17
- from __future__ import print_function
18
-
19
- import os
20
- import json
21
- import argparse
22
- from random import shuffle, seed
23
- import string
24
- # non-standard dependencies:
25
- import h5py
26
- from six.moves import cPickle
27
- import numpy as np
28
- import torch
29
- import torchvision.models as models
30
- import skimage.io
31
-
32
- from torchvision import transforms as trn
33
- preprocess = trn.Compose([
34
- #trn.ToTensor(),
35
- trn.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
36
- ])
37
-
38
- from captioning.utils.resnet_utils import myResnet
39
- import captioning.utils.resnet as resnet
40
-
41
-
42
- def main(params):
43
- net = getattr(resnet, params['model'])()
44
- net.load_state_dict(torch.load(os.path.join(params['model_root'],params['model']+'.pth')))
45
- my_resnet = myResnet(net)
46
- my_resnet.cuda()
47
- my_resnet.eval()
48
-
49
- imgs = json.load(open(params['input_json'], 'r'))
50
- imgs = imgs['images']
51
- N = len(imgs)
52
-
53
- seed(123) # make reproducible
54
-
55
- dir_fc = params['output_dir']+'_fc'
56
- dir_att = params['output_dir']+'_att'
57
- if not os.path.isdir(dir_fc):
58
- os.mkdir(dir_fc)
59
- if not os.path.isdir(dir_att):
60
- os.mkdir(dir_att)
61
-
62
- for i,img in enumerate(imgs):
63
- # load the image
64
- I = skimage.io.imread(os.path.join(params['images_root'], img['filepath'], img['filename']))
65
- # handle grayscale input images
66
- if len(I.shape) == 2:
67
- I = I[:,:,np.newaxis]
68
- I = np.concatenate((I,I,I), axis=2)
69
-
70
- I = I.astype('float32')/255.0
71
- I = torch.from_numpy(I.transpose([2,0,1])).cuda()
72
- I = preprocess(I)
73
- with torch.no_grad():
74
- tmp_fc, tmp_att = my_resnet(I, params['att_size'])
75
- # write to pkl
76
- # print(dir_fc, str(img['cocoid']), tmp_fc.shape, tmp_att.shape, dir_att)
77
- # exit()
78
- np.save(os.path.join(dir_fc, str(img['cocoid'])), tmp_fc.data.cpu().float().numpy())
79
- np.savez_compressed(os.path.join(dir_att, str(img['cocoid'])), feat=tmp_att.data.cpu().float().numpy())
80
-
81
- if i % 1000 == 0:
82
- print('processing %d/%d (%.2f%% done)' % (i, N, i*100.0/N))
83
- print('wrote ', params['output_dir'])
84
-
85
- if __name__ == "__main__":
86
-
87
- parser = argparse.ArgumentParser()
88
-
89
- # input json
90
- parser.add_argument('--input_json', required=True, help='input json file to process into hdf5')
91
- parser.add_argument('--output_dir', default='data', help='output h5 file')
92
-
93
- # options
94
- parser.add_argument('--images_root', default='', help='root location in which images are stored, to be prepended to file_path in input json')
95
- parser.add_argument('--att_size', default=14, type=int, help='14x14 or 7x7')
96
- parser.add_argument('--model', default='resnet101', type=str, help='resnet101, resnet152')
97
- parser.add_argument('--model_root', default='./data/imagenet_weights', type=str, help='model root')
98
-
99
- args = parser.parse_args()
100
- params = vars(args) # convert to ordinary dict
101
- print('parsed input parameters:')
102
- print(json.dumps(params, indent = 2))
103
- main(params)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/prepro_labels.py DELETED
@@ -1,206 +0,0 @@
1
- """
2
- Preprocess a raw json dataset into hdf5/json files for use in data_loader.py
3
-
4
- Input: json file that has the form
5
- [{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...]
6
- example element in this list would look like
7
- {'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895}
8
-
9
- This script reads this json, does some basic preprocessing on the captions
10
- (e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays
11
-
12
- Output: a json file and an hdf5 file
13
- The hdf5 file contains several fields:
14
- /labels is (M,max_length) uint32 array of encoded labels, zero padded
15
- /label_start_ix and /label_end_ix are (N,) uint32 arrays of pointers to the
16
- first and last indices (in range 1..M) of labels for each image
17
- /label_length stores the length of the sequence for each of the M sequences
18
-
19
- The json file has a dict that contains:
20
- - an 'ix_to_word' field storing the vocab in form {ix:'word'}, where ix is 1-indexed
21
- - an 'images' field that is a list holding auxiliary information for each image,
22
- such as in particular the 'split' it was assigned to.
23
- """
24
-
25
- from __future__ import absolute_import
26
- from __future__ import division
27
- from __future__ import print_function
28
-
29
- import os
30
- import json
31
- import argparse
32
- from random import shuffle, seed
33
- import string
34
- # non-standard dependencies:
35
- import h5py
36
- import numpy as np
37
- import torch
38
- import torchvision.models as models
39
- import skimage.io
40
- from PIL import Image
41
-
42
-
43
- def build_vocab(imgs, params):
44
- count_thr = params['word_count_threshold']
45
-
46
- # count up the number of words
47
- counts = {}
48
- for img in imgs:
49
- for sent in img['sentences']:
50
- for w in sent['tokens']:
51
- counts[w] = counts.get(w, 0) + 1
52
- cw = sorted([(count,w) for w,count in counts.items()], reverse=True)
53
- print('top words and their counts:')
54
- print('\n'.join(map(str,cw[:20])))
55
-
56
- # print some stats
57
- total_words = sum(counts.values())
58
- print('total words:', total_words)
59
- bad_words = [w for w,n in counts.items() if n <= count_thr]
60
- vocab = [w for w,n in counts.items() if n > count_thr]
61
- bad_count = sum(counts[w] for w in bad_words)
62
- print('number of bad words: %d/%d = %.2f%%' % (len(bad_words), len(counts), len(bad_words)*100.0/len(counts)))
63
- print('number of words in vocab would be %d' % (len(vocab), ))
64
- print('number of UNKs: %d/%d = %.2f%%' % (bad_count, total_words, bad_count*100.0/total_words))
65
-
66
- # lets look at the distribution of lengths as well
67
- sent_lengths = {}
68
- for img in imgs:
69
- for sent in img['sentences']:
70
- txt = sent['tokens']
71
- nw = len(txt)
72
- sent_lengths[nw] = sent_lengths.get(nw, 0) + 1
73
- max_len = max(sent_lengths.keys())
74
- print('max length sentence in raw data: ', max_len)
75
- print('sentence length distribution (count, number of words):')
76
- sum_len = sum(sent_lengths.values())
77
- for i in range(max_len+1):
78
- print('%2d: %10d %f%%' % (i, sent_lengths.get(i,0), sent_lengths.get(i,0)*100.0/sum_len))
79
-
80
- # lets now produce the final annotations
81
- if bad_count > 0:
82
- # additional special UNK token we will use below to map infrequent words to
83
- print('inserting the special UNK token')
84
- vocab.append('UNK')
85
-
86
- for img in imgs:
87
- img['final_captions'] = []
88
- for sent in img['sentences']:
89
- txt = sent['tokens']
90
- caption = [w if counts.get(w,0) > count_thr else 'UNK' for w in txt]
91
- img['final_captions'].append(caption)
92
-
93
- return vocab
94
-
95
-
96
- def encode_captions(imgs, params, wtoi):
97
- """
98
- encode all captions into one large array, which will be 1-indexed.
99
- also produces label_start_ix and label_end_ix which store 1-indexed
100
- and inclusive (Lua-style) pointers to the first and last caption for
101
- each image in the dataset.
102
- """
103
-
104
- max_length = params['max_length']
105
- N = len(imgs)
106
- M = sum(len(img['final_captions']) for img in imgs) # total number of captions
107
-
108
- label_arrays = []
109
- label_start_ix = np.zeros(N, dtype='uint32') # note: these will be one-indexed
110
- label_end_ix = np.zeros(N, dtype='uint32')
111
- label_length = np.zeros(M, dtype='uint32')
112
- caption_counter = 0
113
- counter = 1
114
- for i,img in enumerate(imgs):
115
- n = len(img['final_captions'])
116
- assert n > 0, 'error: some image has no captions'
117
-
118
- Li = np.zeros((n, max_length), dtype='uint32')
119
- for j,s in enumerate(img['final_captions']):
120
- label_length[caption_counter] = min(max_length, len(s)) # record the length of this sequence
121
- caption_counter += 1
122
- for k,w in enumerate(s):
123
- if k < max_length:
124
- Li[j,k] = wtoi[w]
125
-
126
- # note: word indices are 1-indexed, and captions are padded with zeros
127
- label_arrays.append(Li)
128
- label_start_ix[i] = counter
129
- label_end_ix[i] = counter + n - 1
130
-
131
- counter += n
132
-
133
- L = np.concatenate(label_arrays, axis=0) # put all the labels together
134
- assert L.shape[0] == M, 'lengths don\'t match? that\'s weird'
135
- assert np.all(label_length > 0), 'error: some caption had no words?'
136
-
137
- print('encoded captions to array of size ', L.shape)
138
- return L, label_start_ix, label_end_ix, label_length
139
-
140
-
141
- def main(params):
142
-
143
- imgs = json.load(open(params['input_json'], 'r'))
144
- imgs = imgs['images']
145
-
146
- seed(123) # make reproducible
147
-
148
- # create the vocab
149
- vocab = build_vocab(imgs, params)
150
- itow = {i+1:w for i,w in enumerate(vocab)} # a 1-indexed vocab translation table
151
- wtoi = {w:i+1 for i,w in enumerate(vocab)} # inverse table
152
-
153
- # encode captions in large arrays, ready to ship to hdf5 file
154
- L, label_start_ix, label_end_ix, label_length = encode_captions(imgs, params, wtoi)
155
-
156
- # create output h5 file
157
- N = len(imgs)
158
- f_lb = h5py.File(params['output_h5']+'_label.h5', "w")
159
- f_lb.create_dataset("labels", dtype='uint32', data=L)
160
- f_lb.create_dataset("label_start_ix", dtype='uint32', data=label_start_ix)
161
- f_lb.create_dataset("label_end_ix", dtype='uint32', data=label_end_ix)
162
- f_lb.create_dataset("label_length", dtype='uint32', data=label_length)
163
- f_lb.close()
164
-
165
- # create output json file
166
- out = {}
167
- out['ix_to_word'] = itow # encode the (1-indexed) vocab
168
- out['images'] = []
169
- for i,img in enumerate(imgs):
170
-
171
- jimg = {}
172
- jimg['split'] = img['split']
173
- if 'filename' in img: jimg['file_path'] = os.path.join(img.get('filepath', ''), img['filename']) # copy it over, might need
174
- if 'cocoid' in img:
175
- jimg['id'] = img['cocoid'] # copy over & mantain an id, if present (e.g. coco ids, useful)
176
- elif 'imgid' in img:
177
- jimg['id'] = img['imgid']
178
-
179
- if params['images_root'] != '':
180
- with Image.open(os.path.join(params['images_root'], img['filepath'], img['filename'])) as _img:
181
- jimg['width'], jimg['height'] = _img.size
182
-
183
- out['images'].append(jimg)
184
-
185
- json.dump(out, open(params['output_json'], 'w'))
186
- print('wrote ', params['output_json'])
187
-
188
- if __name__ == "__main__":
189
-
190
- parser = argparse.ArgumentParser()
191
-
192
- # input json
193
- parser.add_argument('--input_json', required=True, help='input json file to process into hdf5')
194
- parser.add_argument('--output_json', default='data.json', help='output json file')
195
- parser.add_argument('--output_h5', default='data', help='output h5 file')
196
- parser.add_argument('--images_root', default='', help='root location in which images are stored, to be prepended to file_path in input json')
197
-
198
- # options
199
- parser.add_argument('--max_length', default=16, type=int, help='max length of a caption, in number of words. captions longer than this get clipped.')
200
- parser.add_argument('--word_count_threshold', default=5, type=int, help='only words that occur more than this number of times will be put in vocab')
201
-
202
- args = parser.parse_args()
203
- params = vars(args) # convert to ordinary dict
204
- print('parsed input parameters:')
205
- print(json.dumps(params, indent = 2))
206
- main(params)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/prepro_ngrams.py DELETED
@@ -1,94 +0,0 @@
1
- """
2
- Precompute ngram counts of captions, to accelerate cider computation during training time.
3
- """
4
-
5
- import os
6
- import json
7
- import argparse
8
- from six.moves import cPickle
9
- import captioning.utils.misc as utils
10
- from collections import defaultdict
11
-
12
- import sys
13
- sys.path.append("cider")
14
- from pyciderevalcap.ciderD.ciderD_scorer import CiderScorer
15
-
16
-
17
- def get_doc_freq(refs, params):
18
- tmp = CiderScorer(df_mode="corpus")
19
- for ref in refs:
20
- tmp.cook_append(None, ref)
21
- tmp.compute_doc_freq()
22
- return tmp.document_frequency, len(tmp.crefs)
23
-
24
-
25
- def build_dict(imgs, wtoi, params):
26
- wtoi['<eos>'] = 0
27
-
28
- count_imgs = 0
29
-
30
- refs_words = []
31
- refs_idxs = []
32
- for img in imgs:
33
- if (params['split'] == img['split']) or \
34
- (params['split'] == 'train' and img['split'] == 'restval') or \
35
- (params['split'] == 'all'):
36
- #(params['split'] == 'val' and img['split'] == 'restval') or \
37
- ref_words = []
38
- ref_idxs = []
39
- for sent in img['sentences']:
40
- if hasattr(params, 'bpe'):
41
- sent['tokens'] = params.bpe.segment(' '.join(sent['tokens'])).strip().split(' ')
42
- tmp_tokens = sent['tokens'] + ['<eos>']
43
- tmp_tokens = [_ if _ in wtoi else 'UNK' for _ in tmp_tokens]
44
- ref_words.append(' '.join(tmp_tokens))
45
- ref_idxs.append(' '.join([str(wtoi[_]) for _ in tmp_tokens]))
46
- refs_words.append(ref_words)
47
- refs_idxs.append(ref_idxs)
48
- count_imgs += 1
49
- print('total imgs:', count_imgs)
50
-
51
- ngram_words, count_refs = get_doc_freq(refs_words, params)
52
- ngram_idxs, count_refs = get_doc_freq(refs_idxs, params)
53
- print('count_refs:', count_refs)
54
- return ngram_words, ngram_idxs, count_refs
55
-
56
- def main(params):
57
-
58
- imgs = json.load(open(params['input_json'], 'r'))
59
- dict_json = json.load(open(params['dict_json'], 'r'))
60
- itow = dict_json['ix_to_word']
61
- wtoi = {w:i for i,w in itow.items()}
62
-
63
- # Load bpe
64
- if 'bpe' in dict_json:
65
- import tempfile
66
- import codecs
67
- codes_f = tempfile.NamedTemporaryFile(delete=False)
68
- codes_f.close()
69
- with open(codes_f.name, 'w') as f:
70
- f.write(dict_json['bpe'])
71
- with codecs.open(codes_f.name, encoding='UTF-8') as codes:
72
- bpe = apply_bpe.BPE(codes)
73
- params.bpe = bpe
74
-
75
- imgs = imgs['images']
76
-
77
- ngram_words, ngram_idxs, ref_len = build_dict(imgs, wtoi, params)
78
-
79
- utils.pickle_dump({'document_frequency': ngram_words, 'ref_len': ref_len}, open(params['output_pkl']+'-words.p','wb'))
80
- utils.pickle_dump({'document_frequency': ngram_idxs, 'ref_len': ref_len}, open(params['output_pkl']+'-idxs.p','wb'))
81
-
82
- if __name__ == "__main__":
83
-
84
- parser = argparse.ArgumentParser()
85
-
86
- # input json
87
- parser.add_argument('--input_json', default='data/dataset_coco.json', help='input json file to process into hdf5')
88
- parser.add_argument('--dict_json', default='data/cocotalk.json', help='output json file')
89
- parser.add_argument('--output_pkl', default='data/coco-all', help='output pickle file')
90
- parser.add_argument('--split', default='all', help='test, val, train, all')
91
- args = parser.parse_args()
92
- params = vars(args) # convert to ordinary dict
93
-
94
- main(params)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/prepro_reference_json.py DELETED
@@ -1,69 +0,0 @@
1
- # coding: utf-8
2
- """
3
- Create a reference json file used for evaluation with `coco-caption` repo.
4
- Used when reference json is not provided, (e.g., flickr30k, or you have your own split of train/val/test)
5
- """
6
-
7
- from __future__ import absolute_import
8
- from __future__ import division
9
- from __future__ import print_function
10
-
11
- import os
12
- import json
13
- import argparse
14
- import sys
15
- import hashlib
16
- from random import shuffle, seed
17
-
18
-
19
- def main(params):
20
-
21
- imgs = json.load(open(params['input_json'][0], 'r'))['images']
22
- # tmp = []
23
- # for k in imgs.keys():
24
- # for img in imgs[k]:
25
- # img['filename'] = img['image_id'] # k+'/'+img['image_id']
26
- # img['image_id'] = int(
27
- # int(hashlib.sha256(img['image_id']).hexdigest(), 16) % sys.maxint)
28
- # tmp.append(img)
29
- # imgs = tmp
30
-
31
- # create output json file
32
- out = {'info': {'description': 'This is stable 1.0 version of the 2014 MS COCO dataset.', 'url': 'http://mscoco.org', 'version': '1.0', 'year': 2014, 'contributor': 'Microsoft COCO group', 'date_created': '2015-01-27 09:11:52.357475'}, 'licenses': [{'url': 'http://creativecommons.org/licenses/by-nc-sa/2.0/', 'id': 1, 'name': 'Attribution-NonCommercial-ShareAlike License'}, {'url': 'http://creativecommons.org/licenses/by-nc/2.0/', 'id': 2, 'name': 'Attribution-NonCommercial License'}, {'url': 'http://creativecommons.org/licenses/by-nc-nd/2.0/', 'id': 3, 'name': 'Attribution-NonCommercial-NoDerivs License'}, {'url': 'http://creativecommons.org/licenses/by/2.0/', 'id': 4, 'name': 'Attribution License'}, {'url': 'http://creativecommons.org/licenses/by-sa/2.0/', 'id': 5, 'name': 'Attribution-ShareAlike License'}, {'url': 'http://creativecommons.org/licenses/by-nd/2.0/', 'id': 6, 'name': 'Attribution-NoDerivs License'}, {'url': 'http://flickr.com/commons/usage/', 'id': 7, 'name': 'No known copyright restrictions'}, {'url': 'http://www.usa.gov/copyright.shtml', 'id': 8, 'name': 'United States Government Work'}], 'type': 'captions'}
33
- out.update({'images': [], 'annotations': []})
34
-
35
- cnt = 0
36
- empty_cnt = 0
37
- for i, img in enumerate(imgs):
38
- if img['split'] == 'train':
39
- continue
40
- out['images'].append(
41
- {'id': img.get('cocoid', img['imgid'])})
42
- for j, s in enumerate(img['sentences']):
43
- if len(s) == 0:
44
- continue
45
- s = ' '.join(s['tokens'])
46
- out['annotations'].append(
47
- {'image_id': out['images'][-1]['id'], 'caption': s, 'id': cnt})
48
- cnt += 1
49
-
50
- json.dump(out, open(params['output_json'], 'w'))
51
- print('wrote ', params['output_json'])
52
-
53
-
54
- if __name__ == "__main__":
55
-
56
- parser = argparse.ArgumentParser()
57
-
58
- # input json
59
- parser.add_argument('--input_json', nargs='+', required=True,
60
- help='input json file to process into hdf5')
61
- parser.add_argument('--output_json', default='data.json',
62
- help='output json file')
63
-
64
- args = parser.parse_args()
65
- params = vars(args) # convert to ordinary dict
66
- print('parsed input parameters:')
67
- print(json.dumps(params, indent=2))
68
- main(params)
69
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts_FineCapEval/clip_prepro_feats.py DELETED
@@ -1,163 +0,0 @@
1
- """
2
- Preprocess a raw json dataset into features files for use in data_loader.py
3
-
4
- Input: json file that has the form
5
- [{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...]
6
- example element in this list would look like
7
- {'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895}
8
-
9
- This script reads this json, does some basic preprocessing on the captions
10
- (e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays
11
-
12
- Output: two folders of features
13
- """
14
-
15
- from __future__ import absolute_import
16
- from __future__ import division
17
- from __future__ import print_function
18
-
19
- import os
20
- import json
21
- import argparse
22
- from random import shuffle, seed
23
- import string
24
- # non-standard dependencies:
25
- import h5py
26
- from six.moves import cPickle
27
- import numpy as np
28
- import torch
29
- import torchvision.models as models
30
- import skimage.io
31
-
32
- from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
33
- from PIL import Image
34
- from torch import nn
35
-
36
- preprocess = Compose([
37
- Resize((448, 448), interpolation=Image.BICUBIC),
38
- CenterCrop((448, 448)),
39
- ToTensor()
40
- ])
41
-
42
-
43
- from clip.clip import load
44
- from timm.models.vision_transformer import resize_pos_embed
45
- import timm
46
-
47
- from captioning.utils.resnet_utils import myResnet
48
- import captioning.utils.resnet as resnet
49
-
50
- from tqdm import tqdm
51
-
52
-
53
- def main(params):
54
- if params["model_type"] != 'vit_base_patch32_224_in21k':
55
- model, transform = load(params["model_type"], jit=False)
56
- else:
57
- model = timm.create_model(params["model_type"], pretrained=True)
58
- model = model.cuda()
59
-
60
- if params["model_type"] != 'vit_base_patch32_224_in21k':
61
- save_model_type = params["model_type"].split("-")[0]
62
- mean = torch.Tensor([0.48145466, 0.4578275, 0.40821073]).to("cuda").reshape(3, 1, 1)
63
- std = torch.Tensor([0.26862954, 0.26130258, 0.27577711]).to("cuda").reshape(3, 1, 1)
64
-
65
- if "RN" in params["model_type"]:
66
- num_patches = 196 #600 * 1000 // 32 // 32
67
- pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, model.visual.attnpool.positional_embedding.shape[-1], device='cuda'),)
68
- pos_embed.weight = resize_pos_embed(model.visual.attnpool.positional_embedding.unsqueeze(0), pos_embed)
69
- model.visual.attnpool.positional_embedding = pos_embed
70
-
71
- else:
72
- save_model_type = 'vit_base'
73
- mean = torch.Tensor([0.5, 0.5, 0.5]).to("cuda").reshape(3, 1, 1)
74
- std = torch.Tensor([0.5, 0.5, 0.5]).to("cuda").reshape(3, 1, 1)
75
-
76
- num_patches = 196 #600 * 1000 // 32 // 32
77
- pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, 768, device='cuda'),)
78
- pos_embed.weight = resize_pos_embed(model.pos_embed, pos_embed)
79
- model.pos_embed = pos_embed
80
-
81
- if params["model_type"] == "ViT-B/32":
82
- num_patches = 196 #600 * 1000 // 32 // 32
83
- pos_embed = nn.Parameter(torch.zeros(num_patches + 1, 768, device='cuda'),)
84
- pos_embed.weight = resize_pos_embed(model.visual.positional_embedding.unsqueeze(0), pos_embed.unsqueeze(0))
85
- model.visual.positional_embedding = pos_embed
86
- imgs = json.load(open(params['input_json'], 'r'))
87
- imgs = imgs['images']
88
- N = len(imgs)
89
-
90
- seed(123) # make reproducible
91
-
92
- dir_fc = params['output_dir']+'_clip_'+save_model_type+'_fc'
93
- dir_att = params['output_dir']+'_clip_'+save_model_type+'_att'
94
- if not os.path.isdir(dir_fc):
95
- os.mkdir(dir_fc)
96
- if not os.path.isdir(dir_att):
97
- os.mkdir(dir_att)
98
-
99
- for i, img in enumerate(tqdm(imgs)):
100
- with torch.no_grad():
101
-
102
- # img_path = os.path.join(params['images_root'], img['filepath'], img['filename'])
103
- # img_path = os.path.join(params['images_root'], img['file_name'])
104
-
105
- img_path = os.path.join(params['images_root'], img['file_path'])
106
-
107
- image = preprocess(Image.open( img_path ).convert("RGB"))
108
- image = torch.tensor(np.stack([image])).cuda()
109
- image -= mean
110
- image /= std
111
- if "RN" in params["model_type"]:
112
- tmp_att, tmp_fc = model.encode_image(image)
113
- tmp_att = tmp_att[0].permute(1, 2, 0)
114
- tmp_fc = tmp_fc[0]
115
- elif params["model_type"] == 'vit_base_patch32_224_in21k':
116
- x = model(image)
117
- tmp_fc = x[0, 0, :]
118
- tmp_att = x[0, 1:, :].reshape( 14, 14, 768 )
119
- else:
120
- x = model.visual.conv1(image.half()) # shape = [*, width, grid, grid]
121
- x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
122
- x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
123
- x = torch.cat([model.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
124
- x = x + model.visual.positional_embedding.to(x.dtype)[:x.shape[1], :]
125
- x = model.visual.ln_pre(x)
126
-
127
- x = x.permute(1, 0, 2) # NLD -> LND
128
-
129
- for layer_idx, layer in enumerate(model.visual.transformer.resblocks):
130
- x = layer(x)
131
-
132
- x = x.permute(1, 0, 2)
133
- tmp_fc = x[0, 0, :]
134
- tmp_att = x[0, 1:, :].reshape( 14, 14, 768 )
135
-
136
- # np.save(os.path.join(dir_fc, str(img['cocoid'])), tmp_fc.data.cpu().float().numpy())
137
- # np.savez_compressed(os.path.join(dir_att, str(img['cocoid'])), feat=tmp_att.data.cpu().float().numpy())
138
- np.save(os.path.join(dir_fc, str(img['id'])), tmp_fc.data.cpu().float().numpy())
139
- np.savez_compressed(os.path.join(dir_att, str(img['id'])), feat=tmp_att.data.cpu().float().numpy())
140
-
141
-
142
- # if i % 1000 == 0:
143
- # print('processing %d/%d (%.2f%% done)' % (i, N, i*100.0/N))
144
- print('wrote ', dir_fc, dir_att)
145
-
146
- if __name__ == "__main__":
147
-
148
- parser = argparse.ArgumentParser()
149
-
150
- # input json
151
- parser.add_argument('--input_json', required=True, help='input json file to process into hdf5')
152
- parser.add_argument('--output_dir', default='data', help='output h5 file')
153
-
154
- # options
155
- parser.add_argument('--images_root', default='', help='root location in which images are stored, to be prepended to file_path in input json')
156
- parser.add_argument('--att_size', default=14, type=int, help='14x14 or 7x7')
157
- parser.add_argument('--model_type', default='RN50', type=str, help='RN50, RN101, RN50x4, ViT-B/32, vit_base_patch32_224_in21k')
158
-
159
- args = parser.parse_args()
160
- params = vars(args) # convert to ordinary dict
161
- print('parsed input parameters:')
162
- print(json.dumps(params, indent = 2))
163
- main(params)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts_FineCapEval/clipscore_prepro_feats.py DELETED
@@ -1,154 +0,0 @@
1
- """
2
- Preprocess a raw json dataset into features files for use in data_loader.py
3
-
4
- Input: json file that has the form
5
- [{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...]
6
- example element in this list would look like
7
- {'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895}
8
-
9
- This script reads this json, does some basic preprocessing on the captions
10
- (e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays
11
-
12
- Output: two folders of features
13
- """
14
-
15
- from __future__ import absolute_import
16
- from __future__ import division
17
- from __future__ import print_function
18
-
19
- import os
20
- import json
21
- import argparse
22
- from random import shuffle, seed
23
- import string
24
- # non-standard dependencies:
25
- import h5py
26
- from six.moves import cPickle
27
- import numpy as np
28
- import torch
29
- import torchvision.models as models
30
- import skimage.io
31
-
32
- from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
33
- from PIL import Image
34
- from torch import nn
35
-
36
- # preprocess = Compose([
37
- # Resize((448, 448), interpolation=Image.BICUBIC),
38
- # CenterCrop((448, 448)),
39
- # ToTensor()
40
- # ])
41
-
42
-
43
- # from clip.clip import load
44
- # from timm.models.vision_transformer import resize_pos_embed
45
- # import timm
46
-
47
- # from captioning.utils.resnet_utils import myResnet
48
- # import captioning.utils.resnet as resnet
49
-
50
- from captioning.utils.clipscore import CLIPScore
51
-
52
- from tqdm import tqdm
53
-
54
-
55
- def main(params):
56
-
57
- clipscore_model = CLIPScore()
58
- clipscore_model.to('cuda')
59
-
60
- imgs = json.load(open(params['input_json'], 'r'))
61
- imgs = imgs['images']
62
- N = len(imgs)
63
-
64
- seed(123) # make reproducible
65
-
66
- # dir_fc = params['output_dir']+'_clip_'+save_model_type+'_fc'
67
- # dir_att = params['output_dir']+'_clip_'+save_model_type+'_att'
68
-
69
- vis_dir_fc = params['output_dir']+'_clipscore_vis'
70
- if not os.path.isdir(vis_dir_fc):
71
- os.mkdir(vis_dir_fc)
72
-
73
- # text_dir_fc = params['output_dir']+'_clipscore_text'
74
- # if not os.path.isdir(text_dir_fc):
75
- # os.mkdir(text_dir_fc)
76
-
77
- # if not os.path.isdir(dir_att):
78
- # os.mkdir(dir_att)
79
-
80
- for i,img in enumerate(tqdm(imgs)):
81
- # load the image
82
-
83
- # img_path = os.path.join(params['images_root'], img['filepath'], img['filename'])
84
- # img_path = os.path.join(params['images_root'], img['file_name'])
85
- img_path = os.path.join(params['images_root'], img['file_path'])
86
-
87
- img_feat = clipscore_model.image_extract(img_path)
88
- img_feat = img_feat.view(512)
89
-
90
- # for d in img['sentences']:
91
- # text = d['raw'].strip()
92
- # text_feat = clipscore_model.text_extract(text)
93
-
94
-
95
- # with torch.no_grad():
96
-
97
- # image = preprocess(Image.open(os.path.join(params['images_root'], img['filepath'], img['filename']) ).convert("RGB"))
98
- # image = torch.tensor(np.stack([image])).cuda()
99
- # image -= mean
100
- # image /= std
101
- # if "RN" in params["model_type"]:
102
- # tmp_att, tmp_fc = model.encode_image(image)
103
- # tmp_att = tmp_att[0].permute(1, 2, 0)
104
- # tmp_fc = tmp_fc[0]
105
- # elif params["model_type"] == 'vit_base_patch32_224_in21k':
106
- # x = model(image)
107
- # tmp_fc = x[0, 0, :]
108
- # tmp_att = x[0, 1:, :].reshape( 14, 14, 768 )
109
- # else:
110
- # x = model.visual.conv1(image.half()) # shape = [*, width, grid, grid]
111
- # x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
112
- # x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
113
- # x = torch.cat([model.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
114
- # x = x + model.visual.positional_embedding.to(x.dtype)[:x.shape[1], :]
115
- # x = model.visual.ln_pre(x)
116
-
117
- # x = x.permute(1, 0, 2) # NLD -> LND
118
-
119
- # for layer_idx, layer in enumerate(model.visual.transformer.resblocks):
120
- # x = layer(x)
121
-
122
- # x = x.permute(1, 0, 2)
123
- # tmp_fc = x[0, 0, :]
124
- # tmp_att = x[0, 1:, :].reshape( 14, 14, 768 )
125
-
126
- np.save(os.path.join(vis_dir_fc, str(img['id'])), img_feat.data.cpu().float().numpy())
127
- # np.save(os.path.join(text_dir_fc, str(img['cocoid'])), tmp_fc.data.cpu().float().numpy())
128
-
129
-
130
- # np.savez_compressed(os.path.join(dir_att, str(img['cocoid'])), feat=tmp_att.data.cpu().float().numpy())
131
-
132
- # if i % 1000 == 0:
133
- # print('processing %d/%d (%.2f%% done)' % (i, N, i*100.0/N))
134
- print('wrote ', vis_dir_fc)
135
-
136
- if __name__ == "__main__":
137
-
138
- parser = argparse.ArgumentParser()
139
-
140
- # input json
141
- # dataset_coco.json
142
- parser.add_argument('--input_json', required=True, help='input json file to process into hdf5')
143
- parser.add_argument('--output_dir', default='data', help='output h5 file')
144
-
145
- # options
146
- parser.add_argument('--images_root', default='', help='root location in which images are stored, to be prepended to file_path in input json')
147
- # parser.add_argument('--att_size', default=14, type=int, help='14x14 or 7x7')
148
- # parser.add_argument('--model_type', default='RN50', type=str, help='RN50, RN101, RN50x4, ViT-B/32, vit_base_patch32_224_in21k')
149
-
150
- args = parser.parse_args()
151
- params = vars(args) # convert to ordinary dict
152
- print('parsed input parameters:')
153
- print(json.dumps(params, indent = 2))
154
- main(params)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts_FineCapEval/prepro_labels.py DELETED
@@ -1,209 +0,0 @@
1
- """
2
- Preprocess a raw json dataset into hdf5/json files for use in data_loader.py
3
-
4
- Input: json file that has the form
5
- [{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...]
6
- example element in this list would look like
7
- {'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895}
8
-
9
- This script reads this json, does some basic preprocessing on the captions
10
- (e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays
11
-
12
- Output: a json file and an hdf5 file
13
- The hdf5 file contains several fields:
14
- /labels is (M,max_length) uint32 array of encoded labels, zero padded
15
- /label_start_ix and /label_end_ix are (N,) uint32 arrays of pointers to the
16
- first and last indices (in range 1..M) of labels for each image
17
- /label_length stores the length of the sequence for each of the M sequences
18
-
19
- The json file has a dict that contains:
20
- - an 'ix_to_word' field storing the vocab in form {ix:'word'}, where ix is 1-indexed
21
- - an 'images' field that is a list holding auxiliary information for each image,
22
- such as in particular the 'split' it was assigned to.
23
- """
24
-
25
- from __future__ import absolute_import
26
- from __future__ import division
27
- from __future__ import print_function
28
-
29
- import os
30
- import json
31
- import argparse
32
- from random import shuffle, seed
33
- import string
34
- # non-standard dependencies:
35
- import h5py
36
- import numpy as np
37
- import torch
38
- import torchvision.models as models
39
- import skimage.io
40
- from PIL import Image
41
-
42
-
43
- def build_vocab(imgs, params):
44
- count_thr = params['word_count_threshold']
45
-
46
- # count up the number of words
47
- counts = {}
48
- for img in imgs:
49
- for sent in img['sentences']:
50
- for w in sent['tokens']:
51
- counts[w] = counts.get(w, 0) + 1
52
- cw = sorted([(count,w) for w,count in counts.items()], reverse=True)
53
- print('top words and their counts:')
54
- print('\n'.join(map(str,cw[:20])))
55
-
56
- # print some stats
57
- total_words = sum(counts.values())
58
- print('total words:', total_words)
59
- bad_words = [w for w,n in counts.items() if n <= count_thr]
60
- vocab = [w for w,n in counts.items() if n > count_thr]
61
- bad_count = sum(counts[w] for w in bad_words)
62
- print('number of bad words: %d/%d = %.2f%%' % (len(bad_words), len(counts), len(bad_words)*100.0/len(counts)))
63
- print('number of words in vocab would be %d' % (len(vocab), ))
64
- print('number of UNKs: %d/%d = %.2f%%' % (bad_count, total_words, bad_count*100.0/total_words))
65
-
66
- # lets look at the distribution of lengths as well
67
- sent_lengths = {}
68
- for img in imgs:
69
- for sent in img['sentences']:
70
- txt = sent['tokens']
71
- nw = len(txt)
72
- sent_lengths[nw] = sent_lengths.get(nw, 0) + 1
73
- max_len = max(sent_lengths.keys())
74
- print('max length sentence in raw data: ', max_len)
75
- print('sentence length distribution (count, number of words):')
76
- sum_len = sum(sent_lengths.values())
77
- for i in range(max_len+1):
78
- print('%2d: %10d %f%%' % (i, sent_lengths.get(i,0), sent_lengths.get(i,0)*100.0/sum_len))
79
-
80
- # lets now produce the final annotations
81
- if bad_count > 0:
82
- # additional special UNK token we will use below to map infrequent words to
83
- print('inserting the special UNK token')
84
- vocab.append('UNK')
85
-
86
- for img in imgs:
87
- img['final_captions'] = []
88
- for sent in img['sentences']:
89
- txt = sent['tokens']
90
- caption = [w if counts.get(w,0) > count_thr else 'UNK' for w in txt]
91
- img['final_captions'].append(caption)
92
-
93
- return vocab
94
-
95
-
96
- def encode_captions(imgs, params, wtoi):
97
- """
98
- encode all captions into one large array, which will be 1-indexed.
99
- also produces label_start_ix and label_end_ix which store 1-indexed
100
- and inclusive (Lua-style) pointers to the first and last caption for
101
- each image in the dataset.
102
- """
103
-
104
- max_length = params['max_length']
105
- N = len(imgs)
106
- M = sum(len(img['final_captions']) for img in imgs) # total number of captions
107
-
108
- label_arrays = []
109
- label_start_ix = np.zeros(N, dtype='uint32') # note: these will be one-indexed
110
- label_end_ix = np.zeros(N, dtype='uint32')
111
- label_length = np.zeros(M, dtype='uint32')
112
- caption_counter = 0
113
- counter = 1
114
- for i,img in enumerate(imgs):
115
- n = len(img['final_captions'])
116
- assert n > 0, 'error: some image has no captions'
117
-
118
- Li = np.zeros((n, max_length), dtype='uint32')
119
- for j,s in enumerate(img['final_captions']):
120
- label_length[caption_counter] = min(max_length, len(s)) # record the length of this sequence
121
- caption_counter += 1
122
- for k,w in enumerate(s):
123
- if k < max_length:
124
- Li[j,k] = wtoi[w]
125
-
126
- # note: word indices are 1-indexed, and captions are padded with zeros
127
- label_arrays.append(Li)
128
- label_start_ix[i] = counter
129
- label_end_ix[i] = counter + n - 1
130
-
131
- counter += n
132
-
133
- L = np.concatenate(label_arrays, axis=0) # put all the labels together
134
- assert L.shape[0] == M, 'lengths don\'t match? that\'s weird'
135
- assert np.all(label_length > 0), 'error: some caption had no words?'
136
-
137
- print('encoded captions to array of size ', L.shape)
138
- return L, label_start_ix, label_end_ix, label_length
139
-
140
-
141
- def main(params):
142
-
143
- imgs = json.load(open(params['input_json'], 'r'))
144
- imgs = imgs['images']
145
-
146
- seed(123) # make reproducible
147
-
148
- # # create the vocab
149
- # vocab = build_vocab(imgs, params)
150
- # itow = {i+1:w for i,w in enumerate(vocab)} # a 1-indexed vocab translation table
151
- # wtoi = {w:i+1 for i,w in enumerate(vocab)} # inverse table
152
-
153
- itow = imgs['ix_to_word']
154
- wtoi = {w:i for i, w in itow.items()}
155
-
156
- # encode captions in large arrays, ready to ship to hdf5 file
157
- L, label_start_ix, label_end_ix, label_length = encode_captions(imgs, params, wtoi)
158
-
159
- # create output h5 file
160
- N = len(imgs)
161
- f_lb = h5py.File(params['output_h5']+'_label.h5', "w")
162
- f_lb.create_dataset("labels", dtype='uint32', data=L)
163
- f_lb.create_dataset("label_start_ix", dtype='uint32', data=label_start_ix)
164
- f_lb.create_dataset("label_end_ix", dtype='uint32', data=label_end_ix)
165
- f_lb.create_dataset("label_length", dtype='uint32', data=label_length)
166
- f_lb.close()
167
-
168
- # create output json file
169
- out = {}
170
- out['ix_to_word'] = itow # encode the (1-indexed) vocab
171
- out['images'] = []
172
- for i,img in enumerate(imgs):
173
-
174
- jimg = {}
175
- jimg['split'] = img['split']
176
- if 'filename' in img: jimg['file_path'] = os.path.join(img.get('filepath', ''), img['filename']) # copy it over, might need
177
- if 'cocoid' in img:
178
- jimg['id'] = img['cocoid'] # copy over & mantain an id, if present (e.g. coco ids, useful)
179
- elif 'imgid' in img:
180
- jimg['id'] = img['imgid']
181
-
182
- if params['images_root'] != '':
183
- with Image.open(os.path.join(params['images_root'], img['filepath'], img['filename'])) as _img:
184
- jimg['width'], jimg['height'] = _img.size
185
-
186
- out['images'].append(jimg)
187
-
188
- json.dump(out, open(params['output_json'], 'w'))
189
- print('wrote ', params['output_json'])
190
-
191
- if __name__ == "__main__":
192
-
193
- parser = argparse.ArgumentParser()
194
-
195
- # input json
196
- parser.add_argument('--input_json', required=True, help='input json file to process into hdf5')
197
- parser.add_argument('--output_json', default='data.json', help='output json file')
198
- parser.add_argument('--output_h5', default='data', help='output h5 file')
199
- parser.add_argument('--images_root', default='', help='root location in which images are stored, to be prepended to file_path in input json')
200
-
201
- # options
202
- parser.add_argument('--max_length', default=16, type=int, help='max length of a caption, in number of words. captions longer than this get clipped.')
203
- parser.add_argument('--word_count_threshold', default=5, type=int, help='only words that occur more than this number of times will be put in vocab')
204
-
205
- args = parser.parse_args()
206
- params = vars(args) # convert to ordinary dict
207
- print('parsed input parameters:')
208
- print(json.dumps(params, indent = 2))
209
- main(params)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tools/eval.py DELETED
@@ -1,125 +0,0 @@
1
- from __future__ import absolute_import
2
- from __future__ import division
3
- from __future__ import print_function
4
-
5
- import json
6
- import numpy as np
7
-
8
- import time
9
- import os
10
- from six.moves import cPickle
11
-
12
- import captioning.utils.opts as opts
13
- import captioning.models as models
14
- from captioning.data.dataloader import *
15
- # from captioning.data.dataloaderraw import *
16
- import captioning.utils.eval_utils as eval_utils
17
- import argparse
18
- import captioning.utils.misc as utils
19
- import captioning.modules.losses as losses
20
- import torch
21
-
22
- # Input arguments and options
23
- parser = argparse.ArgumentParser()
24
- # Input paths
25
- parser.add_argument('--model', type=str, default='',
26
- help='path to model to evaluate')
27
- parser.add_argument('--cnn_model', type=str, default='resnet101',
28
- help='resnet101, resnet152')
29
- parser.add_argument('--infos_path', type=str, default='',
30
- help='path to infos to evaluate')
31
- parser.add_argument('--only_lang_eval', type=int, default=0,
32
- help='lang eval on saved results')
33
- parser.add_argument('--force', type=int, default=0,
34
- help='force to evaluate no matter if there are results available')
35
- parser.add_argument('--device', type=str, default='cuda',
36
- help='cpu or cuda')
37
- opts.add_eval_options(parser)
38
- opts.add_diversity_opts(parser)
39
- opt = parser.parse_args()
40
-
41
- # Load infos
42
- with open(opt.infos_path, 'rb') as f:
43
- infos = utils.pickle_load(f)
44
-
45
- # override and collect parameters
46
- replace = ['input_fc_dir', 'input_att_dir', 'input_box_dir', 'input_label_h5', 'input_json', 'batch_size', 'id']
47
- ignore = ['start_from']
48
-
49
- for k in vars(infos['opt']).keys():
50
- if k in replace:
51
- setattr(opt, k, getattr(opt, k) or getattr(infos['opt'], k, ''))
52
- elif k not in ignore:
53
- if not k in vars(opt):
54
- vars(opt).update({k: vars(infos['opt'])[k]}) # copy over options from model
55
-
56
- vocab = infos['vocab'] # ix -> word mapping
57
-
58
- pred_fn = os.path.join('eval_results/', '.saved_pred_'+ opt.id + '_' + opt.split + '.pth')
59
- result_fn = os.path.join('eval_results/', opt.id + '_' + opt.split + '.json')
60
-
61
- if opt.only_lang_eval == 1 or (not opt.force and os.path.isfile(pred_fn)):
62
- # if results existed, then skip, unless force is on
63
- if not opt.force:
64
- try:
65
- if os.path.isfile(result_fn):
66
- print(result_fn)
67
- json.load(open(result_fn, 'r'))
68
- print('already evaluated')
69
- os._exit(0)
70
- except:
71
- pass
72
-
73
- predictions, n_predictions = torch.load(pred_fn)
74
- lang_stats = eval_utils.language_eval(opt.input_json, predictions, n_predictions, vars(opt), opt.split)
75
- print(lang_stats)
76
- os._exit(0)
77
-
78
- # At this point only_lang_eval if 0
79
- if not opt.force:
80
- # Check out if
81
- try:
82
- # if no pred exists, then continue
83
- tmp = torch.load(pred_fn)
84
- # if language_eval == 1, and no pred exists, then continue
85
- if opt.language_eval == 1:
86
- json.load(open(result_fn, 'r'))
87
- print('Result is already there')
88
- os._exit(0)
89
- except:
90
- pass
91
-
92
- # Setup the model
93
- opt.vocab = vocab
94
- model = models.setup(opt)
95
- del opt.vocab
96
- model.load_state_dict(torch.load(opt.model, map_location='cpu'))
97
- model.to(opt.device)
98
- model.eval()
99
- crit = losses.LanguageModelCriterion()
100
-
101
- # Create the Data Loader instance
102
- if len(opt.image_folder) == 0:
103
- loader = DataLoader(opt)
104
- else:
105
- loader = DataLoaderRaw({'folder_path': opt.image_folder,
106
- 'coco_json': opt.coco_json,
107
- 'batch_size': opt.batch_size,
108
- 'cnn_model': opt.cnn_model})
109
- # When eval using provided pretrained model, the vocab may be different from what you have in your cocotalk.json
110
- # So make sure to use the vocab in infos file.
111
- loader.dataset.ix_to_word = infos['vocab']
112
-
113
-
114
- # Set sample options
115
- opt.dataset = opt.input_json
116
- loss, split_predictions, lang_stats = eval_utils.eval_split(model, crit, loader,
117
- vars(opt))
118
-
119
- print('loss: ', loss)
120
- if lang_stats:
121
- print(lang_stats)
122
-
123
- if opt.dump_json == 1:
124
- # dump the json
125
- json.dump(split_predictions, open('vis/vis.json', 'w'))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tools/eval_clip_retrieval.py DELETED
@@ -1,231 +0,0 @@
1
-
2
- from PIL import Image
3
- # import requests
4
-
5
- from transformers import CLIPProcessor, CLIPModel
6
-
7
- import torch
8
- from torch.utils.data import DataLoader, Dataset
9
-
10
- from pathlib import Path
11
- from tqdm import tqdm
12
- import json
13
- import argparse
14
- import numpy as np
15
-
16
- class COCODataset(Dataset):
17
- def __init__(self,
18
- coco_root="/nas-ssd/jmincho/datasets/COCO/",
19
- gen_caption_path=None,
20
- is_gt=True):
21
- super().__init__()
22
-
23
- self.coco_root = Path(coco_root)
24
-
25
- self.image_dir = self.coco_root.joinpath('images/val2014')
26
-
27
- if is_gt:
28
- print("Loading karpathy splits")
29
- data_info_path = self.coco_root.joinpath('dataset_coco.json')
30
- with open(data_info_path) as f:
31
- karpathy_data = json.load(f)
32
-
33
- data = []
34
- for datum in karpathy_data['images']:
35
- # karpathy test split
36
- if datum['split'] == 'test':
37
- img_id = datum['filename'].split('.')[0]
38
- new_datum = {
39
- 'img_id': img_id,
40
- 'captions': [d['raw'].strip() for d in datum['sentences']],
41
- }
42
- data.append(new_datum)
43
- else:
44
- print("Loading generated captions")
45
- gen_caption_path = Path(gen_caption_path)
46
- with open(gen_caption_path) as f:
47
- # karpathy_data = json.load(f)
48
- imgTogen_results = json.load(f)['imgToEval']
49
- data = []
50
- for img_id, img_data in imgTogen_results.items():
51
- new_datum = {
52
- 'img_id': img_id,
53
- 'captions': [img_data['caption']],
54
- }
55
- data.append(new_datum)
56
-
57
- self.data = data
58
- print('# images:', len(self.data))
59
-
60
- self.img_transform = processor.feature_extractor
61
- self.tokenizer = processor.tokenizer
62
-
63
- def __len__(self):
64
- return len(self.data)
65
-
66
- def __getitem__(self, idx):
67
- datum = self.data[idx]
68
- img_id = datum['img_id']
69
- if 'COCO' not in img_id:
70
- img_id = f'COCO_val2014_{str(img_id).zfill(12)}'
71
- img_fname = f"{img_id}.jpg"
72
- # COCO_val2014_000000522418.jpg
73
- img_path = self.image_dir.joinpath(img_fname)
74
- img = Image.open(img_path).convert("RGB")
75
-
76
- # take first caption
77
- caption = datum['captions'][0]
78
-
79
- return {
80
- "img": img,
81
- "caption": caption,
82
- }
83
-
84
- def collate_fn(self, datum_list):
85
- B = len(datum_list)
86
- imgs = [datum['img'] for datum in datum_list]
87
- images = self.img_transform(imgs, return_tensors="pt")
88
-
89
- captions = [datum['caption'] for datum in datum_list]
90
-
91
- text_tokens = self.tokenizer(captions, return_tensors="pt", padding=True)
92
- batch = {
93
- 'images': images,
94
- 'captions': text_tokens,
95
- }
96
- return batch
97
-
98
-
99
- def compute_similarity(image_features, text_features, bs = 1000):
100
- # compute similarity
101
- max_pairs = image_features.shape[0]
102
- similarity_scores = torch.zeros(max_pairs, max_pairs)
103
- for v in range(0, max_pairs, bs):
104
- for t in range(0, max_pairs, bs):
105
- # print('Processing Visual '+str(v)+' Text '+str(t), end='\r')
106
- batch_visual_emb = image_features[v:v+bs]
107
- batch_caption_emb = text_features[t:t+bs]
108
-
109
- logits = batch_visual_emb @ batch_caption_emb.t()
110
- similarity_scores[v:v+bs,t:t+bs] = logits
111
-
112
- print('Done similarity')
113
- return similarity_scores
114
-
115
- def compute_retrieval(a2b_sims, return_ranks=True):
116
- """
117
- Args:
118
- a2b_sims: Result of computing similarity between two sets of embeddings (emb1 @ emb2.T)
119
- with shape (num_datapoints, num_datapoints).
120
-
121
- Returns:
122
- Retrieval metrics for that similarity.
123
- """
124
- npts = a2b_sims.shape[0]
125
- ranks = np.zeros(npts)
126
- top1 = np.zeros(npts)
127
- # loop source embedding indices
128
- for index in range(npts):
129
- # get order of similarities to target embeddings
130
- inds = np.argsort(a2b_sims[index])[::-1]
131
- # find where the correct embedding is ranked
132
- where = np.where(inds == index)
133
- rank = where[0][0]
134
- ranks[index] = rank
135
- # save the top1 result as well
136
- top1[index] = inds[0]
137
-
138
- # Compute metrics
139
- r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
140
- r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
141
- r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
142
- r50 = 100.0 * len(np.where(ranks < 50)[0]) / len(ranks)
143
- medr = np.floor(np.median(ranks)) + 1
144
- meanr = ranks.mean() + 1
145
-
146
- report_dict = {"r1": r1, "r5": r5, "r10": r10, "r50": r50, "medr": medr, "meanr": meanr, "sum": r1 + r5 + r10}
147
-
148
- if return_ranks:
149
- return report_dict, (ranks, top1)
150
- else:
151
- return report_dict
152
-
153
-
154
- if __name__ == '__main__':
155
-
156
- parser = argparse.ArgumentParser()
157
- parser.add_argument('--coco_root', type=str, default="/nas-ssd/jmincho/datasets/COCO/")
158
- parser.add_argument('--gt', action='store_true')
159
- parser.add_argument('--gen_caption_path', type=str, default="./eval_results/clipRN50_cider_test.json")
160
- args = parser.parse_args()
161
-
162
- model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
163
- processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
164
-
165
- device = "cuda"
166
- model = model.to(device)
167
- model.eval()
168
- print(f"Loaded CLIP at {device}")
169
-
170
- batch_size = 1000
171
-
172
- dataset = COCODataset(
173
- coco_root="/nas-ssd/jmincho/datasets/COCO/",
174
- gen_caption_path=args.gen_caption_path,
175
- is_gt=args.gt
176
- )
177
- data_loader = DataLoader(
178
- dataset,
179
- batch_size=batch_size,
180
- collate_fn=dataset.collate_fn,
181
- shuffle=False,
182
- num_workers=8)
183
-
184
- # fwd all samples
185
- image_features = []
186
- text_features = []
187
- for batch_idx, batch in enumerate(tqdm(data_loader)):
188
- # print('Evaluating batch {}/{}'.format(batch_idx, len(data_loader)), end="\r")
189
- # images, texts = batch
190
-
191
- with torch.no_grad():
192
- images = batch["images"].to(device)
193
- texts = batch["captions"].to(device)
194
-
195
- vision_outputs = model.vision_model(**batch['images'])
196
- text_outputs = model.text_model(**batch['captions'])
197
-
198
- image_embeds = vision_outputs[1]
199
- image_embeds = model.visual_projection(image_embeds)
200
-
201
- text_embeds = text_outputs[1]
202
- text_embeds = model.text_projection(text_embeds)
203
-
204
- # normalized features
205
- image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)
206
- text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
207
-
208
- text_features.append(text_embeds.detach().cpu())
209
- image_features.append(image_embeds.detach().cpu())
210
-
211
- image_features = torch.cat(image_features, 0)
212
- text_features = torch.cat(text_features, 0)
213
- print('Done forward')
214
-
215
- # normalized features
216
- image_features = image_features / image_features.norm(dim=-1, keepdim=True)
217
- text_features = text_features / text_features.norm(dim=-1, keepdim=True)
218
-
219
- # if not single_caption:
220
- # for cap_idx in range(text_features.shape[1]):
221
- # similarity_scores = compute_similarity(image_features, text_features[:,cap_idx,:])
222
- # i2t_dict = compute_retrieval(similarity_scores.numpy())
223
- # t2i_dict = compute_retrieval(similarity_scores.t().numpy())
224
- # print(cap_idx, 'i2t', i2t_dict)
225
- # print(cap_idx, 't2i', t2i_dict)
226
- # else:
227
- similarity_scores = compute_similarity(image_features, text_features)
228
- i2t_dict = compute_retrieval(similarity_scores.numpy())
229
- t2i_dict = compute_retrieval(similarity_scores.t().numpy())
230
- print('i2t', i2t_dict)
231
- print('t2i', t2i_dict)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tools/eval_finecapeval.py DELETED
@@ -1,204 +0,0 @@
1
-
2
- from tqdm import tqdm
3
- from pprint import pprint
4
- import pandas as pd
5
- import argparse
6
- import re
7
- import json
8
- import nltk
9
- from nltk.tokenize import word_tokenize
10
- from nltk.stem.porter import PorterStemmer
11
- p_stemmer = PorterStemmer()
12
-
13
- # nltk.download('punkt')
14
- # nltk.download('wordnet')
15
- # nltk.download('stopwords')
16
-
17
- import language_evaluation
18
- evaluator = language_evaluation.CocoEvaluator()
19
-
20
-
21
- def nltk_process(text):
22
- # Tokenization
23
- nltk_tokenList = word_tokenize(text)
24
-
25
- # Stemming
26
- nltk_stemedList = []
27
- for word in nltk_tokenList:
28
- nltk_stemedList.append(p_stemmer.stem(word))
29
-
30
- filtered_sentence = nltk_stemedList
31
-
32
- # Removing Punctuation
33
-
34
- tokens = [re.sub(r'[^a-zA-Z0-9]', '', tok) for tok in filtered_sentence]
35
-
36
- text = " ".join(tokens)
37
-
38
- return text
39
-
40
-
41
- def calculate_finegrained_scores(pred_id2sent, id2caption, use_coco_eval=False):
42
- if use_coco_eval:
43
- n_total = 0
44
- refs = []
45
- hyps = []
46
- for id, gt_captions in id2caption.items():
47
- pred_sent = pred_id2sent[id]
48
-
49
- refs.append(gt_captions)
50
- hyps.append(pred_sent)
51
-
52
- n_total += 1
53
-
54
- print('caption')
55
- results = evaluator.run_evaluation(hyps, refs)
56
- pprint(results)
57
-
58
- n_total = 0
59
- total_score = 0
60
- for id, gt_phrases in id2background.items():
61
- pred_sent = pred_id2sent[id]
62
-
63
- score = 0
64
- n_phrases = len(gt_phrases)
65
-
66
- for gt_phrase in gt_phrases:
67
- word_score = 0
68
- for gt_word in gt_phrase.split():
69
- if gt_word in pred_sent:
70
- word_score += 1
71
- if len(gt_phrase.split()) > 0:
72
- score += word_score / len(gt_phrase.split())
73
-
74
- if n_phrases > 0:
75
- score /= n_phrases
76
-
77
- total_score += score
78
- n_total += 1
79
- print('background')
80
- # print('# retrieved words:', n_retrieved)
81
- print(f'Acc: {total_score / n_total * 100:.2f}')
82
-
83
- n_total = 0
84
- total_score = 0
85
- for id, gt_phrases in id2object.items():
86
- pred_sent = pred_id2sent[id]
87
-
88
- score = 0
89
- n_phrases = len(gt_phrases)
90
-
91
- for gt_phrase in gt_phrases:
92
- word_score = 0
93
- for gt_word in gt_phrase.split():
94
- if gt_word in pred_sent:
95
- word_score += 1
96
- if len(gt_phrase.split()) > 0:
97
- score += word_score / len(gt_phrase.split())
98
-
99
- if n_phrases > 0:
100
- score /= n_phrases
101
-
102
- total_score += score
103
- n_total += 1
104
- print('object')
105
- # print('# retrieved words:', n_retrieved)
106
- print(f'Acc: {total_score / n_total * 100:.2f}')
107
-
108
- n_total = 0
109
- total_score = 0
110
- for id, gt_phrases in id2relation.items():
111
- pred_sent = pred_id2sent[id]
112
-
113
- score = 0
114
- n_phrases = len(gt_phrases)
115
-
116
- for gt_phrase in gt_phrases:
117
- word_score = 0
118
- for gt_word in gt_phrase.split():
119
- if gt_word in pred_sent:
120
- word_score += 1
121
- if len(gt_phrase.split()) > 0:
122
- score += word_score / len(gt_phrase.split())
123
-
124
- if n_phrases > 0:
125
- score /= n_phrases
126
-
127
- total_score += score
128
- n_total += 1
129
- print('relation')
130
- # print('# retrieved words:', n_retrieved)
131
- print(f'Acc: {total_score / n_total * 100:.2f}')
132
-
133
-
134
- if __name__ == '__main__':
135
- parser = argparse.ArgumentParser()
136
- parser.add_argument('--finecapeval_path', type=str, default="data/FineCapEval.csv")
137
- parser.add_argument('--generated_id2caption', type=str, default="FineCapEval_results/mle.json")
138
- args = parser.parse_args()
139
-
140
- df = pd.read_csv(args.finecapeval_path)
141
- assert df.shape == (5000, 5)
142
-
143
- generated_id2caption = json.load(open(args.generated_id2caption, 'r'))
144
-
145
- print("Preprocessing GT FineCapEval data...")
146
- id2caption = {}
147
- id2background = {}
148
- id2object = {}
149
- id2relation = {}
150
-
151
- for row in tqdm(df.itertuples(), total=len(df)):
152
-
153
- id = row.image.split('.')[0]
154
- caption = row.caption
155
- background = row.background
156
- object = row.object
157
- relation = row.relation
158
-
159
- if not isinstance(caption, str):
160
- continue
161
- if not isinstance(background, str):
162
- continue
163
- if not isinstance(object, str):
164
- continue
165
- if not isinstance(relation, str):
166
- continue
167
-
168
- if id not in id2caption:
169
- id2caption[id] = []
170
- id2background[id] = []
171
- id2object[id] = []
172
- id2relation[id] = []
173
-
174
- id2caption[id].append(caption)
175
-
176
- phrases = []
177
- for phrase in background.lower().split('\;'):
178
- if len(phrase) > 1:
179
- phrase = nltk_process(phrase)
180
- phrases.append(phrase)
181
- id2background[id].extend(phrases)
182
-
183
- phrases = []
184
- for phrase in object.lower().split('\;'):
185
- if len(phrase) > 1:
186
- phrase = nltk_process(phrase)
187
- phrases.append(phrase)
188
- id2object[id].extend(phrases)
189
-
190
- phrases = []
191
- for phrase in relation.lower().split('\;'):
192
- if len(phrase) > 1:
193
- phrase = nltk_process(phrase)
194
- phrases.append(phrase)
195
- id2relation[id].extend(phrases)
196
-
197
- print("Calculating scores...")
198
- calculate_finegrained_scores(
199
- generated_id2caption,
200
- id2caption,
201
- use_coco_eval=True)
202
-
203
-
204
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tools/finecapeval_inference.py DELETED
@@ -1,186 +0,0 @@
1
- import sys
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
- import torch.optim as optim
6
-
7
- import numpy as np
8
-
9
- import time
10
- import os
11
- from collections import defaultdict
12
- import json
13
-
14
- import captioning.utils.opts as opts
15
- import captioning.models as models
16
- from captioning.data.pth_loader import CaptionDataset
17
- import captioning.utils.eval_utils as eval_utils
18
- # import captioning.utils.vizwiz_eval_utils as vizwiz_eval_utils
19
- import captioning.utils.misc as utils
20
- from captioning.utils.rewards import init_scorer, get_self_critical_reward
21
- from captioning.modules.loss_wrapper import LossWrapper
22
-
23
- import pytorch_lightning as pl
24
-
25
-
26
- class ModelCheckpoint(pl.callbacks.ModelCheckpoint):
27
-
28
- def on_keyboard_interrupt(self, trainer, pl_module):
29
- # Save model when keyboard interrupt
30
- filepath = os.path.join(self.dirpath, self.prefix + 'interrupt.ckpt')
31
- self._save_model(filepath)
32
-
33
-
34
- if __name__ == '__main__':
35
-
36
- device = 'cuda'
37
-
38
- import argparse
39
- parser = argparse.ArgumentParser()
40
- parser.add_argument('--reward', type=str, default='mle')
41
- args = parser.parse_args()
42
-
43
- if args.reward == 'mle':
44
- cfg = f'configs/phase1/fg_clipRN50_{args.reward}.yml'
45
- else:
46
- cfg = f'configs/phase2/fg_clipRN50_{args.reward}.yml'
47
-
48
- print("Loading cfg from", cfg)
49
-
50
- opt = opts.parse_opt(parse=False, cfg=cfg)
51
-
52
- dataset = CaptionDataset(opt)
53
-
54
- opt.vocab_size = dataset.vocab_size
55
- opt.seq_length = dataset.seq_length
56
-
57
- opt.batch_size = 40
58
-
59
- opt.vocab = dataset.get_vocab()
60
-
61
- model = models.setup(opt)
62
- del opt.vocab
63
-
64
- ckpt_path = opt.checkpoint_path + '-last.ckpt'
65
-
66
- print("Loading checkpoint from", ckpt_path)
67
- raw_state_dict = torch.load(
68
- ckpt_path,
69
- map_location=device)
70
-
71
- strict = True
72
-
73
- state_dict = raw_state_dict['state_dict']
74
-
75
- if '_vocab' in state_dict:
76
- model.vocab = utils.deserialize(state_dict['_vocab'])
77
- del state_dict['_vocab']
78
- elif strict:
79
- raise KeyError
80
- if '_opt' in state_dict:
81
- saved_model_opt = utils.deserialize(state_dict['_opt'])
82
- del state_dict['_opt']
83
- # Make sure the saved opt is compatible with the curren topt
84
- need_be_same = ["caption_model",
85
- "rnn_type", "rnn_size", "num_layers"]
86
- for checkme in need_be_same:
87
- if getattr(saved_model_opt, checkme) in ['updown', 'topdown'] and \
88
- getattr(opt, checkme) in ['updown', 'topdown']:
89
- continue
90
- assert getattr(saved_model_opt, checkme) == getattr(
91
- opt, checkme), "Command line argument and saved model disagree on '%s' " % checkme
92
- elif strict:
93
- raise KeyError
94
- res = model.load_state_dict(state_dict, strict)
95
- print(res)
96
-
97
- opt.use_grammar = False
98
-
99
- lw_model = LossWrapper(model, opt)
100
-
101
- split = 'test'
102
-
103
- print("Building dataloader...")
104
-
105
- test_dataset = torch.utils.data.Subset(
106
- dataset,
107
- dataset.split_ix[split]
108
- )
109
- test_loader = torch.utils.data.DataLoader(
110
- test_dataset,
111
- batch_size=opt.batch_size,
112
- shuffle=False,
113
- num_workers=4,
114
- drop_last=False,
115
- collate_fn=dataset.collate_func
116
- )
117
-
118
- eval_kwargs = {'dataset': opt.input_json}
119
- eval_kwargs.update(vars(opt))
120
-
121
- verbose = eval_kwargs.get('verbose', True)
122
- verbose_beam = eval_kwargs.get('verbose_beam', 0)
123
- verbose_loss = eval_kwargs.get('verbose_loss', 1)
124
- # num_images = eval_kwargs.get('num_images', eval_kwargs.get('val_images_use', -1))
125
- # lang_eval = eval_kwargs.get('language_eval', 0)
126
- dataset = eval_kwargs.get('dataset', 'coco')
127
- beam_size = eval_kwargs.get('beam_size', 1)
128
- sample_n = eval_kwargs.get('sample_n', 1)
129
- remove_bad_endings = eval_kwargs.get('remove_bad_endings', 0)
130
-
131
- crit = lw_model.crit
132
-
133
- model = model.to(device)
134
-
135
- from tqdm import tqdm
136
-
137
- test_id2sent = {}
138
-
139
- model.eval()
140
-
141
- print("running inference...")
142
-
143
- for data in tqdm(test_loader):
144
- with torch.no_grad():
145
- # forward the model to get loss
146
- tmp = [data['fc_feats'], data['att_feats'],
147
- data['labels'], data['masks'], data['att_masks']]
148
- tmp = [d.to(device) if isinstance(d, torch.Tensor) else d for d in tmp]
149
-
150
- fc_feats, att_feats, labels, masks, att_masks = tmp
151
-
152
- loss = crit(model(fc_feats, att_feats,
153
- labels[..., :-1], att_masks), labels[..., 1:], masks[..., 1:])
154
-
155
- # forward the model to also get generated samples for each image
156
- # Only leave one feature for each image, in case duplicate sample
157
- tmp_eval_kwargs = eval_kwargs.copy()
158
- tmp_eval_kwargs.update({'sample_n': 1})
159
- seq, seq_logprobs = model(
160
- fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample')
161
- seq = seq.data
162
- entropy = - (F.softmax(seq_logprobs, dim=2) *
163
- seq_logprobs).sum(2).sum(1) / ((seq > 0).to(seq_logprobs).sum(1)+1)
164
- perplexity = - \
165
- seq_logprobs.gather(2, seq.unsqueeze(2)).squeeze(
166
- 2).sum(1) / ((seq > 0).to(seq_logprobs).sum(1)+1)
167
-
168
- # Print beam search
169
- if beam_size > 1 and verbose_beam:
170
- for i in range(fc_feats.shape[0]):
171
- print('\n'.join([utils.decode_sequence(model.vocab, _[
172
- 'seq'].unsqueeze(0))[0] for _ in model.done_beams[i]]))
173
- print('--' * 10)
174
- sents = utils.decode_sequence(model.vocab, seq)
175
-
176
- for d, sent in zip(data['infos'], sents):
177
- test_id2sent[d['id']] = sent
178
-
179
- res_path = f'FineCapEval_results/clipRN50_{args.reward}.json'
180
-
181
- print("Results save at {}".format(res_path))
182
-
183
- with open(res_path, 'w') as f:
184
- json.dump(test_id2sent, f)
185
-
186
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tools/train_pl.py DELETED
@@ -1,709 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- import torch.optim as optim
5
-
6
- import numpy as np
7
-
8
- import time
9
- import os
10
- from collections import defaultdict
11
-
12
- import captioning.utils.opts as opts
13
- import captioning.models as models
14
- from captioning.data.pth_loader import CaptionDataset
15
- import captioning.utils.eval_utils as eval_utils
16
- import captioning.utils.misc as utils
17
- from captioning.utils.rewards import init_scorer, get_self_critical_reward
18
- from captioning.modules.loss_wrapper import LossWrapper
19
-
20
- import pytorch_lightning as pl
21
-
22
- import detectron2.utils.comm as d2comm
23
- from detectron2.utils.env import seed_all_rng
24
- seed_all_rng(1234)
25
-
26
-
27
- class LitModel(pl.LightningModule):
28
- def __init__(self, opt):
29
- super().__init__()
30
- self.opt = opt
31
- # Intilaize dataset
32
- self.dataset = CaptionDataset(opt)
33
- opt.vocab_size = self.dataset.vocab_size
34
- opt.seq_length = self.dataset.seq_length
35
- self.batch_size = opt.batch_size
36
-
37
- # Build model
38
- opt.vocab = self.dataset.get_vocab()
39
- model = models.setup(opt)
40
- # print(model)
41
- del opt.vocab
42
-
43
- # wrapper with loss in it.
44
- lw_model = LossWrapper(model, opt)
45
-
46
- self.model = model
47
- self.lw_model = lw_model
48
-
49
- self.struc_flag = None
50
- self.sc_flag = None
51
-
52
- # if self.opt.use_clipscore:
53
- # if self.opt.use_clipscore or os.getenv('EVALUATE', '0') == '1':
54
- # if CLIP-S+Grammar is used in reward -> Launch another CLIP-S where parameter is unchanged
55
- if getattr(self.opt, 'use_grammar', False):
56
- from captioning.utils.clipscore import CLIPScore
57
- self.val_clipscore_model = CLIPScore(
58
- mode=opt.clipscore_mode, use_grammar=False)
59
- for p in self.val_clipscore_model.parameters():
60
- p.requires_grad = False
61
- else:
62
- if self.lw_model.clipscore_model is not None:
63
- self.val_clipscore_model = self.lw_model.clipscore_model
64
- else:
65
- from captioning.utils.clipscore import CLIPScore
66
- self.val_clipscore_model = CLIPScore(
67
- mode=opt.clipscore_mode, use_grammar=False)
68
- for p in self.val_clipscore_model.parameters():
69
- p.requires_grad = False
70
- self.val_clipscore_model.eval()
71
-
72
- # BERTSCORE
73
- from bert_score import BERTScorer
74
- self.bert_scorer = BERTScorer(
75
- lang="en",
76
- # rescale_with_baseline=True,
77
- rescale_with_baseline=False,
78
- device='cpu'
79
- )
80
-
81
- def forward(self, *args, **kwargs):
82
- """
83
- I hate this design. Never pretend it as a nn.Module
84
- """
85
- raise NotImplementedError
86
-
87
- def train_dataloader(self):
88
- train_dataset = torch.utils.data.Subset(
89
- self.dataset,
90
- self.dataset.split_ix['train']
91
- )
92
-
93
- train_loader = torch.utils.data.DataLoader(
94
- dataset=train_dataset,
95
- batch_size=self.batch_size,
96
- shuffle=True,
97
- num_workers=4,
98
- collate_fn=self.dataset.collate_func
99
- )
100
- return train_loader
101
-
102
- def val_dataloader(self, split='val'):
103
- val_dataset = torch.utils.data.Subset(
104
- self.dataset,
105
- self.dataset.split_ix[split]
106
- )
107
- val_loader = torch.utils.data.DataLoader(
108
- val_dataset,
109
- batch_size=self.batch_size,
110
- shuffle=False,
111
- num_workers=4,
112
- drop_last=False,
113
- collate_fn=self.dataset.collate_func
114
- )
115
- return val_loader
116
-
117
- def test_dataloader(self):
118
- return self.val_dataloader('test')
119
-
120
- def training_step(self, data, batch_idx):
121
- sc_flag, struc_flag = self.sc_flag, self.struc_flag
122
-
123
- tmp = [data['fc_feats'], data['att_feats'],
124
- data['labels'], data['masks'], data['att_masks']]
125
- fc_feats, att_feats, labels, masks, att_masks = tmp
126
- if int(os.getenv('M2_cider', '0')) != 0:
127
- data['gts'] = data['rawgts']
128
-
129
- if self.opt.use_clipscore:
130
- clip_vis_feats = data['clip_vis_feats']
131
- model_out = self.lw_model(fc_feats, att_feats, labels, masks, att_masks,
132
- data['gts'], torch.arange(0, len(data['gts'])), sc_flag, struc_flag,
133
- clip_vis_feats=clip_vis_feats)
134
- else:
135
- model_out = self.lw_model(fc_feats, att_feats, labels, masks, att_masks,
136
- data['gts'], torch.arange(0, len(data['gts'])), sc_flag, struc_flag)
137
- loss = model_out['loss']
138
-
139
- data_time = self.trainer.profiler.recorded_durations["get_train_batch"][-1]
140
- data_time = torch.tensor(data_time)
141
-
142
- logger_logs = model_out.copy()
143
- # if struc_flag or sc_flag:
144
- # logger_logs['reward'] = model_out['reward'].mean()
145
- # logger_logs['reward_var'] = model_out['reward'].var(1).mean()
146
- if struc_flag or sc_flag:
147
- logger_logs['reward'] = model_out['reward'].mean()
148
- for k in ['CLIP-S', 'RefCLIP-S', 'CIDEr', 'grammar_reward']:
149
- if k in model_out:
150
- logger_logs[k] = model_out[k]
151
- if struc_flag:
152
- logger_logs['reward_var'] = model_out['reward'].var(1).mean()
153
-
154
- logger_logs['scheduled_sampling_prob'] = torch.tensor(
155
- self.model.ss_prob)
156
- # logger_logs['training_loss'] = loss
157
- logger_logs['loss'] = loss
158
- logger_logs['data_time'] = data_time
159
-
160
- # UserWarning: The {progress_bar:dict keyword} was deprecated in 0.9.1 and will be removed in 1.0.0
161
- # Please use self.log(...) inside the lightningModule instead.
162
-
163
- # # log on a step or aggregate epoch metric to the logger and/or progress bar
164
- # # (inside LightningModule)
165
- # self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
166
- # warnings.warn(*args, **kwargs)
167
- # UserWarning: The {log:dict keyword} was deprecated in 0.9.1 and will be removed in 1.0.0
168
- # Please use self.log(...) inside the lightningModule instead.
169
-
170
- # output = {
171
- # 'loss': loss,
172
- # 'log': logger_logs,
173
- # 'progress_bar': {'data_time': data_time}
174
- # }
175
-
176
- for k, v in logger_logs.items():
177
- if k in ['reward', 'reward_var', 'data_time', 'CLIP-S', 'RefCLIP-S', 'CIDEr', 'grammar_reward']:
178
- self.log('train/'+k, v, prog_bar=True)
179
- else:
180
- self.log('train/'+k, v)
181
-
182
- return loss
183
-
184
- def validation_step(self, data, batch_idx):
185
- model = self.model
186
- crit = self.lw_model.crit
187
-
188
- opt = self.opt
189
- eval_kwargs = {'dataset': opt.input_json}
190
- eval_kwargs.update(vars(opt))
191
-
192
- # CLIPScore
193
- use_grammar = getattr(self.opt, 'use_grammar', False)
194
- joint_out = getattr(self.opt, 'joint_out', False)
195
-
196
- verbose = eval_kwargs.get('verbose', True)
197
- verbose_beam = eval_kwargs.get('verbose_beam', 0)
198
- verbose_loss = eval_kwargs.get('verbose_loss', 1)
199
- # num_images = eval_kwargs.get('num_images', eval_kwargs.get('val_images_use', -1))
200
- # lang_eval = eval_kwargs.get('language_eval', 0)
201
- dataset = eval_kwargs.get('dataset', 'coco')
202
- beam_size = eval_kwargs.get('beam_size', 1)
203
- sample_n = eval_kwargs.get('sample_n', 1)
204
- remove_bad_endings = eval_kwargs.get('remove_bad_endings', 0)
205
- # Use this nasty way to make other code clean since it's a global configuration
206
- os.environ["REMOVE_BAD_ENDINGS"] = str(remove_bad_endings)
207
-
208
- predictions = []
209
- n_predictions = []
210
-
211
- loss = torch.tensor(0)
212
- if data.get('labels', None) is not None and verbose_loss:
213
- # forward the model to get loss
214
- tmp = [data['fc_feats'], data['att_feats'],
215
- data['labels'], data['masks'], data['att_masks']]
216
- fc_feats, att_feats, labels, masks, att_masks = tmp
217
-
218
- loss = crit(model(fc_feats, att_feats,
219
- labels[..., :-1], att_masks), labels[..., 1:], masks[..., 1:])
220
-
221
- # forward the model to also get generated samples for each image
222
- # Only leave one feature for each image, in case duplicate sample
223
- tmp_eval_kwargs = eval_kwargs.copy()
224
- tmp_eval_kwargs.update({'sample_n': 1})
225
- seq, seq_logprobs = model(
226
- fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample')
227
- seq = seq.data
228
- entropy = - (F.softmax(seq_logprobs, dim=2) *
229
- seq_logprobs).sum(2).sum(1) / ((seq > 0).to(seq_logprobs).sum(1)+1)
230
- perplexity = - \
231
- seq_logprobs.gather(2, seq.unsqueeze(2)).squeeze(
232
- 2).sum(1) / ((seq > 0).to(seq_logprobs).sum(1)+1)
233
-
234
- # Print beam search
235
- if beam_size > 1 and verbose_beam:
236
- for i in range(fc_feats.shape[0]):
237
- print('\n'.join([utils.decode_sequence(model.vocab, _[
238
- 'seq'].unsqueeze(0))[0] for _ in model.done_beams[i]]))
239
- print('--' * 10)
240
- sents = utils.decode_sequence(model.vocab, seq)
241
-
242
- # if self.opt.use_clipscore or os.getenv('EVALUATE', '0') == '1':
243
- # text_feat = self.lw_model.clipscore_model.text_extract(sents)
244
- text_feat = self.val_clipscore_model.text_extract(sents, proj_norm=False)
245
-
246
- text_cont_feat = self.val_clipscore_model.clip_model.text_projection(text_feat)
247
- text_cont_feat = text_cont_feat / text_cont_feat.norm(dim=-1, keepdim=True)
248
-
249
- vis_feat = data['clip_vis_feats']
250
- # if self.opt.clipscore_mode == 'clip_s':
251
- # clip_s = self.val_clipscore_model(text_feat=text_cont_feat, img_feat=vis_feat, mode='clip_s')
252
-
253
- # elif self.opt.clipscore_mode == 'refclip_s':
254
- clip_s = self.val_clipscore_model(text_feat=text_cont_feat, img_feat=vis_feat, mode='clip_s')
255
- # ref_text = utils.decode_sequence(model.vocab, data['gts'])
256
-
257
- gt_indices = torch.arange(0, len(data['gts']))
258
- data_gts = [data['gts'][_] for _ in gt_indices.tolist()]
259
-
260
- B = len(data_gts)
261
-
262
- gts = []
263
- gts_valid_mask = []
264
- max_n_refs = max([len(_gts) for _gts in data_gts])
265
- for i in range(len(data_gts)):
266
- _gts = utils.decode_sequence(model.vocab, data_gts[i])
267
- # pad references
268
- n_ref = len(_gts)
269
- _gts.extend([''] * (max_n_refs - n_ref))
270
- gts.extend(_gts)
271
- gts_valid_mask.extend([1] * n_ref + [0] * (max_n_refs - n_ref))
272
- assert len(gts) == B * max_n_refs
273
- assert len(gts_valid_mask) == B * max_n_refs
274
-
275
- ref_text = gts
276
- ref_text_mask = gts_valid_mask
277
-
278
- refclip_s = self.val_clipscore_model(
279
- text_feat=text_cont_feat, img_feat=vis_feat,
280
- ref_text=ref_text, ref_text_mask=ref_text_mask, mode='refclip_s')
281
-
282
- # use_grammar = getattr(self.opt, 'use_grammar', False)
283
- # joint_out = getattr(self.opt, 'joint_out', False)
284
- if use_grammar and not joint_out:
285
- with torch.no_grad():
286
- # grammar_logit = self.val_clipscore_model.grammar_score_head(text_feat.view(-1, 512))
287
- grammar_logit = self.lw_model.clipscore_model.grammar_score_head(text_feat.view(-1, 512))
288
- grammar_prob = torch.softmax(grammar_logit, dim=-1)[:, 1]
289
-
290
-
291
- # BERTScore
292
- if next(self.bert_scorer._model.parameters()).device != self.device:
293
- self.bert_scorer._model.to(self.device)
294
- self.bert_scorer.device = self.device
295
-
296
-
297
- # [B*K] -> [B, K]
298
- ref_text_per_example = []
299
- for i in range(B):
300
- ref_text_list_example = []
301
- for k in range(max_n_refs):
302
- ref = ref_text[i * max_n_refs + k]
303
- if len(ref) > 0:
304
- ref_text_list_example.append(ref)
305
- # assert len(ref_text_list_example) == max_n_refs
306
- ref_text_per_example.append(ref_text_list_example)
307
- assert len(ref_text_per_example) == B
308
-
309
- P, R, F1 = self.bert_scorer.score(
310
- sents,
311
- ref_text_per_example,
312
- )
313
- bertscore_f1 = F1
314
- # print('Example 5:')
315
- # for i in range(5):
316
- # print('Generated:', sents[i])
317
- # print('ref_text:', ref_text_per_example[i])
318
- # print('BERT-Score:', F1[i].item())
319
-
320
-
321
- for k, sent in enumerate(sents):
322
- entry = {'image_id': data['infos'][k]['id'], 'caption': sent,
323
- 'perplexity': perplexity[k].item(), 'entropy': entropy[k].item()}
324
- if self.opt.use_clipscore or os.getenv('EVALUATE', '0') == '1':
325
- # if self.opt.clipscore_mode == 'clip_s':
326
- # entry['clipscore'] = clipscore[k].item()
327
- # entry['CLIP-S'] = clip_s[k].item()
328
- # elif self.opt.clipscore_mode == 'refclip_s':
329
- entry['CLIP-S'] = clip_s[k].item()
330
- entry['RefCLIP-S'] = refclip_s[k].item()
331
-
332
- if use_grammar and not joint_out:
333
- entry['grammar_prob'] = grammar_prob[k].item()
334
-
335
- # BERT-S
336
- entry['BERT-S'] = bertscore_f1[k].item()
337
-
338
- if eval_kwargs.get('dump_path', 0) == 1:
339
- entry['file_name'] = data['infos'][k]['file_path']
340
- predictions.append(entry)
341
- if eval_kwargs.get('dump_images', 0) == 1:
342
- # dump the raw image to vis/ folder
343
- cmd = 'cp "' + os.path.join(eval_kwargs['image_root'], data['infos'][k]['file_path']) + \
344
- '" vis/imgs/img' + \
345
- str(len(predictions)) + '.jpg' # bit gross
346
- print(cmd)
347
- os.system(cmd)
348
-
349
- if verbose:
350
- print('image %s: %s' %
351
- (entry['image_id'], entry['caption']))
352
-
353
- if sample_n > 1:
354
- eval_utils.eval_split_n(model, n_predictions, [
355
- fc_feats, att_feats, att_masks, data], eval_kwargs)
356
-
357
- output = {
358
- # 'val_loss': loss,
359
- 'loss': loss,
360
- 'predictions': predictions,
361
- 'n_predictions': n_predictions,
362
- }
363
- return output
364
-
365
- def test_step(self, *args, **kwargs):
366
- return self.validation_step(*args, **kwargs)
367
-
368
- def validation_epoch_end(self, outputs, split='val'):
369
- outputs = d2comm.gather(outputs)
370
- # master node
371
- if d2comm.is_main_process():
372
- assert self.trainer.node_rank == 0 and self.trainer.local_rank == 0
373
- outputs = sum(outputs, [])
374
-
375
- opt = self.opt
376
- # val_loss_mean = sum([_['val_loss']
377
- # val_loss_mean = sum([_['val_loss'].cpu()
378
- val_loss_mean = sum([_['loss'].cpu()
379
- for _ in outputs]) / len(outputs)
380
-
381
- predictions = sum([_['predictions'] for _ in outputs], [])
382
- if len(outputs[0]['n_predictions']) != 0:
383
- n_predictions = sum([_['n_predictions'] for _ in outputs], [])
384
- else:
385
- n_predictions = []
386
-
387
- lang_stats = None
388
- if len(n_predictions) > 0 and 'perplexity' in n_predictions[0]:
389
- n_predictions = sorted(
390
- n_predictions, key=lambda x: x['perplexity'])
391
-
392
- if not os.path.isdir('eval_results'):
393
- os.mkdir('eval_results')
394
- torch.save((predictions, n_predictions), os.path.join(
395
- 'eval_results/', '.saved_pred_' + opt.id + '_' + split + '.pth'))
396
-
397
- if opt.language_eval:
398
- lang_stats = eval_utils.language_eval(
399
- opt.input_json, predictions, n_predictions, vars(opt), split)
400
-
401
- if opt.reduce_on_plateau:
402
- optimizer = self.trainer.optimizers[0]
403
- if 'CIDEr' in lang_stats:
404
- optimizer.scheduler_step(-lang_stats['CIDEr'])
405
- else:
406
- optimizer.scheduler_step(val_loss_mean)
407
-
408
- # out = {
409
- # 'val_loss': val_loss_mean
410
- # }
411
- out = {
412
- 'loss': val_loss_mean
413
- }
414
- out.update(lang_stats)
415
- # out['to_monitor'] = lang_stats['CIDEr'] if lang_stats is not None else -val_loss_mean
416
- if self.opt.use_clipscore or os.getenv('EVALUATE', '0') == '1':
417
- # if self.opt.clipscore_mode == 'clip_s':
418
- # out['clipscore'] = sum([p['clipscore'] for p in predictions]) / len(predictions)
419
- # print('CLIPScore', out['clipscore'])
420
- # out['CLIP-S'] = sum([p['CLIP-S'] for p in predictions]) / len(predictions)
421
- # print('CLIP-S', out['CLIP-S'])
422
- # elif self.opt.clipscore_mode == 'refclip_s':
423
- out['CLIP-S'] = sum([p['CLIP-S'] for p in predictions]) / len(predictions)
424
- print('CLIP-S', out['CLIP-S'])
425
-
426
- out['RefCLIP-S'] = sum([p['RefCLIP-S'] for p in predictions]) / len(predictions)
427
- print('RefCLIP-S', out['RefCLIP-S'])
428
-
429
- if getattr(self.opt, 'use_grammar', False) and not getattr(self.opt, 'joint_out', False):
430
- out['grammar_prob'] = sum([p['grammar_prob'] for p in predictions]) / len(predictions)
431
- print('grammar_prob', out['grammar_prob'])
432
-
433
- out['BERT-S'] = sum([p['BERT-S'] for p in predictions]) / len(predictions)
434
- print('BERT-S', out['BERT-S'])
435
- else:
436
- out = {}
437
-
438
- out = d2comm.all_gather(out)[0] # Only the one from master node
439
- assert len(out) > 0 # make sure the head has index 0
440
-
441
- # must all be tensors
442
- out = {k: torch.tensor(v) if not torch.is_tensor(
443
- v) else v for k, v in out.items()}
444
-
445
- # return {
446
- # 'progress_bar': {'val_loss': out['val_loss']},
447
- # 'log': out,
448
- # }
449
- for k, v in out.items():
450
- # if k in ['loss', 'clipscore', 'RefCLIP-S', 'CIDEr']:
451
- # if split != 'test':
452
- # self.log(f'{split}/{k}', v, prog_bar=True)
453
- # elif k == 'to_monitor':
454
- # if split != 'test':
455
- # self.log(f'{split}/{k}', v)
456
- # else:
457
- self.log(f'{split}/{k}', v)
458
-
459
- def test_epoch_end(self, outputs):
460
- # out = self.validation_epoch_end(outputs, 'test')
461
- # out['progress_bar'] = {
462
- # # 'test_loss': out['progress_bar']['val_loss']
463
- # 'test_loss': out['progress_bar']['loss']
464
- # }
465
- # out['log']['test_loss'] = out['log']['val_loss']
466
- # del out['log']['val_loss']
467
- # del out['log']['to_monitor']
468
-
469
- # out['log'] = {'test_'+k if 'test' not in k else k:v \
470
- # for k,v in out['log'].items()}
471
-
472
- # return out
473
- self.validation_epoch_end(outputs, 'test')
474
-
475
- def configure_optimizers(self):
476
- opt = self.opt
477
- model = self.model
478
-
479
- parameters = [p for p in model.parameters() if p.requires_grad]
480
-
481
- if opt.noamopt:
482
- # assert opt.caption_model in ['transformer', 'bert', 'm2transformer'], 'noamopt can only work with transformer'
483
- optimizer = utils.get_std_opt(
484
- model, optim_func=opt.optim, factor=opt.noamopt_factor, warmup=opt.noamopt_warmup)
485
- elif opt.reduce_on_plateau:
486
- # optimizer = utils.build_optimizer(model.parameters(), opt)
487
- optimizer = utils.build_optimizer(parameters, opt)
488
- optimizer = utils.ReduceLROnPlateau(optimizer,
489
- factor=opt.reduce_on_plateau_factor,
490
- patience=opt.reduce_on_plateau_patience)
491
- else:
492
- # optimizer = utils.build_optimizer(model.parameters(), opt)
493
- optimizer = utils.build_optimizer(parameters, opt)
494
- return [optimizer], []
495
-
496
- def optimizer_step(self, epoch, batch_idx, optimizer,
497
- optimizer_idx, *args, **kwargs):
498
- # warm up lr
499
- opt = self.opt
500
- iteration = self.trainer.global_step
501
- if opt.use_warmup and (iteration < opt.noamopt_warmup):
502
- opt.current_lr = opt.learning_rate * \
503
- (iteration+1) / opt.noamopt_warmup
504
- utils.set_lr(optimizer, opt.current_lr)
505
-
506
- super().optimizer_step(epoch, batch_idx, optimizer,
507
- optimizer_idx, *args, **kwargs)
508
-
509
- def state_dict(self):
510
- """
511
- Save the model state dict as well as opt and vocab
512
- """
513
- state_dict = self.model.state_dict()
514
- device = next(iter(state_dict.values())).device
515
- assert '_vocab' not in state_dict and '_opt' not in state_dict, 'Just in case'
516
- state_dict.update({
517
- '_vocab': utils.serialize_to_tensor(self.model.vocab).to(device),
518
- '_opt': utils.serialize_to_tensor(self.opt).to(device)
519
- })
520
- return state_dict
521
-
522
- def load_state_dict(self, state_dict=None, strict=True):
523
- if '_vocab' in state_dict:
524
- self.model.vocab = utils.deserialize(state_dict['_vocab'])
525
- del state_dict['_vocab']
526
- # elif strict:
527
- # raise KeyError
528
- if '_opt' in state_dict:
529
- saved_model_opt = utils.deserialize(state_dict['_opt'])
530
- del state_dict['_opt']
531
- opt = self.opt
532
- # Make sure the saved opt is compatible with the curren topt
533
- need_be_same = ["caption_model",
534
- "rnn_type", "rnn_size", "num_layers"]
535
- for checkme in need_be_same:
536
- if getattr(saved_model_opt, checkme) in ['updown', 'topdown'] and \
537
- getattr(opt, checkme) in ['updown', 'topdown']:
538
- continue
539
- assert getattr(saved_model_opt, checkme) == getattr(
540
- opt, checkme), "Command line argument and saved model disagree on '%s' " % checkme
541
- # elif strict:
542
- # raise KeyError
543
- self.model.load_state_dict(state_dict, strict)
544
-
545
-
546
- class OnEpochStartCallback(pl.Callback):
547
-
548
- def on_epoch_start(self, trainer, pl_module):
549
- # Update lr/training stage/scheduled sampling prob etc.
550
- opt = pl_module.opt
551
- model = pl_module.model
552
- epoch = trainer.current_epoch
553
- optimizer = trainer.optimizers[0]
554
-
555
- if not opt.noamopt and not opt.reduce_on_plateau:
556
- # Assign the learning rate
557
- if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
558
- frac = (
559
- epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every
560
- decay_factor = opt.learning_rate_decay_rate ** frac
561
- opt.current_lr = opt.learning_rate * decay_factor
562
- else:
563
- opt.current_lr = opt.learning_rate
564
- utils.set_lr(optimizer, opt.current_lr) # set the decayed rate
565
- # Assign the scheduled sampling prob
566
- if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
567
- frac = (
568
- epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every
569
- opt.ss_prob = min(opt.scheduled_sampling_increase_prob *
570
- frac, opt.scheduled_sampling_max_prob)
571
- model.ss_prob = opt.ss_prob
572
-
573
- # If start self critical training
574
- if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
575
- sc_flag = True
576
- init_scorer(opt.cached_tokens)
577
- else:
578
- sc_flag = False
579
-
580
- # If start structure loss training
581
- if opt.structure_after != -1 and epoch >= opt.structure_after:
582
- struc_flag = True
583
- init_scorer(opt.cached_tokens)
584
- else:
585
- struc_flag = False
586
-
587
- pl_module.struc_flag = struc_flag
588
- pl_module.sc_flag = sc_flag
589
-
590
-
591
- class ModelCheckpoint(pl.callbacks.ModelCheckpoint):
592
-
593
- def on_keyboard_interrupt(self, trainer, pl_module):
594
- # Save model when keyboard interrupt
595
- filepath = os.path.join(self.dirpath, self.prefix + 'interrupt.ckpt')
596
- self._save_model(filepath)
597
-
598
-
599
- opt = opts.parse_opt()
600
-
601
- checkpoint_callback = ModelCheckpoint(
602
- filepath=opt.checkpoint_path,
603
- # dirpath=opt.checkpoint_path,
604
- save_last=True,
605
- save_top_k=1,
606
- verbose=True,
607
- # monitor='to_monitor',
608
- # monitor='val/to_monitor',
609
- monitor='val/CIDEr',
610
- mode='max',
611
- # prefix=opt.id+'_',
612
- prefix=opt.id,
613
- # filename=f'{opt.id}_',
614
- )
615
-
616
- verbose = True
617
- # import torch
618
- # if torch.cuda.current_device() in [0, -1]:
619
- if 'LOCAL_RANK' in os.environ and os.environ['LOCAL_RANK'] != '0':
620
- verbose = False
621
-
622
- if verbose:
623
- print(opt)
624
- print("""
625
- val_image_use,
626
- save_checkpoint_very
627
- save_every_epoch,
628
- save_history-ckpt will be ignored.
629
- """)
630
-
631
- # Lightning defines batch size as batch size per gpu
632
- assert opt.batch_size % torch.cuda.device_count() == 0
633
- opt.batch_size = opt.batch_size // torch.cuda.device_count()
634
-
635
- # If resume from last checkpoint
636
- # if opt.start_from is not None and os.path.isfile(os.path.join(opt.start_from, f'{opt.id}_last.ckpt')):
637
- # resume_from = os.path.join(opt.start_from, f'{opt.id}_last.ckpt')
638
- if opt.start_from is not None:
639
- resume_from = os.path.join(opt.start_from, f'{opt.id}-last.ckpt')
640
- if os.path.isfile(resume_from):
641
- if verbose:
642
- print('Loading checkpoint from', resume_from)
643
- else:
644
- print("Checkpoint not found:", resume_from)
645
- resume_from = None
646
- else:
647
- resume_from = None
648
-
649
- from pytorch_lightning.loggers import WandbLogger
650
- wandb_logger = WandbLogger(
651
- project='CLIP-ViL-COCOCaption',
652
- name=opt.id,
653
- )
654
-
655
- if verbose:
656
- wandb_logger.experiment.config.update(opt)
657
- from pathlib import Path
658
- import glob
659
- import wandb
660
- # src_dir = Path(__file__).resolve().parent.parent
661
- glob_str = "**/*.py"
662
- base_path = './'
663
- wandb.save(glob_str=glob_str, base_path=base_path)
664
-
665
- # code = wandb.Artifact('project-source', type='code')
666
- # for path in glob.glob('**/*.py', recursive=True):
667
- # code.add_file(path, name='source/'+path)
668
- # print(path)
669
- # wandb.run.use_artifact(code)
670
-
671
-
672
-
673
-
674
- lit = LitModel(opt)
675
- # warning grad_clip_mode is ignored.
676
- trainer = pl.Trainer(
677
- callbacks=[
678
- OnEpochStartCallback(),
679
- # pl.callbacks.lr_logger.LearningRateLogger()
680
- pl.callbacks.LearningRateMonitor()
681
- ],
682
- default_root_dir=opt.checkpoint_path,
683
- resume_from_checkpoint=resume_from,
684
- distributed_backend='ddp',
685
- check_val_every_n_epoch=1,
686
- max_epochs=opt.max_epochs,
687
- gradient_clip_val=opt.grad_clip_value,
688
- gpus=torch.cuda.device_count(),
689
- checkpoint_callback=checkpoint_callback,
690
- log_gpu_memory='min_max',
691
- # log_save_interval=opt.losses_log_every,
692
- log_every_n_steps=opt.losses_log_every,
693
- profiler=True,
694
- # profiler='simple',
695
- # row_log_interval=10, # what is it?
696
- flush_logs_every_n_steps=10,
697
- num_sanity_val_steps=0,
698
- # val_check_interval=0.01,
699
- # limit_train_batches=500,
700
- # progress_bar_refresh_rate=0,
701
- # fast_dev_run=True,
702
- precision=opt.precision,
703
- logger=wandb_logger
704
- )
705
-
706
- if os.getenv('EVALUATE', '0') == '1':
707
- trainer.test(lit)
708
- else:
709
- trainer.fit(lit)