Spaces:
Runtime error
Runtime error
Alberto Carmona
commited on
Commit
·
23d1bbb
1
Parent(s):
ebd4e51
Track error cloning the repo
Browse files- clip/__init__.py +0 -1
- clip/bpe_simple_vocab_16e6.txt.gz +0 -3
- clip/clip.py +0 -193
- clip/model.py +0 -437
- clip/simple_tokenizer.py +0 -132
- configs/phase1/FineCapEval_clipRN50_mle.yml +0 -60
- configs/phase1/clipRN50_mle.yml +0 -52
- configs/phase1/transformer.yml +0 -41
- configs/phase2/FineCapEval_clipRN50_cider.yml +0 -61
- configs/phase2/FineCapEval_clipRN50_cider_clips.yml +0 -65
- configs/phase2/FineCapEval_clipRN50_clips.yml +0 -64
- configs/phase2/FineCapEval_clipRN50_clips_grammar.yml +0 -64
- configs/phase2/clipRN50_cider.yml +0 -58
- configs/phase2/clipRN50_cider_clips.yml +0 -61
- configs/phase2/clipRN50_clips.yml +0 -58
- configs/phase2/clipRN50_clips_grammar.yml +0 -64
- configs/phase2/transformer.yml +0 -41
- data/README.md +0 -1
- retrieval/README.md +0 -5
- retrieval/caption_data.py +0 -500
- retrieval/clip_model.py +0 -350
- retrieval/configs/clip_negative_text.yaml +0 -14
- retrieval/param.py +0 -209
- retrieval/pth_loader.py +0 -334
- retrieval/text_utils.py +0 -74
- retrieval/train_pl.py +0 -661
- save/README.md +0 -1
- scripts/build_bpe_subword_nmt.py +0 -214
- scripts/clip_prepro_feats.py +0 -170
- scripts/clipscore_prepro_feats.py +0 -162
- scripts/copy_model.sh +0 -9
- scripts/dump_to_h5df.py +0 -56
- scripts/dump_to_lmdb.py +0 -241
- scripts/make_bu_data.py +0 -52
- scripts/prepro_feats.py +0 -103
- scripts/prepro_labels.py +0 -206
- scripts/prepro_ngrams.py +0 -94
- scripts/prepro_reference_json.py +0 -69
- scripts_FineCapEval/clip_prepro_feats.py +0 -163
- scripts_FineCapEval/clipscore_prepro_feats.py +0 -154
- scripts_FineCapEval/prepro_labels.py +0 -209
- tools/eval.py +0 -125
- tools/eval_clip_retrieval.py +0 -231
- tools/eval_finecapeval.py +0 -204
- tools/finecapeval_inference.py +0 -186
- tools/train_pl.py +0 -709
clip/__init__.py
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
from .clip import *
|
|
|
|
clip/bpe_simple_vocab_16e6.txt.gz
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
|
3 |
-
size 1356917
|
|
|
|
|
|
|
|
clip/clip.py
DELETED
@@ -1,193 +0,0 @@
|
|
1 |
-
import hashlib
|
2 |
-
import os
|
3 |
-
import urllib
|
4 |
-
import warnings
|
5 |
-
from typing import Union, List
|
6 |
-
|
7 |
-
import torch
|
8 |
-
from PIL import Image
|
9 |
-
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
10 |
-
from tqdm import tqdm
|
11 |
-
|
12 |
-
from .model import build_model
|
13 |
-
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
|
14 |
-
|
15 |
-
__all__ = ["available_models", "load", "tokenize"]
|
16 |
-
_tokenizer = _Tokenizer()
|
17 |
-
|
18 |
-
_MODELS = {
|
19 |
-
"RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
|
20 |
-
"RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
|
21 |
-
"RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
|
22 |
-
"ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
|
23 |
-
}
|
24 |
-
|
25 |
-
|
26 |
-
def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
|
27 |
-
os.makedirs(root, exist_ok=True)
|
28 |
-
filename = os.path.basename(url)
|
29 |
-
|
30 |
-
expected_sha256 = url.split("/")[-2]
|
31 |
-
download_target = os.path.join(root, filename)
|
32 |
-
|
33 |
-
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
34 |
-
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
35 |
-
|
36 |
-
if os.path.isfile(download_target):
|
37 |
-
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
|
38 |
-
return download_target
|
39 |
-
else:
|
40 |
-
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
|
41 |
-
|
42 |
-
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
43 |
-
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
|
44 |
-
while True:
|
45 |
-
buffer = source.read(8192)
|
46 |
-
if not buffer:
|
47 |
-
break
|
48 |
-
|
49 |
-
output.write(buffer)
|
50 |
-
loop.update(len(buffer))
|
51 |
-
|
52 |
-
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
|
53 |
-
raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
|
54 |
-
|
55 |
-
return download_target
|
56 |
-
|
57 |
-
|
58 |
-
def _transform(n_px):
|
59 |
-
return Compose([
|
60 |
-
Resize(n_px, interpolation=Image.BICUBIC),
|
61 |
-
CenterCrop(n_px),
|
62 |
-
lambda image: image.convert("RGB"),
|
63 |
-
ToTensor(),
|
64 |
-
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
65 |
-
])
|
66 |
-
|
67 |
-
|
68 |
-
def available_models() -> List[str]:
|
69 |
-
"""Returns the names of available CLIP models"""
|
70 |
-
return list(_MODELS.keys())
|
71 |
-
|
72 |
-
|
73 |
-
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True):
|
74 |
-
"""Load a CLIP model
|
75 |
-
|
76 |
-
Parameters
|
77 |
-
----------
|
78 |
-
name : str
|
79 |
-
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
|
80 |
-
|
81 |
-
device : Union[str, torch.device]
|
82 |
-
The device to put the loaded model
|
83 |
-
|
84 |
-
jit : bool
|
85 |
-
Whether to load the optimized JIT model (default) or more hackable non-JIT model.
|
86 |
-
|
87 |
-
Returns
|
88 |
-
-------
|
89 |
-
model : torch.nn.Module
|
90 |
-
The CLIP model
|
91 |
-
|
92 |
-
preprocess : Callable[[PIL.Image], torch.Tensor]
|
93 |
-
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
|
94 |
-
"""
|
95 |
-
if name in _MODELS:
|
96 |
-
model_path = _download(_MODELS[name])
|
97 |
-
elif os.path.isfile(name):
|
98 |
-
model_path = name
|
99 |
-
else:
|
100 |
-
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
|
101 |
-
|
102 |
-
try:
|
103 |
-
# loading JIT archive
|
104 |
-
model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
|
105 |
-
state_dict = None
|
106 |
-
except RuntimeError:
|
107 |
-
# loading saved state dict
|
108 |
-
if jit:
|
109 |
-
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
|
110 |
-
jit = False
|
111 |
-
state_dict = torch.load(model_path, map_location="cpu")
|
112 |
-
|
113 |
-
if not jit:
|
114 |
-
model = build_model(state_dict or model.state_dict()).to(device)
|
115 |
-
if str(device) == "cpu":
|
116 |
-
model.float()
|
117 |
-
return model, _transform(model.visual.input_resolution)
|
118 |
-
|
119 |
-
# patch the device names
|
120 |
-
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
|
121 |
-
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
|
122 |
-
|
123 |
-
def patch_device(module):
|
124 |
-
graphs = [module.graph] if hasattr(module, "graph") else []
|
125 |
-
if hasattr(module, "forward1"):
|
126 |
-
graphs.append(module.forward1.graph)
|
127 |
-
|
128 |
-
for graph in graphs:
|
129 |
-
for node in graph.findAllNodes("prim::Constant"):
|
130 |
-
if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
|
131 |
-
node.copyAttributes(device_node)
|
132 |
-
|
133 |
-
model.apply(patch_device)
|
134 |
-
patch_device(model.encode_image)
|
135 |
-
patch_device(model.encode_text)
|
136 |
-
|
137 |
-
# patch dtype to float32 on CPU
|
138 |
-
if str(device) == "cpu":
|
139 |
-
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
|
140 |
-
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
|
141 |
-
float_node = float_input.node()
|
142 |
-
|
143 |
-
def patch_float(module):
|
144 |
-
graphs = [module.graph] if hasattr(module, "graph") else []
|
145 |
-
if hasattr(module, "forward1"):
|
146 |
-
graphs.append(module.forward1.graph)
|
147 |
-
|
148 |
-
for graph in graphs:
|
149 |
-
for node in graph.findAllNodes("aten::to"):
|
150 |
-
inputs = list(node.inputs())
|
151 |
-
for i in [1, 2]: # dtype can be the second or third argument to aten::to()
|
152 |
-
if inputs[i].node()["value"] == 5:
|
153 |
-
inputs[i].node().copyAttributes(float_node)
|
154 |
-
|
155 |
-
model.apply(patch_float)
|
156 |
-
patch_float(model.encode_image)
|
157 |
-
patch_float(model.encode_text)
|
158 |
-
|
159 |
-
model.float()
|
160 |
-
|
161 |
-
return model, _transform(model.input_resolution.item())
|
162 |
-
|
163 |
-
|
164 |
-
def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
|
165 |
-
"""
|
166 |
-
Returns the tokenized representation of given input string(s)
|
167 |
-
|
168 |
-
Parameters
|
169 |
-
----------
|
170 |
-
texts : Union[str, List[str]]
|
171 |
-
An input string or a list of input strings to tokenize
|
172 |
-
|
173 |
-
context_length : int
|
174 |
-
The context length to use; all CLIP models use 77 as the context length
|
175 |
-
|
176 |
-
Returns
|
177 |
-
-------
|
178 |
-
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
|
179 |
-
"""
|
180 |
-
if isinstance(texts, str):
|
181 |
-
texts = [texts]
|
182 |
-
|
183 |
-
sot_token = _tokenizer.encoder["<|startoftext|>"]
|
184 |
-
eot_token = _tokenizer.encoder["<|endoftext|>"]
|
185 |
-
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
|
186 |
-
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
187 |
-
|
188 |
-
for i, tokens in enumerate(all_tokens):
|
189 |
-
if len(tokens) > context_length:
|
190 |
-
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
|
191 |
-
result[i, :len(tokens)] = torch.tensor(tokens)
|
192 |
-
|
193 |
-
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
clip/model.py
DELETED
@@ -1,437 +0,0 @@
|
|
1 |
-
from collections import OrderedDict
|
2 |
-
from typing import Tuple, Union
|
3 |
-
|
4 |
-
import torch
|
5 |
-
import torch.nn.functional as F
|
6 |
-
from torch import nn
|
7 |
-
|
8 |
-
|
9 |
-
class Bottleneck(nn.Module):
|
10 |
-
expansion = 4
|
11 |
-
|
12 |
-
def __init__(self, inplanes, planes, stride=1):
|
13 |
-
super().__init__()
|
14 |
-
|
15 |
-
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
|
16 |
-
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
17 |
-
self.bn1 = nn.BatchNorm2d(planes)
|
18 |
-
|
19 |
-
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
20 |
-
self.bn2 = nn.BatchNorm2d(planes)
|
21 |
-
|
22 |
-
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
23 |
-
|
24 |
-
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
25 |
-
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
26 |
-
|
27 |
-
self.relu = nn.ReLU(inplace=True)
|
28 |
-
self.downsample = None
|
29 |
-
self.stride = stride
|
30 |
-
|
31 |
-
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
32 |
-
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
|
33 |
-
self.downsample = nn.Sequential(OrderedDict([
|
34 |
-
("-1", nn.AvgPool2d(stride)),
|
35 |
-
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
|
36 |
-
("1", nn.BatchNorm2d(planes * self.expansion))
|
37 |
-
]))
|
38 |
-
|
39 |
-
def forward(self, x: torch.Tensor):
|
40 |
-
identity = x
|
41 |
-
|
42 |
-
out = self.relu(self.bn1(self.conv1(x)))
|
43 |
-
out = self.relu(self.bn2(self.conv2(out)))
|
44 |
-
out = self.avgpool(out)
|
45 |
-
out = self.bn3(self.conv3(out))
|
46 |
-
|
47 |
-
if self.downsample is not None:
|
48 |
-
identity = self.downsample(x)
|
49 |
-
|
50 |
-
out += identity
|
51 |
-
out = self.relu(out)
|
52 |
-
return out
|
53 |
-
|
54 |
-
|
55 |
-
class AttentionPool2d(nn.Module):
|
56 |
-
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
57 |
-
super().__init__()
|
58 |
-
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
|
59 |
-
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
60 |
-
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
61 |
-
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
62 |
-
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
63 |
-
self.num_heads = num_heads
|
64 |
-
|
65 |
-
def forward(self, x):
|
66 |
-
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
|
67 |
-
# print(x.shape, self.positional_embedding.shape)
|
68 |
-
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
|
69 |
-
x = x + self.positional_embedding[0, :, None, :].to(x.dtype) # (HW+1)NC
|
70 |
-
x, _ = F.multi_head_attention_forward(
|
71 |
-
query=x, key=x, value=x,
|
72 |
-
embed_dim_to_check=x.shape[-1],
|
73 |
-
num_heads=self.num_heads,
|
74 |
-
q_proj_weight=self.q_proj.weight,
|
75 |
-
k_proj_weight=self.k_proj.weight,
|
76 |
-
v_proj_weight=self.v_proj.weight,
|
77 |
-
in_proj_weight=None,
|
78 |
-
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
79 |
-
bias_k=None,
|
80 |
-
bias_v=None,
|
81 |
-
add_zero_attn=False,
|
82 |
-
dropout_p=0,
|
83 |
-
out_proj_weight=torch.ones_like(self.q_proj.weight),
|
84 |
-
out_proj_bias=torch.zeros_like(self.q_proj.bias),
|
85 |
-
# out_proj_weight=self.c_proj.weight,
|
86 |
-
# out_proj_bias=self.c_proj.bias,
|
87 |
-
use_separate_proj_weight=True,
|
88 |
-
training=self.training,
|
89 |
-
need_weights=False
|
90 |
-
)
|
91 |
-
|
92 |
-
return x[0]
|
93 |
-
|
94 |
-
|
95 |
-
class ModifiedResNet(nn.Module):
|
96 |
-
"""
|
97 |
-
A ResNet class that is similar to torchvision's but contains the following changes:
|
98 |
-
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
|
99 |
-
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
|
100 |
-
- The final pooling layer is a QKV attention instead of an average pool
|
101 |
-
"""
|
102 |
-
|
103 |
-
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
|
104 |
-
super().__init__()
|
105 |
-
self.output_dim = output_dim
|
106 |
-
self.input_resolution = input_resolution
|
107 |
-
|
108 |
-
# the 3-layer stem
|
109 |
-
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
|
110 |
-
self.bn1 = nn.BatchNorm2d(width // 2)
|
111 |
-
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
|
112 |
-
self.bn2 = nn.BatchNorm2d(width // 2)
|
113 |
-
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
|
114 |
-
self.bn3 = nn.BatchNorm2d(width)
|
115 |
-
self.avgpool = nn.AvgPool2d(2)
|
116 |
-
self.relu = nn.ReLU(inplace=True)
|
117 |
-
|
118 |
-
# residual layers
|
119 |
-
self._inplanes = width # this is a *mutable* variable used during construction
|
120 |
-
self.layer1 = self._make_layer(width, layers[0])
|
121 |
-
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
|
122 |
-
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
|
123 |
-
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
|
124 |
-
|
125 |
-
embed_dim = width * 32 # the ResNet feature dimension
|
126 |
-
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
|
127 |
-
|
128 |
-
def _make_layer(self, planes, blocks, stride=1):
|
129 |
-
layers = [Bottleneck(self._inplanes, planes, stride)]
|
130 |
-
|
131 |
-
self._inplanes = planes * Bottleneck.expansion
|
132 |
-
for _ in range(1, blocks):
|
133 |
-
layers.append(Bottleneck(self._inplanes, planes))
|
134 |
-
|
135 |
-
return nn.Sequential(*layers)
|
136 |
-
|
137 |
-
def forward(self, x):
|
138 |
-
def stem(x):
|
139 |
-
for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
|
140 |
-
x = self.relu(bn(conv(x)))
|
141 |
-
x = self.avgpool(x)
|
142 |
-
return x
|
143 |
-
|
144 |
-
x = x.type(self.conv1.weight.dtype)
|
145 |
-
x = stem(x)
|
146 |
-
x = self.layer1(x)
|
147 |
-
x = self.layer2(x)
|
148 |
-
x = self.layer3(x)
|
149 |
-
x = self.layer4(x)
|
150 |
-
# print(x.shape)
|
151 |
-
# x = self.attnpool(x)
|
152 |
-
attnpool = self.attnpool(x)
|
153 |
-
|
154 |
-
return (x, attnpool)
|
155 |
-
|
156 |
-
|
157 |
-
class LayerNorm(nn.LayerNorm):
|
158 |
-
"""Subclass torch's LayerNorm to handle fp16."""
|
159 |
-
|
160 |
-
def forward(self, x: torch.Tensor):
|
161 |
-
orig_type = x.dtype
|
162 |
-
ret = super().forward(x.type(torch.float32))
|
163 |
-
return ret.type(orig_type)
|
164 |
-
|
165 |
-
|
166 |
-
class QuickGELU(nn.Module):
|
167 |
-
def forward(self, x: torch.Tensor):
|
168 |
-
return x * torch.sigmoid(1.702 * x)
|
169 |
-
|
170 |
-
|
171 |
-
class ResidualAttentionBlock(nn.Module):
|
172 |
-
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
|
173 |
-
super().__init__()
|
174 |
-
|
175 |
-
self.attn = nn.MultiheadAttention(d_model, n_head)
|
176 |
-
self.ln_1 = LayerNorm(d_model)
|
177 |
-
self.mlp = nn.Sequential(OrderedDict([
|
178 |
-
("c_fc", nn.Linear(d_model, d_model * 4)),
|
179 |
-
("gelu", QuickGELU()),
|
180 |
-
("c_proj", nn.Linear(d_model * 4, d_model))
|
181 |
-
]))
|
182 |
-
self.ln_2 = LayerNorm(d_model)
|
183 |
-
self.attn_mask = attn_mask
|
184 |
-
|
185 |
-
def attention(self, x: torch.Tensor):
|
186 |
-
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
187 |
-
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
188 |
-
|
189 |
-
def forward(self, x: torch.Tensor):
|
190 |
-
x = x + self.attention(self.ln_1(x))
|
191 |
-
x = x + self.mlp(self.ln_2(x))
|
192 |
-
return x
|
193 |
-
|
194 |
-
|
195 |
-
class Transformer(nn.Module):
|
196 |
-
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
|
197 |
-
super().__init__()
|
198 |
-
self.width = width
|
199 |
-
self.layers = layers
|
200 |
-
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
|
201 |
-
|
202 |
-
def forward(self, x: torch.Tensor):
|
203 |
-
return self.resblocks(x)
|
204 |
-
|
205 |
-
|
206 |
-
class VisualTransformer(nn.Module):
|
207 |
-
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
|
208 |
-
super().__init__()
|
209 |
-
self.input_resolution = input_resolution
|
210 |
-
self.output_dim = output_dim
|
211 |
-
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
|
212 |
-
|
213 |
-
scale = width ** -0.5
|
214 |
-
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
215 |
-
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
|
216 |
-
self.ln_pre = LayerNorm(width)
|
217 |
-
|
218 |
-
self.transformer = Transformer(width, layers, heads)
|
219 |
-
|
220 |
-
self.ln_post = LayerNorm(width)
|
221 |
-
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
|
222 |
-
|
223 |
-
def forward(self, x: torch.Tensor):
|
224 |
-
x = self.conv1(x) # shape = [*, width, grid, grid]
|
225 |
-
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
226 |
-
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
227 |
-
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
228 |
-
x = x + self.positional_embedding.to(x.dtype)
|
229 |
-
x = self.ln_pre(x)
|
230 |
-
|
231 |
-
x = x.permute(1, 0, 2) # NLD -> LND
|
232 |
-
x = self.transformer(x)
|
233 |
-
x = x.permute(1, 0, 2) # LND -> NLD
|
234 |
-
|
235 |
-
# x = self.ln_post(x[:, 0, :])
|
236 |
-
|
237 |
-
x = self.ln_post(x)
|
238 |
-
# if self.proj is not None:
|
239 |
-
# x = x @ self.proj
|
240 |
-
|
241 |
-
return x
|
242 |
-
|
243 |
-
|
244 |
-
class CLIP(nn.Module):
|
245 |
-
def __init__(self,
|
246 |
-
embed_dim: int,
|
247 |
-
# vision
|
248 |
-
image_resolution: int,
|
249 |
-
vision_layers: Union[Tuple[int, int, int, int], int],
|
250 |
-
vision_width: int,
|
251 |
-
vision_patch_size: int,
|
252 |
-
# text
|
253 |
-
context_length: int,
|
254 |
-
vocab_size: int,
|
255 |
-
transformer_width: int,
|
256 |
-
transformer_heads: int,
|
257 |
-
transformer_layers: int
|
258 |
-
):
|
259 |
-
super().__init__()
|
260 |
-
|
261 |
-
self.context_length = context_length
|
262 |
-
|
263 |
-
if isinstance(vision_layers, (tuple, list)):
|
264 |
-
vision_heads = vision_width * 32 // 64
|
265 |
-
self.visual = ModifiedResNet(
|
266 |
-
layers=vision_layers,
|
267 |
-
output_dim=embed_dim,
|
268 |
-
heads=vision_heads,
|
269 |
-
input_resolution=image_resolution,
|
270 |
-
width=vision_width
|
271 |
-
)
|
272 |
-
else:
|
273 |
-
vision_heads = vision_width // 64
|
274 |
-
self.visual = VisualTransformer(
|
275 |
-
input_resolution=image_resolution,
|
276 |
-
patch_size=vision_patch_size,
|
277 |
-
width=vision_width,
|
278 |
-
layers=vision_layers,
|
279 |
-
heads=vision_heads,
|
280 |
-
output_dim=embed_dim
|
281 |
-
)
|
282 |
-
|
283 |
-
self.transformer = Transformer(
|
284 |
-
width=transformer_width,
|
285 |
-
layers=transformer_layers,
|
286 |
-
heads=transformer_heads,
|
287 |
-
attn_mask=self.build_attention_mask()
|
288 |
-
)
|
289 |
-
|
290 |
-
self.vocab_size = vocab_size
|
291 |
-
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
|
292 |
-
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
|
293 |
-
self.ln_final = LayerNorm(transformer_width)
|
294 |
-
|
295 |
-
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
|
296 |
-
self.logit_scale = nn.Parameter(torch.ones([]))
|
297 |
-
|
298 |
-
self.initialize_parameters()
|
299 |
-
|
300 |
-
def initialize_parameters(self):
|
301 |
-
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
302 |
-
nn.init.normal_(self.positional_embedding, std=0.01)
|
303 |
-
|
304 |
-
if isinstance(self.visual, ModifiedResNet):
|
305 |
-
if self.visual.attnpool is not None:
|
306 |
-
std = self.visual.attnpool.c_proj.in_features ** -0.5
|
307 |
-
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
|
308 |
-
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
|
309 |
-
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
|
310 |
-
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
|
311 |
-
|
312 |
-
for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
|
313 |
-
for name, param in resnet_block.named_parameters():
|
314 |
-
if name.endswith("bn3.weight"):
|
315 |
-
nn.init.zeros_(param)
|
316 |
-
|
317 |
-
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
|
318 |
-
attn_std = self.transformer.width ** -0.5
|
319 |
-
fc_std = (2 * self.transformer.width) ** -0.5
|
320 |
-
for block in self.transformer.resblocks:
|
321 |
-
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
322 |
-
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
323 |
-
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
324 |
-
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
325 |
-
|
326 |
-
if self.text_projection is not None:
|
327 |
-
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
|
328 |
-
|
329 |
-
def build_attention_mask(self):
|
330 |
-
# lazily create causal attention mask, with full attention between the vision tokens
|
331 |
-
# pytorch uses additive attention mask; fill with -inf
|
332 |
-
mask = torch.empty(self.context_length, self.context_length)
|
333 |
-
mask.fill_(float("-inf"))
|
334 |
-
mask.triu_(1) # zero out the lower diagonal
|
335 |
-
return mask
|
336 |
-
|
337 |
-
@property
|
338 |
-
def dtype(self):
|
339 |
-
return self.visual.conv1.weight.dtype
|
340 |
-
|
341 |
-
def encode_image(self, image):
|
342 |
-
return self.visual(image.type(self.dtype))
|
343 |
-
|
344 |
-
def encode_text(self, text):
|
345 |
-
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
|
346 |
-
|
347 |
-
x = x + self.positional_embedding.type(self.dtype)
|
348 |
-
x = x.permute(1, 0, 2) # NLD -> LND
|
349 |
-
x = self.transformer(x)
|
350 |
-
x = x.permute(1, 0, 2) # LND -> NLD
|
351 |
-
x = self.ln_final(x).type(self.dtype)
|
352 |
-
|
353 |
-
# x.shape = [batch_size, n_ctx, transformer.width]
|
354 |
-
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
355 |
-
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
356 |
-
|
357 |
-
return x
|
358 |
-
|
359 |
-
def forward(self, image, text):
|
360 |
-
image_features = self.encode_image(image)
|
361 |
-
text_features = self.encode_text(text)
|
362 |
-
|
363 |
-
# normalized features
|
364 |
-
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
365 |
-
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
366 |
-
|
367 |
-
# cosine similarity as logits
|
368 |
-
logit_scale = self.logit_scale.exp()
|
369 |
-
logits_per_image = logit_scale * image_features @ text_features.t()
|
370 |
-
logits_per_text = logit_scale * text_features @ image_features.t()
|
371 |
-
|
372 |
-
# shape = [global_batch_size, global_batch_size]
|
373 |
-
return logits_per_image, logits_per_text
|
374 |
-
|
375 |
-
|
376 |
-
def convert_weights(model: nn.Module):
|
377 |
-
"""Convert applicable model parameters to fp16"""
|
378 |
-
|
379 |
-
def _convert_weights_to_fp16(l):
|
380 |
-
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
381 |
-
l.weight.data = l.weight.data.half()
|
382 |
-
if l.bias is not None:
|
383 |
-
l.bias.data = l.bias.data.half()
|
384 |
-
|
385 |
-
if isinstance(l, nn.MultiheadAttention):
|
386 |
-
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
387 |
-
tensor = getattr(l, attr)
|
388 |
-
if tensor is not None:
|
389 |
-
tensor.data = tensor.data.half()
|
390 |
-
|
391 |
-
for name in ["text_projection", "proj"]:
|
392 |
-
if hasattr(l, name):
|
393 |
-
attr = getattr(l, name)
|
394 |
-
if attr is not None:
|
395 |
-
attr.data = attr.data.half()
|
396 |
-
|
397 |
-
model.apply(_convert_weights_to_fp16)
|
398 |
-
|
399 |
-
|
400 |
-
def build_model(state_dict: dict):
|
401 |
-
vit = "visual.proj" in state_dict
|
402 |
-
|
403 |
-
if vit:
|
404 |
-
vision_width = state_dict["visual.conv1.weight"].shape[0]
|
405 |
-
vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
|
406 |
-
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
|
407 |
-
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
|
408 |
-
image_resolution = vision_patch_size * grid_size
|
409 |
-
else:
|
410 |
-
counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
|
411 |
-
vision_layers = tuple(counts)
|
412 |
-
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
|
413 |
-
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
|
414 |
-
vision_patch_size = None
|
415 |
-
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
|
416 |
-
image_resolution = output_width * 32
|
417 |
-
|
418 |
-
embed_dim = state_dict["text_projection"].shape[1]
|
419 |
-
context_length = state_dict["positional_embedding"].shape[0]
|
420 |
-
vocab_size = state_dict["token_embedding.weight"].shape[0]
|
421 |
-
transformer_width = state_dict["ln_final.weight"].shape[0]
|
422 |
-
transformer_heads = transformer_width // 64
|
423 |
-
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
|
424 |
-
|
425 |
-
model = CLIP(
|
426 |
-
embed_dim,
|
427 |
-
image_resolution, vision_layers, vision_width, vision_patch_size,
|
428 |
-
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
|
429 |
-
)
|
430 |
-
|
431 |
-
for key in ["input_resolution", "context_length", "vocab_size"]:
|
432 |
-
if key in state_dict:
|
433 |
-
del state_dict[key]
|
434 |
-
|
435 |
-
convert_weights(model)
|
436 |
-
model.load_state_dict(state_dict)
|
437 |
-
return model.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
clip/simple_tokenizer.py
DELETED
@@ -1,132 +0,0 @@
|
|
1 |
-
import gzip
|
2 |
-
import html
|
3 |
-
import os
|
4 |
-
from functools import lru_cache
|
5 |
-
|
6 |
-
import ftfy
|
7 |
-
import regex as re
|
8 |
-
|
9 |
-
|
10 |
-
@lru_cache()
|
11 |
-
def default_bpe():
|
12 |
-
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
|
13 |
-
|
14 |
-
|
15 |
-
@lru_cache()
|
16 |
-
def bytes_to_unicode():
|
17 |
-
"""
|
18 |
-
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
19 |
-
The reversible bpe codes work on unicode strings.
|
20 |
-
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
21 |
-
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
22 |
-
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
23 |
-
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
24 |
-
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
25 |
-
"""
|
26 |
-
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
27 |
-
cs = bs[:]
|
28 |
-
n = 0
|
29 |
-
for b in range(2**8):
|
30 |
-
if b not in bs:
|
31 |
-
bs.append(b)
|
32 |
-
cs.append(2**8+n)
|
33 |
-
n += 1
|
34 |
-
cs = [chr(n) for n in cs]
|
35 |
-
return dict(zip(bs, cs))
|
36 |
-
|
37 |
-
|
38 |
-
def get_pairs(word):
|
39 |
-
"""Return set of symbol pairs in a word.
|
40 |
-
Word is represented as tuple of symbols (symbols being variable-length strings).
|
41 |
-
"""
|
42 |
-
pairs = set()
|
43 |
-
prev_char = word[0]
|
44 |
-
for char in word[1:]:
|
45 |
-
pairs.add((prev_char, char))
|
46 |
-
prev_char = char
|
47 |
-
return pairs
|
48 |
-
|
49 |
-
|
50 |
-
def basic_clean(text):
|
51 |
-
text = ftfy.fix_text(text)
|
52 |
-
text = html.unescape(html.unescape(text))
|
53 |
-
return text.strip()
|
54 |
-
|
55 |
-
|
56 |
-
def whitespace_clean(text):
|
57 |
-
text = re.sub(r'\s+', ' ', text)
|
58 |
-
text = text.strip()
|
59 |
-
return text
|
60 |
-
|
61 |
-
|
62 |
-
class SimpleTokenizer(object):
|
63 |
-
def __init__(self, bpe_path: str = default_bpe()):
|
64 |
-
self.byte_encoder = bytes_to_unicode()
|
65 |
-
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
66 |
-
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
|
67 |
-
merges = merges[1:49152-256-2+1]
|
68 |
-
merges = [tuple(merge.split()) for merge in merges]
|
69 |
-
vocab = list(bytes_to_unicode().values())
|
70 |
-
vocab = vocab + [v+'</w>' for v in vocab]
|
71 |
-
for merge in merges:
|
72 |
-
vocab.append(''.join(merge))
|
73 |
-
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
|
74 |
-
self.encoder = dict(zip(vocab, range(len(vocab))))
|
75 |
-
self.decoder = {v: k for k, v in self.encoder.items()}
|
76 |
-
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
77 |
-
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
|
78 |
-
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
|
79 |
-
|
80 |
-
def bpe(self, token):
|
81 |
-
if token in self.cache:
|
82 |
-
return self.cache[token]
|
83 |
-
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
|
84 |
-
pairs = get_pairs(word)
|
85 |
-
|
86 |
-
if not pairs:
|
87 |
-
return token+'</w>'
|
88 |
-
|
89 |
-
while True:
|
90 |
-
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
91 |
-
if bigram not in self.bpe_ranks:
|
92 |
-
break
|
93 |
-
first, second = bigram
|
94 |
-
new_word = []
|
95 |
-
i = 0
|
96 |
-
while i < len(word):
|
97 |
-
try:
|
98 |
-
j = word.index(first, i)
|
99 |
-
new_word.extend(word[i:j])
|
100 |
-
i = j
|
101 |
-
except:
|
102 |
-
new_word.extend(word[i:])
|
103 |
-
break
|
104 |
-
|
105 |
-
if word[i] == first and i < len(word)-1 and word[i+1] == second:
|
106 |
-
new_word.append(first+second)
|
107 |
-
i += 2
|
108 |
-
else:
|
109 |
-
new_word.append(word[i])
|
110 |
-
i += 1
|
111 |
-
new_word = tuple(new_word)
|
112 |
-
word = new_word
|
113 |
-
if len(word) == 1:
|
114 |
-
break
|
115 |
-
else:
|
116 |
-
pairs = get_pairs(word)
|
117 |
-
word = ' '.join(word)
|
118 |
-
self.cache[token] = word
|
119 |
-
return word
|
120 |
-
|
121 |
-
def encode(self, text):
|
122 |
-
bpe_tokens = []
|
123 |
-
text = whitespace_clean(basic_clean(text)).lower()
|
124 |
-
for token in re.findall(self.pat, text):
|
125 |
-
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
126 |
-
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
127 |
-
return bpe_tokens
|
128 |
-
|
129 |
-
def decode(self, tokens):
|
130 |
-
text = ''.join([self.decoder[token] for token in tokens])
|
131 |
-
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
|
132 |
-
return text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/phase1/FineCapEval_clipRN50_mle.yml
DELETED
@@ -1,60 +0,0 @@
|
|
1 |
-
caption_model: transformer
|
2 |
-
noamopt: true
|
3 |
-
noamopt_warmup: 20000
|
4 |
-
label_smoothing: 0.0
|
5 |
-
input_json: data/FineCapEval.json
|
6 |
-
input_label_h5: none
|
7 |
-
input_fc_dir: data/FineCapEval_clip_RN50_fc
|
8 |
-
input_att_dir: data/FineCapEval_clip_RN50_att
|
9 |
-
input_clipscore_vis_dir: data/FineCapEval_clipscore_vis
|
10 |
-
|
11 |
-
seq_per_img: 5
|
12 |
-
batch_size: 200
|
13 |
-
learning_rate: 0.0005
|
14 |
-
|
15 |
-
checkpoint_path: ./save/clipRN50_mle/clipRN50_mle
|
16 |
-
|
17 |
-
# clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt'
|
18 |
-
|
19 |
-
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
|
20 |
-
# N=num_layers
|
21 |
-
# d_model=input_encoding_size
|
22 |
-
# d_ff=rnn_size
|
23 |
-
|
24 |
-
# will be ignored
|
25 |
-
num_layers: 6
|
26 |
-
input_encoding_size: 512
|
27 |
-
rnn_size: 2048
|
28 |
-
|
29 |
-
# Transformer config
|
30 |
-
N_enc: 6
|
31 |
-
N_dec: 6
|
32 |
-
d_model: 512
|
33 |
-
d_ff: 2048
|
34 |
-
num_att_heads: 8
|
35 |
-
dropout: 0.1
|
36 |
-
|
37 |
-
|
38 |
-
learning_rate_decay_start: 0
|
39 |
-
scheduled_sampling_start: -1
|
40 |
-
save_checkpoint_every: 3000
|
41 |
-
language_eval: 1
|
42 |
-
val_images_use: 5000
|
43 |
-
max_epochs: 15
|
44 |
-
train_sample_n: 5
|
45 |
-
|
46 |
-
REFORWARD: false
|
47 |
-
|
48 |
-
# _BASE_: transformer.yml
|
49 |
-
reduce_on_plateau: false
|
50 |
-
noamopt: false
|
51 |
-
learning_rate: 0.000005
|
52 |
-
learning_rate_decay_start: -1
|
53 |
-
|
54 |
-
self_critical_after: 15
|
55 |
-
max_epochs: 50
|
56 |
-
|
57 |
-
verbose: false
|
58 |
-
precision: 32
|
59 |
-
|
60 |
-
use_clipscore: false
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/phase1/clipRN50_mle.yml
DELETED
@@ -1,52 +0,0 @@
|
|
1 |
-
caption_model: transformer
|
2 |
-
noamopt: true
|
3 |
-
# noamopt: false
|
4 |
-
noamopt_warmup: 20000
|
5 |
-
label_smoothing: 0.0
|
6 |
-
input_json: data/cocotalk.json
|
7 |
-
input_label_h5: data/cocotalk_label.h5
|
8 |
-
input_fc_dir: data/cocotalk_clip_RN50_fc
|
9 |
-
input_att_dir: data/cocotalk_clip_RN50_att
|
10 |
-
input_clipscore_vis_dir: data/cocotalk_clipscore_vis
|
11 |
-
seq_per_img: 5
|
12 |
-
# batch_size: 600
|
13 |
-
batch_size: 200
|
14 |
-
|
15 |
-
learning_rate: 0.0005
|
16 |
-
|
17 |
-
# checkpoint_path: ./save/trans_clip_rn50_sc_pl
|
18 |
-
checkpoint_path: save/clipRN50_mle/clipRN50_mle
|
19 |
-
|
20 |
-
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
|
21 |
-
# N=num_layers
|
22 |
-
# d_model=input_encoding_size
|
23 |
-
# d_ff=rnn_size
|
24 |
-
|
25 |
-
# will be ignored
|
26 |
-
num_layers: 6
|
27 |
-
input_encoding_size: 512
|
28 |
-
rnn_size: 2048
|
29 |
-
|
30 |
-
# Transformer config
|
31 |
-
N_enc: 6
|
32 |
-
N_dec: 6
|
33 |
-
d_model: 512
|
34 |
-
d_ff: 2048
|
35 |
-
num_att_heads: 8
|
36 |
-
dropout: 0.1
|
37 |
-
|
38 |
-
|
39 |
-
learning_rate_decay_start: 0
|
40 |
-
scheduled_sampling_start: -1
|
41 |
-
save_checkpoint_every: 3000
|
42 |
-
language_eval: 1
|
43 |
-
val_images_use: 5000
|
44 |
-
# max_epochs: 15
|
45 |
-
max_epochs: 25
|
46 |
-
train_sample_n: 5
|
47 |
-
|
48 |
-
REFORWARD: false
|
49 |
-
|
50 |
-
|
51 |
-
verbose: false
|
52 |
-
precision: 16
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/phase1/transformer.yml
DELETED
@@ -1,41 +0,0 @@
|
|
1 |
-
caption_model: transformer
|
2 |
-
noamopt: true
|
3 |
-
noamopt_warmup: 20000
|
4 |
-
label_smoothing: 0.0
|
5 |
-
input_json: data/cocotalk.json
|
6 |
-
input_label_h5: data/cocotalk_label.h5
|
7 |
-
input_att_dir: data/cocotalk_att
|
8 |
-
seq_per_img: 5
|
9 |
-
batch_size: 10
|
10 |
-
learning_rate: 0.0005
|
11 |
-
|
12 |
-
checkpoint_path: ./save/trans_rn50_sc
|
13 |
-
|
14 |
-
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
|
15 |
-
# N=num_layers
|
16 |
-
# d_model=input_encoding_size
|
17 |
-
# d_ff=rnn_size
|
18 |
-
|
19 |
-
# will be ignored
|
20 |
-
num_layers: 6
|
21 |
-
input_encoding_size: 512
|
22 |
-
rnn_size: 2048
|
23 |
-
|
24 |
-
# Transformer config
|
25 |
-
N_enc: 6
|
26 |
-
N_dec: 6
|
27 |
-
d_model: 512
|
28 |
-
d_ff: 2048
|
29 |
-
num_att_heads: 8
|
30 |
-
dropout: 0.1
|
31 |
-
|
32 |
-
|
33 |
-
learning_rate_decay_start: 0
|
34 |
-
scheduled_sampling_start: -1
|
35 |
-
save_checkpoint_every: 3000
|
36 |
-
language_eval: 1
|
37 |
-
val_images_use: 5000
|
38 |
-
max_epochs: 15
|
39 |
-
train_sample_n: 5
|
40 |
-
|
41 |
-
REFORWARD: false
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/phase2/FineCapEval_clipRN50_cider.yml
DELETED
@@ -1,61 +0,0 @@
|
|
1 |
-
caption_model: transformer
|
2 |
-
noamopt: true
|
3 |
-
noamopt_warmup: 20000
|
4 |
-
label_smoothing: 0.0
|
5 |
-
input_json: data/FineCapEval.json
|
6 |
-
input_label_h5: none
|
7 |
-
input_fc_dir: data/FineCapEval_clip_RN50_fc
|
8 |
-
input_att_dir: data/FineCapEval_clip_RN50_att
|
9 |
-
input_clipscore_vis_dir: data/FineCapEval_clipscore_vis
|
10 |
-
|
11 |
-
seq_per_img: 5
|
12 |
-
batch_size: 200
|
13 |
-
learning_rate: 0.0005
|
14 |
-
|
15 |
-
checkpoint_path: ./save/clipRN50_cider/clipRN50_cider
|
16 |
-
|
17 |
-
# clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt'
|
18 |
-
|
19 |
-
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
|
20 |
-
# N=num_layers
|
21 |
-
# d_model=input_encoding_size
|
22 |
-
# d_ff=rnn_size
|
23 |
-
|
24 |
-
# will be ignored
|
25 |
-
num_layers: 6
|
26 |
-
input_encoding_size: 512
|
27 |
-
rnn_size: 2048
|
28 |
-
|
29 |
-
# Transformer config
|
30 |
-
N_enc: 6
|
31 |
-
N_dec: 6
|
32 |
-
d_model: 512
|
33 |
-
d_ff: 2048
|
34 |
-
num_att_heads: 8
|
35 |
-
dropout: 0.1
|
36 |
-
|
37 |
-
|
38 |
-
learning_rate_decay_start: 0
|
39 |
-
scheduled_sampling_start: -1
|
40 |
-
save_checkpoint_every: 3000
|
41 |
-
language_eval: 1
|
42 |
-
val_images_use: 5000
|
43 |
-
max_epochs: 15
|
44 |
-
train_sample_n: 5
|
45 |
-
|
46 |
-
REFORWARD: false
|
47 |
-
|
48 |
-
# _BASE_: transformer.yml
|
49 |
-
reduce_on_plateau: false
|
50 |
-
noamopt: false
|
51 |
-
learning_rate: 0.000005
|
52 |
-
learning_rate_decay_start: -1
|
53 |
-
|
54 |
-
self_critical_after: 15
|
55 |
-
max_epochs: 50
|
56 |
-
|
57 |
-
verbose: false
|
58 |
-
precision: 32
|
59 |
-
|
60 |
-
# use_clipscore: true
|
61 |
-
use_clipscore: false
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/phase2/FineCapEval_clipRN50_cider_clips.yml
DELETED
@@ -1,65 +0,0 @@
|
|
1 |
-
caption_model: transformer
|
2 |
-
noamopt: true
|
3 |
-
noamopt_warmup: 20000
|
4 |
-
label_smoothing: 0.0
|
5 |
-
input_json: data/FineCapEval.json
|
6 |
-
input_label_h5: none
|
7 |
-
input_fc_dir: data/FineCapEval_clip_RN50_fc
|
8 |
-
input_att_dir: data/FineCapEval_clip_RN50_att
|
9 |
-
input_clipscore_vis_dir: data/FineCapEval_clipscore_vis
|
10 |
-
|
11 |
-
seq_per_img: 5
|
12 |
-
batch_size: 200
|
13 |
-
learning_rate: 0.0005
|
14 |
-
|
15 |
-
checkpoint_path: ./save/clipRN50_cider_clips/clipRN50_cider_clips
|
16 |
-
|
17 |
-
# clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt'
|
18 |
-
|
19 |
-
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
|
20 |
-
# N=num_layers
|
21 |
-
# d_model=input_encoding_size
|
22 |
-
# d_ff=rnn_size
|
23 |
-
|
24 |
-
# will be ignored
|
25 |
-
num_layers: 6
|
26 |
-
input_encoding_size: 512
|
27 |
-
rnn_size: 2048
|
28 |
-
|
29 |
-
# Transformer config
|
30 |
-
N_enc: 6
|
31 |
-
N_dec: 6
|
32 |
-
d_model: 512
|
33 |
-
d_ff: 2048
|
34 |
-
num_att_heads: 8
|
35 |
-
dropout: 0.1
|
36 |
-
|
37 |
-
|
38 |
-
learning_rate_decay_start: 0
|
39 |
-
scheduled_sampling_start: -1
|
40 |
-
save_checkpoint_every: 3000
|
41 |
-
language_eval: 1
|
42 |
-
val_images_use: 5000
|
43 |
-
max_epochs: 15
|
44 |
-
train_sample_n: 5
|
45 |
-
|
46 |
-
REFORWARD: false
|
47 |
-
|
48 |
-
# _BASE_: transformer.yml
|
49 |
-
reduce_on_plateau: false
|
50 |
-
noamopt: false
|
51 |
-
learning_rate: 0.000005
|
52 |
-
learning_rate_decay_start: -1
|
53 |
-
|
54 |
-
self_critical_after: 15
|
55 |
-
max_epochs: 50
|
56 |
-
|
57 |
-
verbose: false
|
58 |
-
precision: 32
|
59 |
-
|
60 |
-
# use_clipscore: true
|
61 |
-
use_clipscore: false
|
62 |
-
clipscore_reward_weight: 2.0
|
63 |
-
clipscore_mode: clip_s
|
64 |
-
|
65 |
-
use_multi_rewards: true
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/phase2/FineCapEval_clipRN50_clips.yml
DELETED
@@ -1,64 +0,0 @@
|
|
1 |
-
caption_model: transformer
|
2 |
-
noamopt: true
|
3 |
-
noamopt_warmup: 20000
|
4 |
-
label_smoothing: 0.0
|
5 |
-
input_json: data/FineCapEval.json
|
6 |
-
input_label_h5: none
|
7 |
-
input_fc_dir: data/FineCapEval_clip_RN50_fc
|
8 |
-
input_att_dir: data/FineCapEval_clip_RN50_att
|
9 |
-
input_clipscore_vis_dir: data/FineCapEval_clipscore_vis
|
10 |
-
seq_per_img: 5
|
11 |
-
batch_size: 160
|
12 |
-
learning_rate: 0.0005
|
13 |
-
|
14 |
-
checkpoint_path: ./save/clipRN50_clips/clipRN50_clips
|
15 |
-
|
16 |
-
use_multi_rewards: false
|
17 |
-
use_grammar: false
|
18 |
-
use_grammar_baseline: false
|
19 |
-
# clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt'
|
20 |
-
|
21 |
-
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
|
22 |
-
# N=num_layers
|
23 |
-
# d_model=input_encoding_size
|
24 |
-
# d_ff=rnn_size
|
25 |
-
|
26 |
-
# will be ignored
|
27 |
-
num_layers: 6
|
28 |
-
input_encoding_size: 512
|
29 |
-
rnn_size: 2048
|
30 |
-
|
31 |
-
# Transformer config
|
32 |
-
N_enc: 6
|
33 |
-
N_dec: 6
|
34 |
-
d_model: 512
|
35 |
-
d_ff: 2048
|
36 |
-
num_att_heads: 8
|
37 |
-
dropout: 0.1
|
38 |
-
|
39 |
-
|
40 |
-
learning_rate_decay_start: 0
|
41 |
-
scheduled_sampling_start: -1
|
42 |
-
save_checkpoint_every: 3000
|
43 |
-
language_eval: 0
|
44 |
-
val_images_use: 5000
|
45 |
-
max_epochs: 15
|
46 |
-
train_sample_n: 5
|
47 |
-
|
48 |
-
REFORWARD: false
|
49 |
-
|
50 |
-
# _BASE_: transformer.yml
|
51 |
-
reduce_on_plateau: false
|
52 |
-
noamopt: false
|
53 |
-
learning_rate: 0.000005
|
54 |
-
learning_rate_decay_start: -1
|
55 |
-
|
56 |
-
self_critical_after: 15
|
57 |
-
max_epochs: 50
|
58 |
-
|
59 |
-
verbose: false
|
60 |
-
precision: 32
|
61 |
-
|
62 |
-
# use_clipscore: true
|
63 |
-
use_clipscore: false
|
64 |
-
clipscore_reward_weight: 2.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/phase2/FineCapEval_clipRN50_clips_grammar.yml
DELETED
@@ -1,64 +0,0 @@
|
|
1 |
-
caption_model: transformer
|
2 |
-
noamopt: true
|
3 |
-
noamopt_warmup: 20000
|
4 |
-
label_smoothing: 0.0
|
5 |
-
input_json: data/FineCapEval.json
|
6 |
-
input_label_h5: none
|
7 |
-
input_fc_dir: data/FineCapEval_clip_RN50_fc
|
8 |
-
input_att_dir: data/FineCapEval_clip_RN50_att
|
9 |
-
input_clipscore_vis_dir: data/FineCapEval_clipscore_vis
|
10 |
-
seq_per_img: 5
|
11 |
-
batch_size: 160
|
12 |
-
learning_rate: 0.0005
|
13 |
-
|
14 |
-
checkpoint_path: ./save/clipRN50_clips_grammar/clipRN50_clips_grammar
|
15 |
-
|
16 |
-
use_multi_rewards: true
|
17 |
-
use_grammar: true
|
18 |
-
use_grammar_baseline: true
|
19 |
-
# clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt'
|
20 |
-
|
21 |
-
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
|
22 |
-
# N=num_layers
|
23 |
-
# d_model=input_encoding_size
|
24 |
-
# d_ff=rnn_size
|
25 |
-
|
26 |
-
# will be ignored
|
27 |
-
num_layers: 6
|
28 |
-
input_encoding_size: 512
|
29 |
-
rnn_size: 2048
|
30 |
-
|
31 |
-
# Transformer config
|
32 |
-
N_enc: 6
|
33 |
-
N_dec: 6
|
34 |
-
d_model: 512
|
35 |
-
d_ff: 2048
|
36 |
-
num_att_heads: 8
|
37 |
-
dropout: 0.1
|
38 |
-
|
39 |
-
|
40 |
-
learning_rate_decay_start: 0
|
41 |
-
scheduled_sampling_start: -1
|
42 |
-
save_checkpoint_every: 3000
|
43 |
-
language_eval: 0
|
44 |
-
val_images_use: 5000
|
45 |
-
max_epochs: 15
|
46 |
-
train_sample_n: 5
|
47 |
-
|
48 |
-
REFORWARD: false
|
49 |
-
|
50 |
-
# _BASE_: transformer.yml
|
51 |
-
reduce_on_plateau: false
|
52 |
-
noamopt: false
|
53 |
-
learning_rate: 0.000005
|
54 |
-
learning_rate_decay_start: -1
|
55 |
-
|
56 |
-
self_critical_after: 15
|
57 |
-
max_epochs: 50
|
58 |
-
|
59 |
-
verbose: false
|
60 |
-
precision: 32
|
61 |
-
|
62 |
-
# use_clipscore: true
|
63 |
-
use_clipscore: false
|
64 |
-
clipscore_reward_weight: 2.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/phase2/clipRN50_cider.yml
DELETED
@@ -1,58 +0,0 @@
|
|
1 |
-
caption_model: transformer
|
2 |
-
noamopt: true
|
3 |
-
noamopt_warmup: 20000
|
4 |
-
label_smoothing: 0.0
|
5 |
-
input_json: data/cocotalk.json
|
6 |
-
input_label_h5: data/cocotalk_label.h5
|
7 |
-
input_fc_dir: data/cocotalk_clip_RN50_fc
|
8 |
-
input_att_dir: data/cocotalk_clip_RN50_att
|
9 |
-
# used only for evaluation
|
10 |
-
input_clipscore_vis_dir: data/cocotalk_clipscore_vis
|
11 |
-
|
12 |
-
seq_per_img: 5
|
13 |
-
batch_size: 200
|
14 |
-
learning_rate: 0.0005
|
15 |
-
|
16 |
-
# checkpoint_path: ./save/trans_clip_rn50_sc_pl_scst_cider
|
17 |
-
checkpoint_path: save/clipRN50_cider/clipRN50_cider
|
18 |
-
|
19 |
-
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
|
20 |
-
# N=num_layers
|
21 |
-
# d_model=input_encoding_size
|
22 |
-
# d_ff=rnn_size
|
23 |
-
|
24 |
-
# will be ignored
|
25 |
-
num_layers: 6
|
26 |
-
input_encoding_size: 512
|
27 |
-
rnn_size: 2048
|
28 |
-
|
29 |
-
# Transformer config
|
30 |
-
N_enc: 6
|
31 |
-
N_dec: 6
|
32 |
-
d_model: 512
|
33 |
-
d_ff: 2048
|
34 |
-
num_att_heads: 8
|
35 |
-
dropout: 0.1
|
36 |
-
|
37 |
-
|
38 |
-
learning_rate_decay_start: 0
|
39 |
-
scheduled_sampling_start: -1
|
40 |
-
save_checkpoint_every: 3000
|
41 |
-
language_eval: 1
|
42 |
-
val_images_use: 5000
|
43 |
-
max_epochs: 15
|
44 |
-
train_sample_n: 5
|
45 |
-
|
46 |
-
REFORWARD: false
|
47 |
-
|
48 |
-
# _BASE_: transformer.yml
|
49 |
-
reduce_on_plateau: false
|
50 |
-
noamopt: false
|
51 |
-
learning_rate: 0.000005
|
52 |
-
learning_rate_decay_start: -1
|
53 |
-
|
54 |
-
self_critical_after: 15
|
55 |
-
max_epochs: 40
|
56 |
-
|
57 |
-
verbose: false
|
58 |
-
precision: 32
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/phase2/clipRN50_cider_clips.yml
DELETED
@@ -1,61 +0,0 @@
|
|
1 |
-
caption_model: transformer
|
2 |
-
noamopt: true
|
3 |
-
noamopt_warmup: 20000
|
4 |
-
label_smoothing: 0.0
|
5 |
-
input_json: data/cocotalk.json
|
6 |
-
input_label_h5: data/cocotalk_label.h5
|
7 |
-
input_fc_dir: data/cocotalk_clip_RN50_fc
|
8 |
-
input_att_dir: data/cocotalk_clip_RN50_att
|
9 |
-
input_clipscore_vis_dir: data/cocotalk_clipscore_vis
|
10 |
-
seq_per_img: 5
|
11 |
-
batch_size: 160
|
12 |
-
learning_rate: 0.0005
|
13 |
-
|
14 |
-
checkpoint_path: save/clipRN50_cider_clips/clipRN50_cider_clips
|
15 |
-
|
16 |
-
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
|
17 |
-
# N=num_layers
|
18 |
-
# d_model=input_encoding_size
|
19 |
-
# d_ff=rnn_size
|
20 |
-
|
21 |
-
# will be ignored
|
22 |
-
num_layers: 6
|
23 |
-
input_encoding_size: 512
|
24 |
-
rnn_size: 2048
|
25 |
-
|
26 |
-
# Transformer config
|
27 |
-
N_enc: 6
|
28 |
-
N_dec: 6
|
29 |
-
d_model: 512
|
30 |
-
d_ff: 2048
|
31 |
-
num_att_heads: 8
|
32 |
-
dropout: 0.1
|
33 |
-
|
34 |
-
|
35 |
-
learning_rate_decay_start: 0
|
36 |
-
scheduled_sampling_start: -1
|
37 |
-
save_checkpoint_every: 3000
|
38 |
-
language_eval: 1
|
39 |
-
val_images_use: 5000
|
40 |
-
max_epochs: 15
|
41 |
-
train_sample_n: 5
|
42 |
-
|
43 |
-
REFORWARD: false
|
44 |
-
|
45 |
-
# _BASE_: transformer.yml
|
46 |
-
reduce_on_plateau: false
|
47 |
-
noamopt: false
|
48 |
-
learning_rate: 0.000005
|
49 |
-
learning_rate_decay_start: -1
|
50 |
-
|
51 |
-
self_critical_after: 15
|
52 |
-
max_epochs: 40
|
53 |
-
|
54 |
-
verbose: false
|
55 |
-
precision: 32
|
56 |
-
|
57 |
-
use_clipscore: true
|
58 |
-
clipscore_reward_weight: 2.0
|
59 |
-
clipscore_mode: clip_s
|
60 |
-
|
61 |
-
use_multi_rewards: true
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/phase2/clipRN50_clips.yml
DELETED
@@ -1,58 +0,0 @@
|
|
1 |
-
caption_model: transformer
|
2 |
-
noamopt: true
|
3 |
-
noamopt_warmup: 20000
|
4 |
-
label_smoothing: 0.0
|
5 |
-
input_json: data/cocotalk.json
|
6 |
-
input_label_h5: data/cocotalk_label.h5
|
7 |
-
input_fc_dir: data/cocotalk_clip_RN50_fc
|
8 |
-
input_att_dir: data/cocotalk_clip_RN50_att
|
9 |
-
input_clipscore_vis_dir: data/cocotalk_clipscore_vis
|
10 |
-
seq_per_img: 5
|
11 |
-
batch_size: 160
|
12 |
-
learning_rate: 0.0005
|
13 |
-
|
14 |
-
checkpoint_path: save/clipRN50_clips/clipRN50_clips
|
15 |
-
|
16 |
-
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
|
17 |
-
# N=num_layers
|
18 |
-
# d_model=input_encoding_size
|
19 |
-
# d_ff=rnn_size
|
20 |
-
|
21 |
-
# will be ignored
|
22 |
-
num_layers: 6
|
23 |
-
input_encoding_size: 512
|
24 |
-
rnn_size: 2048
|
25 |
-
|
26 |
-
# Transformer config
|
27 |
-
N_enc: 6
|
28 |
-
N_dec: 6
|
29 |
-
d_model: 512
|
30 |
-
d_ff: 2048
|
31 |
-
num_att_heads: 8
|
32 |
-
dropout: 0.1
|
33 |
-
|
34 |
-
|
35 |
-
learning_rate_decay_start: 0
|
36 |
-
scheduled_sampling_start: -1
|
37 |
-
save_checkpoint_every: 3000
|
38 |
-
language_eval: 1
|
39 |
-
val_images_use: 5000
|
40 |
-
max_epochs: 15
|
41 |
-
train_sample_n: 5
|
42 |
-
|
43 |
-
REFORWARD: false
|
44 |
-
|
45 |
-
# _BASE_: transformer.yml
|
46 |
-
reduce_on_plateau: false
|
47 |
-
noamopt: false
|
48 |
-
learning_rate: 0.000005
|
49 |
-
learning_rate_decay_start: -1
|
50 |
-
|
51 |
-
self_critical_after: 15
|
52 |
-
max_epochs: 40
|
53 |
-
|
54 |
-
verbose: false
|
55 |
-
precision: 32
|
56 |
-
|
57 |
-
use_clipscore: true
|
58 |
-
clipscore_reward_weight: 2.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/phase2/clipRN50_clips_grammar.yml
DELETED
@@ -1,64 +0,0 @@
|
|
1 |
-
caption_model: transformer
|
2 |
-
noamopt: true
|
3 |
-
noamopt_warmup: 20000
|
4 |
-
label_smoothing: 0.0
|
5 |
-
input_json: data/cocotalk.json
|
6 |
-
input_label_h5: data/cocotalk_label.h5
|
7 |
-
input_fc_dir: data/cocotalk_clip_RN50_fc
|
8 |
-
input_att_dir: data/cocotalk_clip_RN50_att
|
9 |
-
input_clipscore_vis_dir: data/cocotalk_clipscore_vis
|
10 |
-
seq_per_img: 5
|
11 |
-
batch_size: 160
|
12 |
-
learning_rate: 0.0005
|
13 |
-
|
14 |
-
checkpoint_path: save/clipRN50_clips_grammar/clipRN50_clips_grammar
|
15 |
-
|
16 |
-
use_multi_rewards: true
|
17 |
-
use_grammar: true
|
18 |
-
use_grammar_baseline: true
|
19 |
-
# clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt'
|
20 |
-
clip_load_path: 'retrieval/save/clip_negative_text/clip_negative_text-epoch=12.ckpt'
|
21 |
-
|
22 |
-
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
|
23 |
-
# N=num_layers
|
24 |
-
# d_model=input_encoding_size
|
25 |
-
# d_ff=rnn_size
|
26 |
-
|
27 |
-
# will be ignored
|
28 |
-
num_layers: 6
|
29 |
-
input_encoding_size: 512
|
30 |
-
rnn_size: 2048
|
31 |
-
|
32 |
-
# Transformer config
|
33 |
-
N_enc: 6
|
34 |
-
N_dec: 6
|
35 |
-
d_model: 512
|
36 |
-
d_ff: 2048
|
37 |
-
num_att_heads: 8
|
38 |
-
dropout: 0.1
|
39 |
-
|
40 |
-
|
41 |
-
learning_rate_decay_start: 0
|
42 |
-
scheduled_sampling_start: -1
|
43 |
-
save_checkpoint_every: 3000
|
44 |
-
language_eval: 1
|
45 |
-
val_images_use: 5000
|
46 |
-
max_epochs: 15
|
47 |
-
train_sample_n: 5
|
48 |
-
|
49 |
-
REFORWARD: false
|
50 |
-
|
51 |
-
# _BASE_: transformer.yml
|
52 |
-
reduce_on_plateau: false
|
53 |
-
noamopt: false
|
54 |
-
learning_rate: 0.000005
|
55 |
-
learning_rate_decay_start: -1
|
56 |
-
|
57 |
-
self_critical_after: 15
|
58 |
-
max_epochs: 40
|
59 |
-
|
60 |
-
verbose: false
|
61 |
-
precision: 32
|
62 |
-
|
63 |
-
use_clipscore: true
|
64 |
-
clipscore_reward_weight: 2.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/phase2/transformer.yml
DELETED
@@ -1,41 +0,0 @@
|
|
1 |
-
caption_model: transformer
|
2 |
-
noamopt: true
|
3 |
-
noamopt_warmup: 20000
|
4 |
-
label_smoothing: 0.0
|
5 |
-
input_json: data/cocotalk.json
|
6 |
-
input_label_h5: data/cocotalk_label.h5
|
7 |
-
input_att_dir: data/cocotalk_att
|
8 |
-
seq_per_img: 5
|
9 |
-
batch_size: 10
|
10 |
-
learning_rate: 0.0005
|
11 |
-
|
12 |
-
checkpoint_path: ./save/trans_rn50_sc
|
13 |
-
|
14 |
-
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
|
15 |
-
# N=num_layers
|
16 |
-
# d_model=input_encoding_size
|
17 |
-
# d_ff=rnn_size
|
18 |
-
|
19 |
-
# will be ignored
|
20 |
-
num_layers: 6
|
21 |
-
input_encoding_size: 512
|
22 |
-
rnn_size: 2048
|
23 |
-
|
24 |
-
# Transformer config
|
25 |
-
N_enc: 6
|
26 |
-
N_dec: 6
|
27 |
-
d_model: 512
|
28 |
-
d_ff: 2048
|
29 |
-
num_att_heads: 8
|
30 |
-
dropout: 0.1
|
31 |
-
|
32 |
-
|
33 |
-
learning_rate_decay_start: 0
|
34 |
-
scheduled_sampling_start: -1
|
35 |
-
save_checkpoint_every: 3000
|
36 |
-
language_eval: 1
|
37 |
-
val_images_use: 5000
|
38 |
-
max_epochs: 15
|
39 |
-
train_sample_n: 5
|
40 |
-
|
41 |
-
REFORWARD: false
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data/README.md
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
directory to store preprocessed files
|
|
|
|
retrieval/README.md
DELETED
@@ -1,5 +0,0 @@
|
|
1 |
-
# Finetuning CLIP reward model
|
2 |
-
|
3 |
-
```bash
|
4 |
-
python train_pl.py --cfg clip_negative_text --id clip_negative_text
|
5 |
-
```
|
|
|
|
|
|
|
|
|
|
|
|
retrieval/caption_data.py
DELETED
@@ -1,500 +0,0 @@
|
|
1 |
-
from torch.utils.data import DataLoader, Dataset, Sampler
|
2 |
-
from pathlib import Path
|
3 |
-
import json
|
4 |
-
from multiprocessing import Pool
|
5 |
-
from tqdm import tqdm
|
6 |
-
from PIL import Image
|
7 |
-
import random
|
8 |
-
import numpy as np
|
9 |
-
import torch
|
10 |
-
import torchvision
|
11 |
-
import torchvision.transforms as T
|
12 |
-
|
13 |
-
from torch.utils.data.distributed import DistributedSampler
|
14 |
-
|
15 |
-
from transformers import T5Tokenizer, BertTokenizer, BertTokenizerFast, CLIPTokenizer
|
16 |
-
|
17 |
-
import text_utils
|
18 |
-
|
19 |
-
project_dir = Path(__file__).parent.resolve()
|
20 |
-
workspace_dir = project_dir.parent.parent
|
21 |
-
dataset_dir = workspace_dir.joinpath('datasets/').resolve()
|
22 |
-
# coco_dir = dataset_dir.joinpath('COCO')
|
23 |
-
# vg_dir = dataset_dir.joinpath('VG')
|
24 |
-
coco_img_dir = dataset_dir.joinpath('COCO/images/')
|
25 |
-
coco_data_dir = project_dir.parent.joinpath('CLIP-ViL/CLIP-ViL-Direct/caption/data/')
|
26 |
-
# coco_feature_dir = coco_dir.joinpath('features')
|
27 |
-
|
28 |
-
|
29 |
-
class COCORetrievalDataset(Dataset):
|
30 |
-
def __init__(self, split='karpathy_train', rank=-1, topk=-1, verbose=True, args=None, mode='train'):
|
31 |
-
super().__init__()
|
32 |
-
|
33 |
-
self.topk = topk
|
34 |
-
self.verbose = verbose
|
35 |
-
self.args = args
|
36 |
-
self.rank = rank
|
37 |
-
self.mode = mode
|
38 |
-
|
39 |
-
# Loading datasets to data
|
40 |
-
self.source = split
|
41 |
-
if self.verbose:
|
42 |
-
print('Data source: ', self.source)
|
43 |
-
|
44 |
-
# if self.args.tokenizer is None:
|
45 |
-
# self.args.tokenizer = self.args.decoder_backbone
|
46 |
-
|
47 |
-
# if 'bert' in self.args.tokenizer:
|
48 |
-
# self.tokenizer = BertTokenizerFast.from_pretrained(
|
49 |
-
# self.args.tokenizer,
|
50 |
-
# # max_length=self.args.max_text_length,
|
51 |
-
# # do_lower_case=self.args.do_lower_case
|
52 |
-
# )
|
53 |
-
# elif 'clip' in self.args.tokenizer:
|
54 |
-
# self.tokenizer = CLIPTokenizer.from_pretrained(
|
55 |
-
# self.args.tokenizer,
|
56 |
-
# # max_length=self.args.max_text_length,
|
57 |
-
# # do_lower_case=self.args.do_lower_case
|
58 |
-
# )
|
59 |
-
|
60 |
-
self.tokenizer = CLIPTokenizer.from_pretrained(
|
61 |
-
self.args.tokenizer,
|
62 |
-
# max_length=self.args.max_text_length,
|
63 |
-
# do_lower_case=self.args.do_lower_case
|
64 |
-
)
|
65 |
-
|
66 |
-
with open(coco_data_dir.joinpath('cocotalk.json')) as f:
|
67 |
-
self.vocab = list(json.load(f)['ix_to_word'].values())
|
68 |
-
popped = self.vocab.pop(-1)
|
69 |
-
assert popped == 'UNK'
|
70 |
-
if self.verbose:
|
71 |
-
print('vocab size: ', len(self.vocab))
|
72 |
-
|
73 |
-
|
74 |
-
data_info_path = coco_data_dir.joinpath('dataset_coco.json')
|
75 |
-
with open(data_info_path) as f:
|
76 |
-
karpathy_data = json.load(f)
|
77 |
-
|
78 |
-
split_rename = {
|
79 |
-
'train': 'train',
|
80 |
-
'restval': 'train',
|
81 |
-
'val': 'val',
|
82 |
-
'test': 'test'
|
83 |
-
}
|
84 |
-
|
85 |
-
n_images = 0
|
86 |
-
|
87 |
-
data = []
|
88 |
-
# self.vocab = set()
|
89 |
-
for datum in karpathy_data['images']:
|
90 |
-
re_split = split_rename[datum['split']]
|
91 |
-
|
92 |
-
# if re_split == 'train':
|
93 |
-
# for d in datum['sentences']:
|
94 |
-
# self.vocab = self.vocab.union(set(d['tokens']))
|
95 |
-
|
96 |
-
if re_split != self.source.split('_')[-1]:
|
97 |
-
continue
|
98 |
-
|
99 |
-
if re_split == 'train':
|
100 |
-
# for d in datum['sentences']:
|
101 |
-
# img_id = datum['filename'].split('.')[0]
|
102 |
-
# new_datum = {
|
103 |
-
# 'filename': datum['filename'],
|
104 |
-
# 'img_id': img_id,
|
105 |
-
# 'sent': d['raw'].strip(),
|
106 |
-
# 'targets': [d['raw'].strip() for d in datum['sentences']],
|
107 |
-
# 'is_train': True,
|
108 |
-
# 'cocoid': datum['cocoid']
|
109 |
-
# }
|
110 |
-
# data.append(new_datum)
|
111 |
-
img_id = datum['filename'].split('.')[0]
|
112 |
-
new_datum = {
|
113 |
-
'filename': datum['filename'],
|
114 |
-
'img_id': img_id,
|
115 |
-
# 'sent': d['raw'],
|
116 |
-
# 'targets': [d['raw'].strip() for d in datum['sentences']],
|
117 |
-
'targets': [" ".join(d['tokens']) for d in datum['sentences']],
|
118 |
-
'is_train': True,
|
119 |
-
'cocoid': datum['cocoid']
|
120 |
-
}
|
121 |
-
data.append(new_datum)
|
122 |
-
|
123 |
-
else:
|
124 |
-
img_id = datum['filename'].split('.')[0]
|
125 |
-
new_datum = {
|
126 |
-
'filename': datum['filename'],
|
127 |
-
'img_id': img_id,
|
128 |
-
# 'sent': d['raw'],
|
129 |
-
# 'targets': [d['raw'].strip() for d in datum['sentences']],
|
130 |
-
'targets': [" ".join(d['tokens']) for d in datum['sentences']],
|
131 |
-
'is_train': False,
|
132 |
-
'cocoid': datum['cocoid']
|
133 |
-
}
|
134 |
-
data.append(new_datum)
|
135 |
-
|
136 |
-
n_images += 1
|
137 |
-
|
138 |
-
if self.verbose:
|
139 |
-
print(f"{self.source} has {n_images} images")
|
140 |
-
# print(f"Loaded {len(data)} data from", split)
|
141 |
-
|
142 |
-
self.n_gpus = torch.cuda.device_count()
|
143 |
-
|
144 |
-
if self.topk > 0:
|
145 |
-
data = data[:self.topk]
|
146 |
-
if self.verbose:
|
147 |
-
print(f"Use only {self.topk} data")
|
148 |
-
|
149 |
-
self.data = data
|
150 |
-
|
151 |
-
# if self.verbose:
|
152 |
-
# print("# all sentences:", len(self.data))
|
153 |
-
|
154 |
-
if self.args.load_feat:
|
155 |
-
# feat_dir = coco_dir.joinpath(''
|
156 |
-
# self.feat_loader = HybridLoader('/scratch-space/CLIP-ViL/CLIP-ViL-Direct/caption/data/cocotalk_clipscore_vis', ext='.npy', in_memory=False)
|
157 |
-
self.feat_loader = HybridLoader(
|
158 |
-
coco_data_dir.joinpath('cocotalk_clipscore_vis'),
|
159 |
-
ext='.npy', in_memory=False)
|
160 |
-
else:
|
161 |
-
if 'openai/clip' in self.args.encoder_backbone:
|
162 |
-
# from transformers import CLIPProcessor
|
163 |
-
# self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32",
|
164 |
-
# size=args.image_size,
|
165 |
-
# do_resize=True,
|
166 |
-
# do_center_crop=False,
|
167 |
-
# )
|
168 |
-
# self.img_transform = lambda image: self.processor.feature_extractor(
|
169 |
-
# image,
|
170 |
-
# return_tensors='pt')['pixel_values'][0]
|
171 |
-
|
172 |
-
self.image_mean = [0.48145466, 0.4578275, 0.40821073]
|
173 |
-
self.image_std = [0.26862954, 0.26130258, 0.27577711]
|
174 |
-
|
175 |
-
# captioning
|
176 |
-
# self.img_transform = T.Compose([
|
177 |
-
# T.Resize((self.args.image_size, self.args.image_size))
|
178 |
-
# ])
|
179 |
-
|
180 |
-
# retrieval
|
181 |
-
self.img_transform = T.Compose([
|
182 |
-
T.Resize(self.args.image_size, interpolation=T.functional.InterpolationMode.BICUBIC),
|
183 |
-
T.CenterCrop(self.args.image_size)
|
184 |
-
])
|
185 |
-
|
186 |
-
self.img_tensor_transform = T.Compose([
|
187 |
-
# T.RandomCrop(224),
|
188 |
-
# T.RandomHorizontalFlip(p=0.3),
|
189 |
-
T.ConvertImageDtype(torch.float),
|
190 |
-
T.Normalize(self.image_mean, self.image_std)
|
191 |
-
]
|
192 |
-
)
|
193 |
-
# elif 'google/vit' in self.args.encoder_backbone:
|
194 |
-
# self.image_mean = [0.5, 0.5, 0.5]
|
195 |
-
# self.image_std = [0.5, 0.5, 0.5]
|
196 |
-
|
197 |
-
# self.img_transform = T.Compose([
|
198 |
-
# # T.PILToTensor(),
|
199 |
-
# T.Resize((self.args.image_size, self.args.image_size))
|
200 |
-
# ])
|
201 |
-
|
202 |
-
# self.img_tensor_transform = T.Compose([
|
203 |
-
# # T.RandomCrop(224),
|
204 |
-
# # T.RandomHorizontalFlip(p=0.3),
|
205 |
-
# T.ConvertImageDtype(torch.float),
|
206 |
-
# T.Normalize(self.image_mean, self.image_std)
|
207 |
-
# ]
|
208 |
-
# )
|
209 |
-
|
210 |
-
def get_negative_text(self, text):
|
211 |
-
neg_type = random.choice(['repeat', 'remove', 'insert', 'swap', 'shuffle'])
|
212 |
-
|
213 |
-
if neg_type == 'repeat':
|
214 |
-
text = text_utils.repeat(text)
|
215 |
-
elif neg_type == 'remove':
|
216 |
-
text = text_utils.remove(text)
|
217 |
-
elif neg_type == 'insert':
|
218 |
-
text = text_utils.insert(text, self.vocab)
|
219 |
-
elif neg_type == 'swap':
|
220 |
-
text = text_utils.swap(text, self.vocab)
|
221 |
-
elif neg_type == 'shuffle':
|
222 |
-
text = text_utils.shuffle(text)
|
223 |
-
|
224 |
-
return text, neg_type
|
225 |
-
|
226 |
-
def __len__(self):
|
227 |
-
return len(self.data)
|
228 |
-
|
229 |
-
def __getitem__(self, idx):
|
230 |
-
datum = self.data[idx]
|
231 |
-
return self.process_datum(datum)
|
232 |
-
|
233 |
-
def process_datum(self, datum):
|
234 |
-
out_dict = {}
|
235 |
-
|
236 |
-
###### Image ######
|
237 |
-
|
238 |
-
if self.args.load_feat:
|
239 |
-
cocoid = datum['cocoid']
|
240 |
-
out_dict['cocoid'] = str(cocoid)
|
241 |
-
img_feat = self.feat_loader.get(str(cocoid))
|
242 |
-
out_dict['img_feat'] = torch.from_numpy(img_feat)
|
243 |
-
|
244 |
-
else:
|
245 |
-
img_id = datum['img_id']
|
246 |
-
out_dict['img_id'] = img_id
|
247 |
-
|
248 |
-
if 'train' in datum['filename']:
|
249 |
-
img_split = 'train2014'
|
250 |
-
elif 'val' in datum['filename']:
|
251 |
-
img_split = 'val2014'
|
252 |
-
img_path = coco_img_dir.joinpath(img_split).joinpath(datum['filename']).with_suffix('.jpg')
|
253 |
-
assert img_path.exists()
|
254 |
-
img_path = str(img_path)
|
255 |
-
out_dict['img_path'] = img_path
|
256 |
-
|
257 |
-
img_tensor = torchvision.io.read_image(img_path)
|
258 |
-
# out_dict['img_tensor'] = img
|
259 |
-
|
260 |
-
# img = Image.open(img_path).convert('RGB')
|
261 |
-
# img_tensor = torch.as_tensor(np.asarray(img))
|
262 |
-
out_dict['img_tensor'] = self.img_transform(img_tensor)
|
263 |
-
# self.img_transform(img_tensor)
|
264 |
-
# out_dict['img_tensor'] = self.img_transform(img)
|
265 |
-
|
266 |
-
###### Text #####
|
267 |
-
# if datum['is_train']:
|
268 |
-
# sent = datum['sent'].strip()
|
269 |
-
|
270 |
-
sent = random.choice(datum['targets'])
|
271 |
-
|
272 |
-
# target_ids = self.tokenizer.encode(
|
273 |
-
# sent, max_length=self.args.gen_max_length, truncation=True)
|
274 |
-
|
275 |
-
# assert len(target_ids) <= self.args.gen_max_length, len(target_ids)
|
276 |
-
out_dict['sent'] = sent
|
277 |
-
# out_dict['target_ids'] = torch.LongTensor(target_ids)
|
278 |
-
# out_dict['target_length'] = len(target_ids)
|
279 |
-
|
280 |
-
|
281 |
-
# negative sample
|
282 |
-
neg_sent, neg_type = self.get_negative_text(sent)
|
283 |
-
|
284 |
-
# neg_target_ids = self.tokenizer.encode(
|
285 |
-
# neg_sent, max_length=self.args.gen_max_length, truncation=True)
|
286 |
-
|
287 |
-
# assert len(neg_target_ids) <= self.args.gen_max_length, len(neg_target_ids)
|
288 |
-
out_dict['neg_sent'] = neg_sent
|
289 |
-
out_dict['neg_type'] = neg_type
|
290 |
-
# out_dict['neg_target_ids'] = torch.LongTensor(neg_target_ids)
|
291 |
-
# out_dict['neg_target_length'] = len(neg_target_ids)
|
292 |
-
|
293 |
-
|
294 |
-
if 'targets' in datum:
|
295 |
-
out_dict['targets'] = datum['targets']
|
296 |
-
|
297 |
-
return out_dict
|
298 |
-
|
299 |
-
def collate_fn(self, batch):
|
300 |
-
batch_entry = {}
|
301 |
-
|
302 |
-
B = len(batch)
|
303 |
-
|
304 |
-
# if 'target_ids' in batch[0]:
|
305 |
-
# T_W_L = max(entry['target_length'] for entry in batch)
|
306 |
-
# target_ids = torch.ones(
|
307 |
-
# B, T_W_L, dtype=torch.long) * self.tokenizer.pad_token_id
|
308 |
-
|
309 |
-
# if 'target_ids' in batch[0]:
|
310 |
-
# T_W_L = max(entry['target_length'] for entry in batch)
|
311 |
-
# target_ids = torch.ones(
|
312 |
-
# B, T_W_L, dtype=torch.long) * self.tokenizer.pad_token_id
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
targets = []
|
317 |
-
img_ids = []
|
318 |
-
img_paths = []
|
319 |
-
|
320 |
-
coco_ids = []
|
321 |
-
|
322 |
-
if self.args.load_feat:
|
323 |
-
img_feats = torch.zeros(B, 512, dtype=torch.float)
|
324 |
-
else:
|
325 |
-
# imgs = []
|
326 |
-
img_tensor = torch.zeros(B, 3, self.args.image_size, self.args.image_size, dtype=torch.uint8)
|
327 |
-
|
328 |
-
for i, entry in enumerate(batch):
|
329 |
-
|
330 |
-
if self.args.load_feat:
|
331 |
-
coco_ids.append(entry['cocoid'])
|
332 |
-
img_feats[i] = entry['img_feat']
|
333 |
-
|
334 |
-
else:
|
335 |
-
|
336 |
-
img_ids.append(entry['img_id'])
|
337 |
-
img_paths.append(entry['img_path'])
|
338 |
-
img_tensor[i] = entry['img_tensor']
|
339 |
-
|
340 |
-
# if 'target_ids' in entry:
|
341 |
-
# target_ids[i, :entry['target_length']] = entry['target_ids']
|
342 |
-
|
343 |
-
if 'targets' in entry:
|
344 |
-
targets.append(entry['targets'])
|
345 |
-
|
346 |
-
if 'sent' in batch[0]:
|
347 |
-
# word_mask = target_ids != self.tokenizer.pad_token_id
|
348 |
-
# target_ids[~word_mask] = -100
|
349 |
-
# batch_entry['target_ids'] = target_ids
|
350 |
-
|
351 |
-
tokenized = self.tokenizer([entry['sent'] for entry in batch], truncation=True, padding=True, return_tensors='pt')
|
352 |
-
neg_tokenized = self.tokenizer([entry['neg_sent'] for entry in batch], truncation=True, padding=True, return_tensors='pt')
|
353 |
-
# sent, max_length=self.args.gen_max_length, truncation=True)
|
354 |
-
|
355 |
-
batch_entry['text'] = (tokenized.input_ids, tokenized.attention_mask)
|
356 |
-
batch_entry['neg_text'] = (neg_tokenized.input_ids, neg_tokenized.attention_mask)
|
357 |
-
|
358 |
-
|
359 |
-
if self.args.load_feat:
|
360 |
-
batch_entry['coco_ids'] = coco_ids
|
361 |
-
batch_entry['img_feats'] = img_feats
|
362 |
-
|
363 |
-
else:
|
364 |
-
|
365 |
-
img_tensor = self.img_tensor_transform(img_tensor)
|
366 |
-
|
367 |
-
batch_entry['img_id'] = img_ids
|
368 |
-
batch_entry['img_paths'] = img_paths
|
369 |
-
batch_entry['img_tensor'] = img_tensor
|
370 |
-
|
371 |
-
batch_entry['targets'] = targets
|
372 |
-
|
373 |
-
# print('batch created')
|
374 |
-
|
375 |
-
# batch_entry['task'] = 'caption'
|
376 |
-
|
377 |
-
return batch_entry
|
378 |
-
|
379 |
-
|
380 |
-
# def get_loader(args, split='karpathy_train', mode='train',
|
381 |
-
# batch_size=32, workers=4, distributed=False, gpu=0,
|
382 |
-
# topk=-1):
|
383 |
-
|
384 |
-
# verbose = (gpu == 0)
|
385 |
-
|
386 |
-
# dataset = COCORetrievalDataset(
|
387 |
-
# split,
|
388 |
-
# rank=gpu,
|
389 |
-
# topk=topk,
|
390 |
-
# verbose=verbose,
|
391 |
-
# args=args,
|
392 |
-
# mode=mode)
|
393 |
-
|
394 |
-
# # if distributed:
|
395 |
-
# # sampler = DistributedSampler(dataset)
|
396 |
-
# # else:
|
397 |
-
# # sampler = None
|
398 |
-
|
399 |
-
# if mode == 'train':
|
400 |
-
# loader = DataLoader(
|
401 |
-
# dataset, batch_size=batch_size, shuffle=(sampler is None),
|
402 |
-
# num_workers=workers, pin_memory=True, sampler=sampler,
|
403 |
-
# collate_fn=dataset.collate_fn)
|
404 |
-
# else:
|
405 |
-
# loader = DataLoader(
|
406 |
-
# dataset,
|
407 |
-
# batch_size=batch_size, shuffle=False,
|
408 |
-
# num_workers=workers, pin_memory=True,
|
409 |
-
# sampler=sampler,
|
410 |
-
# collate_fn=dataset.collate_fn,
|
411 |
-
# drop_last=False)
|
412 |
-
|
413 |
-
# # if verbose:
|
414 |
-
# # loader.evaluator = COCOCaptionEvaluator()
|
415 |
-
|
416 |
-
# # loader.task = 'caption'
|
417 |
-
|
418 |
-
# return loader
|
419 |
-
|
420 |
-
|
421 |
-
# class COCOCaptionEvaluator:
|
422 |
-
# def __init__(self):
|
423 |
-
# import language_evaluation
|
424 |
-
# self.evaluator = language_evaluation.CocoEvaluator(verbose=False)
|
425 |
-
|
426 |
-
# def evaluate(self, predicts, answers):
|
427 |
-
|
428 |
-
# results = self.evaluator.run_evaluation(predicts, answers)
|
429 |
-
|
430 |
-
# return results
|
431 |
-
|
432 |
-
import six
|
433 |
-
import os
|
434 |
-
import h5py
|
435 |
-
|
436 |
-
class HybridLoader:
|
437 |
-
"""
|
438 |
-
If db_path is a director, then use normal file loading
|
439 |
-
If lmdb, then load from lmdb
|
440 |
-
The loading method depend on extention.
|
441 |
-
|
442 |
-
in_memory: if in_memory is True, we save all the features in memory
|
443 |
-
For individual np(y|z)s, we don't need to do that because the system will do this for us.
|
444 |
-
Should be useful for lmdb or h5.
|
445 |
-
(Copied this idea from vilbert)
|
446 |
-
"""
|
447 |
-
|
448 |
-
def __init__(self, db_path, ext='.npy', in_memory=False):
|
449 |
-
self.db_path = db_path
|
450 |
-
self.ext = ext
|
451 |
-
if self.ext == '.npy':
|
452 |
-
self.loader = lambda x: np.load(six.BytesIO(x))
|
453 |
-
else:
|
454 |
-
self.loader = lambda x: np.load(six.BytesIO(x))['feat']
|
455 |
-
# if db_path.endswith('.lmdb'):
|
456 |
-
# self.db_type = 'lmdb'
|
457 |
-
# self.lmdb = lmdbdict(db_path, unsafe=True)
|
458 |
-
# self.lmdb._key_dumps = DUMPS_FUNC['ascii']
|
459 |
-
# self.lmdb._value_loads = LOADS_FUNC['identity']
|
460 |
-
# elif db_path.endswith('.pth'): # Assume a key,value dictionary
|
461 |
-
# self.db_type = 'pth'
|
462 |
-
# self.feat_file = torch.load(db_path)
|
463 |
-
# self.loader = lambda x: x
|
464 |
-
# print('HybridLoader: ext is ignored')
|
465 |
-
# elif db_path.endswith('h5'):
|
466 |
-
# self.db_type = 'h5'
|
467 |
-
# self.loader = lambda x: np.array(x).astype('float32')
|
468 |
-
# else:
|
469 |
-
# self.db_type = 'dir'
|
470 |
-
|
471 |
-
self.in_memory = in_memory
|
472 |
-
if self.in_memory:
|
473 |
-
self.features = {}
|
474 |
-
|
475 |
-
def get(self, key):
|
476 |
-
|
477 |
-
# if self.in_memory and key in self.features:
|
478 |
-
# # We save f_input because we want to save the
|
479 |
-
# # compressed bytes to save memory
|
480 |
-
# f_input = self.features[key]
|
481 |
-
# elif self.db_type == 'lmdb':
|
482 |
-
# f_input = self.lmdb[key]
|
483 |
-
# elif self.db_type == 'pth':
|
484 |
-
# f_input = self.feat_file[key]
|
485 |
-
# elif self.db_type == 'h5':
|
486 |
-
# f_input = h5py.File(self.db_path, 'r')[key]
|
487 |
-
# else:
|
488 |
-
# f_input = open(os.path.join(
|
489 |
-
# self.db_path, key + self.ext), 'rb').read()
|
490 |
-
|
491 |
-
f_input = open(os.path.join(
|
492 |
-
self.db_path, key + self.ext), 'rb').read()
|
493 |
-
|
494 |
-
if self.in_memory and key not in self.features:
|
495 |
-
self.features[key] = f_input
|
496 |
-
|
497 |
-
# load image
|
498 |
-
feat = self.loader(f_input)
|
499 |
-
|
500 |
-
return feat
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
retrieval/clip_model.py
DELETED
@@ -1,350 +0,0 @@
|
|
1 |
-
from transformers import CLIPModel, CLIPTokenizer
|
2 |
-
import os
|
3 |
-
import json
|
4 |
-
import argparse
|
5 |
-
from random import shuffle, seed
|
6 |
-
import string
|
7 |
-
# non-standard dependencies:
|
8 |
-
import h5py
|
9 |
-
from six.moves import cPickle
|
10 |
-
import numpy as np
|
11 |
-
import torch
|
12 |
-
import torchvision.models as models
|
13 |
-
import skimage.io
|
14 |
-
|
15 |
-
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
16 |
-
from PIL import Image
|
17 |
-
from torch import nn
|
18 |
-
|
19 |
-
|
20 |
-
class CLIPScore(nn.Module):
|
21 |
-
def __init__(self, clipscore_w=2.5, image_size=224, mode='clip_s', use_grammar=False, joint_out=False):
|
22 |
-
super(CLIPScore, self).__init__()
|
23 |
-
# from transformers import CLIPModel, CLIPTokenizer
|
24 |
-
self.clip_model = CLIPModel.from_pretrained(
|
25 |
-
'openai/clip-vit-base-patch32')
|
26 |
-
self.tokenizer = CLIPTokenizer.from_pretrained(
|
27 |
-
'openai/clip-vit-base-patch32')
|
28 |
-
|
29 |
-
self.clip_model.eval()
|
30 |
-
|
31 |
-
self.clipscore_w = clipscore_w
|
32 |
-
|
33 |
-
self.image_transform = self._transform(image_size)
|
34 |
-
|
35 |
-
self.mode = mode
|
36 |
-
assert mode in ['clip_s', 'refclip_s']
|
37 |
-
|
38 |
-
self.use_grammar = use_grammar
|
39 |
-
self.joint_out = joint_out
|
40 |
-
|
41 |
-
if self.use_grammar and self.joint_out is False:
|
42 |
-
self.grammar_score_head = nn.Sequential(
|
43 |
-
nn.Linear(self.clip_model.text_embed_dim, self.clip_model.projection_dim, bias=False),
|
44 |
-
nn.ReLU(),
|
45 |
-
nn.Linear(self.clip_model.projection_dim, 2, bias=False)
|
46 |
-
)
|
47 |
-
|
48 |
-
def _transform(self, n_px):
|
49 |
-
return Compose([
|
50 |
-
Resize(n_px, interpolation=Image.BICUBIC),
|
51 |
-
CenterCrop(n_px),
|
52 |
-
lambda image: image.convert("RGB"),
|
53 |
-
ToTensor(),
|
54 |
-
Normalize((0.48145466, 0.4578275, 0.40821073),
|
55 |
-
(0.26862954, 0.26130258, 0.27577711)),
|
56 |
-
])
|
57 |
-
|
58 |
-
def load_image(self, image_path):
|
59 |
-
image = Image.open(image_path)
|
60 |
-
return image
|
61 |
-
|
62 |
-
# @torch.no_grad()
|
63 |
-
def image_extract(self, image):
|
64 |
-
if isinstance(image, str):
|
65 |
-
image = self.load_image(image)
|
66 |
-
if not isinstance(image, torch.Tensor):
|
67 |
-
image = self.image_transform(image)
|
68 |
-
|
69 |
-
img_tensor = image.view(-1, 3, 224, 224)
|
70 |
-
device = next(self.clip_model.parameters()).device
|
71 |
-
img_tensor = img_tensor.to(device)
|
72 |
-
|
73 |
-
clip_model = self.clip_model
|
74 |
-
|
75 |
-
img_feat = clip_model.vision_model(img_tensor).pooler_output
|
76 |
-
img_feat = clip_model.visual_projection(img_feat)
|
77 |
-
img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True)
|
78 |
-
|
79 |
-
return img_feat
|
80 |
-
|
81 |
-
# @torch.no_grad()
|
82 |
-
def text_extract(self, text, prompt="A photo depicts", proj_norm=True):
|
83 |
-
if isinstance(text, str):
|
84 |
-
text_batch = [" ".join([prompt, text])]
|
85 |
-
elif isinstance(text, list):
|
86 |
-
text_batch = [" ".join([prompt, txt]) for txt in text]
|
87 |
-
|
88 |
-
if isinstance(text, tuple) and isinstance(text[0], torch.Tensor):
|
89 |
-
input_ids, attention_mask = text
|
90 |
-
else:
|
91 |
-
input_text = text_batch
|
92 |
-
|
93 |
-
tokenized = self.tokenizer(
|
94 |
-
input_text, return_tensors='pt', padding=True)
|
95 |
-
|
96 |
-
input_ids = tokenized.input_ids
|
97 |
-
attention_mask = tokenized.attention_mask
|
98 |
-
|
99 |
-
clip_model = self.clip_model
|
100 |
-
device = next(self.clip_model.parameters()).device
|
101 |
-
input_ids = input_ids.to(device)
|
102 |
-
attention_mask = attention_mask.to(device)
|
103 |
-
|
104 |
-
text_feat = clip_model.text_model(input_ids, attention_mask).pooler_output
|
105 |
-
|
106 |
-
if proj_norm:
|
107 |
-
text_feat = clip_model.text_projection(text_feat)
|
108 |
-
text_feat = text_feat / text_feat.norm(dim=-1, keepdim=True)
|
109 |
-
|
110 |
-
return text_feat
|
111 |
-
|
112 |
-
# @torch.no_grad()
|
113 |
-
def calc_clip_s(self, img_feat, text_feat):
|
114 |
-
return self.clipscore_w * torch.relu((img_feat * text_feat).sum(dim=-1))
|
115 |
-
|
116 |
-
# @torch.no_grad()
|
117 |
-
def calc_refclip_s(self, img_feat=None, text_feat=None, ref_text_feat=None, ref_text_mask=None, clip_s=None):
|
118 |
-
|
119 |
-
if clip_s is None:
|
120 |
-
clip_s = self.calc_clip_s(img_feat, text_feat)
|
121 |
-
|
122 |
-
B, dim = img_feat.size()
|
123 |
-
|
124 |
-
ref_text_feat = ref_text_feat.view(B, -1, dim)
|
125 |
-
|
126 |
-
K = ref_text_feat.size(1)
|
127 |
-
|
128 |
-
text_feat = text_feat.view(B, 1, dim).expand(-1, K, -1)
|
129 |
-
assert ref_text_feat.size() == text_feat.size(
|
130 |
-
), (ref_text_feat.size(), text_feat.size())
|
131 |
-
|
132 |
-
ref_score = self.calc_clip_s(text_feat, ref_text_feat)
|
133 |
-
if ref_text_mask is not None:
|
134 |
-
if not isinstance(ref_text_mask, torch.Tensor):
|
135 |
-
ref_text_mask = torch.tensor(
|
136 |
-
ref_text_mask, dtype=ref_score.dtype, device=ref_score.device)
|
137 |
-
ref_score = ref_score.view(B, K) * ref_text_mask.view(B, K)
|
138 |
-
|
139 |
-
ref_score = ref_score.view(B, K).max(dim=1).values
|
140 |
-
|
141 |
-
assert clip_s.size() == (B,)
|
142 |
-
assert clip_s.size() == ref_score.size()
|
143 |
-
|
144 |
-
# harmonic mean
|
145 |
-
refclip_s = 2 / (1 / clip_s + 1 / ref_score)
|
146 |
-
return refclip_s
|
147 |
-
|
148 |
-
# # @torch.no_grad()
|
149 |
-
# def forward(self,
|
150 |
-
# images=None, text=None,
|
151 |
-
# img_feat=None, text_feat=None,
|
152 |
-
# ref_text=None, ref_text_feat=None, ref_text_mask=None,
|
153 |
-
# prompt="A photo depicts",
|
154 |
-
# mode=None):
|
155 |
-
# if img_feat is None:
|
156 |
-
# img_feat = self.image_extract(images)
|
157 |
-
# img_feat = img_feat.view(-1, 512)
|
158 |
-
|
159 |
-
# if text_feat is None:
|
160 |
-
# text_feat = self.text_extract(text, prompt=prompt)
|
161 |
-
# text_feat = text_feat.view(-1, 512)
|
162 |
-
|
163 |
-
# if mode is None:
|
164 |
-
# mode = self.mode
|
165 |
-
# assert mode in ['clip_s', 'refclip_s']
|
166 |
-
|
167 |
-
# if mode == 'clip_s':
|
168 |
-
# clip_s = self.calc_clip_s(img_feat, text_feat)
|
169 |
-
# return clip_s
|
170 |
-
# elif mode == 'refclip_s':
|
171 |
-
# if ref_text_feat is None:
|
172 |
-
# ref_text_feat = self.text_extract(ref_text, prompt=prompt)
|
173 |
-
# ref_text_feat = ref_text_feat.view(-1, 512)
|
174 |
-
|
175 |
-
# refclip_s = self.calc_refclip_s(
|
176 |
-
# img_feat, text_feat, ref_text_feat, ref_text_mask=ref_text_mask)
|
177 |
-
# return refclip_s
|
178 |
-
|
179 |
-
|
180 |
-
def train_step(self,
|
181 |
-
images=None, text=None,
|
182 |
-
img_feat=None, text_feat=None,
|
183 |
-
neg_text=None, neg_text_feat=None,
|
184 |
-
# ref_text=None, ref_text_feat=None, ref_text_mask=None,
|
185 |
-
prompt="A photo depicts",
|
186 |
-
# return_loss=True,
|
187 |
-
**kwargs):
|
188 |
-
|
189 |
-
if img_feat is None:
|
190 |
-
img_feat = self.image_extract(images)
|
191 |
-
img_feat = img_feat.view(-1, 512)
|
192 |
-
|
193 |
-
B = img_feat.size(0)
|
194 |
-
|
195 |
-
if self.joint_out:
|
196 |
-
pos_text_feat = self.text_extract(text, prompt=prompt, proj_norm=False).view(B, 512)
|
197 |
-
neg_text_feat = self.text_extract(neg_text, prompt=prompt, proj_norm=False).view(-1, 512)
|
198 |
-
neg_B = neg_text_feat.size(0)
|
199 |
-
|
200 |
-
# [B+neg_B, 512]
|
201 |
-
text_feat = torch.cat([pos_text_feat, neg_text_feat], dim=0)
|
202 |
-
|
203 |
-
text_cont_feat = self.clip_model.text_projection(text_feat)
|
204 |
-
text_cont_feat = text_cont_feat / text_cont_feat.norm(dim=-1, keepdim=True)
|
205 |
-
|
206 |
-
text_cont_feat = text_cont_feat.view(B+neg_B, 512)
|
207 |
-
|
208 |
-
logit_scale = self.clip_model.logit_scale.exp()
|
209 |
-
|
210 |
-
# [B+neg_B * B]
|
211 |
-
logits_per_text = torch.matmul(text_cont_feat, img_feat.t()) * logit_scale
|
212 |
-
|
213 |
-
# image-to-text label: positive text
|
214 |
-
caption_loss = -torch.diag(nn.functional.log_softmax(logits_per_text, dim=0)[:B]).mean()
|
215 |
-
|
216 |
-
# calculate text-to-image only on positive text
|
217 |
-
image_loss = -torch.diag(nn.functional.log_softmax(logits_per_text[:B], dim=1)).mean()
|
218 |
-
|
219 |
-
clip_loss = (caption_loss + image_loss) / 2.0
|
220 |
-
|
221 |
-
out = {
|
222 |
-
'clip_loss': clip_loss,
|
223 |
-
'img_feat': img_feat,
|
224 |
-
'text_feat': text_cont_feat[:B].detach(),
|
225 |
-
# 'neg_text_feat': neg_text_feat,
|
226 |
-
}
|
227 |
-
|
228 |
-
return out
|
229 |
-
|
230 |
-
|
231 |
-
else:
|
232 |
-
if text_feat is None:
|
233 |
-
text_feat = self.text_extract(text, prompt=prompt, proj_norm=False)
|
234 |
-
|
235 |
-
text_cont_feat = self.clip_model.text_projection(text_feat)
|
236 |
-
text_cont_feat = text_cont_feat / \
|
237 |
-
text_cont_feat.norm(dim=-1, keepdim=True)
|
238 |
-
|
239 |
-
text_cont_feat = text_cont_feat.view(B, 512)
|
240 |
-
|
241 |
-
|
242 |
-
# cosine similarity as logits
|
243 |
-
logit_scale = self.clip_model.logit_scale.exp()
|
244 |
-
logits_per_text = torch.matmul(text_cont_feat, img_feat.t()) * logit_scale
|
245 |
-
# logits_per_image = logits_per_text.T
|
246 |
-
|
247 |
-
clip_loss = clip_loss_fn(logits_per_text)
|
248 |
-
|
249 |
-
|
250 |
-
# negative sampling
|
251 |
-
pos_text_feat = text_feat.view(B, 512)
|
252 |
-
neg_text_feat = self.text_extract(neg_text, prompt=prompt, proj_norm=False).view(B, 512)
|
253 |
-
|
254 |
-
grammar_text_feat = torch.cat([pos_text_feat, neg_text_feat], dim=0)
|
255 |
-
|
256 |
-
# 2B, 1
|
257 |
-
grammar_text_logit = self.grammar_score_head(grammar_text_feat)
|
258 |
-
grammar_labels = torch.LongTensor([1] * B + [0] * B).to(grammar_text_logit.device).view(2 * B)
|
259 |
-
|
260 |
-
grammar_loss = torch.nn.functional.cross_entropy(grammar_text_logit, grammar_labels)
|
261 |
-
|
262 |
-
grammar_pred = grammar_text_logit.argmax(dim=1, keepdim=False)
|
263 |
-
grammar_pos_pred = grammar_pred[:B]
|
264 |
-
grammar_neg_pred = grammar_pred[B:]
|
265 |
-
# grammar_acc = (grammar_pred == grammar_labels).float().mean()
|
266 |
-
|
267 |
-
out = {
|
268 |
-
'clip_loss': clip_loss,
|
269 |
-
'grammar_loss': grammar_loss,
|
270 |
-
'img_feat': img_feat,
|
271 |
-
'text_feat': text_cont_feat,
|
272 |
-
'neg_text_feat': neg_text_feat,
|
273 |
-
'grammar_pos_pred': grammar_pos_pred,
|
274 |
-
'grammar_neg_pred': grammar_neg_pred,
|
275 |
-
}
|
276 |
-
|
277 |
-
return out
|
278 |
-
|
279 |
-
def train_step_old(self,
|
280 |
-
images=None, text=None,
|
281 |
-
img_feat=None, text_feat=None,
|
282 |
-
neg_text=None, neg_text_feat=None,
|
283 |
-
# ref_text=None, ref_text_feat=None, ref_text_mask=None,
|
284 |
-
prompt="A photo depicts",
|
285 |
-
# return_loss=True,
|
286 |
-
**kwargs):
|
287 |
-
|
288 |
-
if img_feat is None:
|
289 |
-
img_feat = self.image_extract(images)
|
290 |
-
img_feat = img_feat.view(-1, 512)
|
291 |
-
|
292 |
-
B = img_feat.size(0)
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
if text_feat is None:
|
297 |
-
text_feat = self.text_extract(text, prompt=prompt, proj_norm=False)
|
298 |
-
|
299 |
-
text_cont_feat = self.clip_model.text_projection(text_feat)
|
300 |
-
text_cont_feat = text_cont_feat / text_cont_feat.norm(dim=-1, keepdim=True)
|
301 |
-
text_cont_feat = text_cont_feat.view(B, 512)
|
302 |
-
|
303 |
-
# cosine similarity as logits
|
304 |
-
logit_scale = self.clip_model.logit_scale.exp()
|
305 |
-
logits_per_text = torch.matmul(text_cont_feat, img_feat.t()) * logit_scale
|
306 |
-
# logits_per_image = logits_per_text.T
|
307 |
-
|
308 |
-
clip_loss = clip_loss_fn(logits_per_text)
|
309 |
-
|
310 |
-
|
311 |
-
# negative sampling
|
312 |
-
pos_text_feat = text_feat.view(B, 512)
|
313 |
-
neg_text_feat = self.text_extract(neg_text, prompt=prompt, proj_norm=False).view(B, 512)
|
314 |
-
|
315 |
-
grammar_text_feat = torch.cat([pos_text_feat, neg_text_feat], dim=0)
|
316 |
-
|
317 |
-
# 2B, 1
|
318 |
-
grammar_text_logit = self.grammar_score_head(grammar_text_feat)
|
319 |
-
grammar_labels = torch.LongTensor([1] * B + [0] * B).to(grammar_text_logit.device).view(2 * B)
|
320 |
-
|
321 |
-
grammar_loss = torch.nn.functional.cross_entropy(grammar_text_logit, grammar_labels)
|
322 |
-
|
323 |
-
grammar_pred = grammar_text_logit.argmax(dim=1, keepdim=False)
|
324 |
-
grammar_pos_pred = grammar_pred[:B]
|
325 |
-
grammar_neg_pred = grammar_pred[B:]
|
326 |
-
# grammar_acc = (grammar_pred == grammar_labels).float().mean()
|
327 |
-
|
328 |
-
out = {
|
329 |
-
'clip_loss': clip_loss,
|
330 |
-
'grammar_loss': grammar_loss,
|
331 |
-
'img_feat': img_feat,
|
332 |
-
'text_feat': text_cont_feat,
|
333 |
-
'neg_text_feat': neg_text_feat,
|
334 |
-
'grammar_pos_pred': grammar_pos_pred,
|
335 |
-
'grammar_neg_pred': grammar_neg_pred,
|
336 |
-
}
|
337 |
-
|
338 |
-
return out
|
339 |
-
|
340 |
-
# contrastive loss function, adapted from
|
341 |
-
# https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html
|
342 |
-
def contrastive_loss(logits: torch.Tensor, dim: int) -> torch.Tensor:
|
343 |
-
neg_ce = torch.diag(nn.functional.log_softmax(logits, dim=dim))
|
344 |
-
return -neg_ce.mean()
|
345 |
-
|
346 |
-
|
347 |
-
def clip_loss_fn(similarity: torch.Tensor) -> torch.Tensor:
|
348 |
-
caption_loss = contrastive_loss(similarity, dim=0)
|
349 |
-
image_loss = contrastive_loss(similarity, dim=1)
|
350 |
-
return (caption_loss + image_loss) / 2.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
retrieval/configs/clip_negative_text.yaml
DELETED
@@ -1,14 +0,0 @@
|
|
1 |
-
checkpoint_dir: ./save/clip_negative_text/
|
2 |
-
|
3 |
-
losses_log_every: 25
|
4 |
-
precision: 32
|
5 |
-
load_feat: true
|
6 |
-
data_in_memory: false
|
7 |
-
|
8 |
-
batch_size: 1600
|
9 |
-
valid_batch_size: 200
|
10 |
-
clip_grad_norm: 0
|
11 |
-
|
12 |
-
epochs: 30
|
13 |
-
use_grammar: true
|
14 |
-
joint_out: false
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
retrieval/param.py
DELETED
@@ -1,209 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
import random
|
3 |
-
|
4 |
-
import numpy as np
|
5 |
-
import torch
|
6 |
-
|
7 |
-
import pprint
|
8 |
-
import yaml
|
9 |
-
|
10 |
-
|
11 |
-
def str2bool(v):
|
12 |
-
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
13 |
-
return True
|
14 |
-
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
|
15 |
-
return False
|
16 |
-
else:
|
17 |
-
raise argparse.ArgumentTypeError('Boolean value expected.')
|
18 |
-
|
19 |
-
|
20 |
-
def is_interactive():
|
21 |
-
import __main__ as main
|
22 |
-
return not hasattr(main, '__file__')
|
23 |
-
|
24 |
-
|
25 |
-
def get_optimizer(optim, verbose=False):
|
26 |
-
# Bind the optimizer
|
27 |
-
if optim == 'rms':
|
28 |
-
if verbose:
|
29 |
-
print("Optimizer: Using RMSProp")
|
30 |
-
optimizer = torch.optim.RMSprop
|
31 |
-
elif optim == 'adam':
|
32 |
-
if verbose:
|
33 |
-
print("Optimizer: Using Adam")
|
34 |
-
optimizer = torch.optim.Adam
|
35 |
-
elif optim == 'adamw':
|
36 |
-
if verbose:
|
37 |
-
print("Optimizer: Using AdamW")
|
38 |
-
# optimizer = torch.optim.AdamW
|
39 |
-
optimizer = 'adamw'
|
40 |
-
elif optim == 'adamax':
|
41 |
-
if verbose:
|
42 |
-
print("Optimizer: Using Adamax")
|
43 |
-
optimizer = torch.optim.Adamax
|
44 |
-
elif optim == 'sgd':
|
45 |
-
if verbose:
|
46 |
-
print("Optimizer: SGD")
|
47 |
-
optimizer = torch.optim.SGD
|
48 |
-
else:
|
49 |
-
assert False, "Please add your optimizer %s in the list." % optim
|
50 |
-
|
51 |
-
return optimizer
|
52 |
-
|
53 |
-
|
54 |
-
def parse_args(parse=True, **optional_kwargs):
|
55 |
-
parser = argparse.ArgumentParser()
|
56 |
-
|
57 |
-
parser.add_argument('--seed', type=int, default=9595, help='random seed')
|
58 |
-
|
59 |
-
# Data Splits
|
60 |
-
parser.add_argument("--train", default='karpathy_train')
|
61 |
-
parser.add_argument("--valid", default='karpathy_val')
|
62 |
-
parser.add_argument("--test", default='karpathy_test')
|
63 |
-
# parser.add_argument('--test_only', action='store_true')
|
64 |
-
|
65 |
-
# Quick experiments
|
66 |
-
parser.add_argument('--train_topk', type=int, default=-1)
|
67 |
-
parser.add_argument('--valid_topk', type=int, default=-1)
|
68 |
-
|
69 |
-
# Checkpoint
|
70 |
-
parser.add_argument('--output', type=str, default='snap/test')
|
71 |
-
parser.add_argument('--load', type=str, default=None, help='Load the model (usually the fine-tuned model).')
|
72 |
-
parser.add_argument('--from_scratch', action='store_true')
|
73 |
-
|
74 |
-
# CPU/GPU
|
75 |
-
parser.add_argument("--multiGPU", action='store_const', default=False, const=True)
|
76 |
-
parser.add_argument('--fp16', action='store_true')
|
77 |
-
parser.add_argument("--distributed", action='store_true')
|
78 |
-
parser.add_argument("--num_workers", default=0, type=int)
|
79 |
-
parser.add_argument('--local_rank', type=int, default=-1)
|
80 |
-
# parser.add_argument('--rank', type=int, default=-1)
|
81 |
-
|
82 |
-
# Model Config
|
83 |
-
# parser.add_argument('--encoder_backbone', type=str, default='openai/clip-vit-base-patch32')
|
84 |
-
# parser.add_argument('--decoder_backbone', type=str, default='bert-base-uncased')
|
85 |
-
parser.add_argument('--tokenizer', type=str, default='openai/clip-vit-base-patch32')
|
86 |
-
|
87 |
-
# parser.add_argument('--position_embedding_type', type=str, default='absolute')
|
88 |
-
|
89 |
-
# parser.add_argument('--encoder_transform', action='store_true')
|
90 |
-
|
91 |
-
parser.add_argument('--max_text_length', type=int, default=40)
|
92 |
-
|
93 |
-
# parser.add_argument('--image_size', type=int, default=224)
|
94 |
-
# parser.add_argument('--patch_size', type=int, default=32)
|
95 |
-
|
96 |
-
# parser.add_argument('--decoder_num_layers', type=int, default=12)
|
97 |
-
|
98 |
-
# Training
|
99 |
-
parser.add_argument('--batch_size', type=int, default=256)
|
100 |
-
parser.add_argument('--valid_batch_size', type=int, default=None)
|
101 |
-
|
102 |
-
parser.add_argument('--optim', default='adamw')
|
103 |
-
|
104 |
-
parser.add_argument('--warmup_ratio', type=float, default=0.05)
|
105 |
-
parser.add_argument('--weight_decay', type=float, default=0.01)
|
106 |
-
parser.add_argument('--clip_grad_norm', type=float, default=-1.0)
|
107 |
-
parser.add_argument('--gradient_accumulation_steps', type=int, default=1)
|
108 |
-
parser.add_argument('--lr', type=float, default=1e-4)
|
109 |
-
parser.add_argument('--adam_eps', type=float, default=1e-6)
|
110 |
-
parser.add_argument('--adam_beta1', type=float, default=0.9)
|
111 |
-
parser.add_argument('--adam_beta2', type=float, default=0.999)
|
112 |
-
|
113 |
-
parser.add_argument('--epochs', type=int, default=20)
|
114 |
-
# parser.add_argument('--dropout', type=float, default=0.1)
|
115 |
-
|
116 |
-
|
117 |
-
# Inference
|
118 |
-
# parser.add_argument('--num_beams', type=int, default=1)
|
119 |
-
# parser.add_argument('--gen_max_length', type=int, default=20)
|
120 |
-
|
121 |
-
parser.add_argument('--start_from', type=str, default=None)
|
122 |
-
|
123 |
-
# Data
|
124 |
-
# parser.add_argument('--do_lower_case', type=str2bool, default=None)
|
125 |
-
|
126 |
-
# parser.add_argument('--prefix', type=str, default=None)
|
127 |
-
|
128 |
-
|
129 |
-
# COCO Caption
|
130 |
-
# parser.add_argument('--no_prefix', action='store_true')
|
131 |
-
|
132 |
-
parser.add_argument('--no_cls', action='store_true')
|
133 |
-
|
134 |
-
parser.add_argument('--cfg', type=str, default=None)
|
135 |
-
parser.add_argument('--id', type=str, default=None)
|
136 |
-
|
137 |
-
# Etc.
|
138 |
-
parser.add_argument('--comment', type=str, default='')
|
139 |
-
parser.add_argument("--dry", action='store_true')
|
140 |
-
|
141 |
-
# Parse the arguments.
|
142 |
-
if parse:
|
143 |
-
args = parser.parse_args()
|
144 |
-
# For interative engironmnet (ex. jupyter)
|
145 |
-
else:
|
146 |
-
args = parser.parse_known_args()[0]
|
147 |
-
|
148 |
-
loaded_kwargs = {}
|
149 |
-
if args.cfg is not None:
|
150 |
-
cfg_path = f'configs/{args.cfg}.yaml'
|
151 |
-
with open(cfg_path, 'r') as f:
|
152 |
-
loaded_kwargs = yaml.safe_load(f)
|
153 |
-
|
154 |
-
# Namespace => Dictionary
|
155 |
-
parsed_kwargs = vars(args)
|
156 |
-
parsed_kwargs.update(optional_kwargs)
|
157 |
-
|
158 |
-
kwargs = {}
|
159 |
-
kwargs.update(parsed_kwargs)
|
160 |
-
kwargs.update(loaded_kwargs)
|
161 |
-
|
162 |
-
args = Config(**kwargs)
|
163 |
-
|
164 |
-
# Bind optimizer class.
|
165 |
-
verbose = False
|
166 |
-
args.optimizer = get_optimizer(args.optim, verbose=verbose)
|
167 |
-
|
168 |
-
# Set seeds
|
169 |
-
torch.manual_seed(args.seed)
|
170 |
-
random.seed(args.seed)
|
171 |
-
np.random.seed(args.seed)
|
172 |
-
|
173 |
-
return args
|
174 |
-
|
175 |
-
|
176 |
-
class Config(object):
|
177 |
-
def __init__(self, **kwargs):
|
178 |
-
"""Configuration Class: set kwargs as class attributes with setattr"""
|
179 |
-
for k, v in kwargs.items():
|
180 |
-
setattr(self, k, v)
|
181 |
-
|
182 |
-
@property
|
183 |
-
def config_str(self):
|
184 |
-
return pprint.pformat(self.__dict__)
|
185 |
-
|
186 |
-
def __repr__(self):
|
187 |
-
"""Pretty-print configurations in alphabetical order"""
|
188 |
-
config_str = 'Configurations\n'
|
189 |
-
config_str += self.config_str
|
190 |
-
return config_str
|
191 |
-
|
192 |
-
# def update(self, **kwargs):
|
193 |
-
# for k, v in kwargs.items():
|
194 |
-
# setattr(self, k, v)
|
195 |
-
|
196 |
-
# def save(self, path):
|
197 |
-
# with open(path, 'w') as f:
|
198 |
-
# yaml.dump(self.__dict__, f, default_flow_style=False)
|
199 |
-
|
200 |
-
# @classmethod
|
201 |
-
# def load(cls, path):
|
202 |
-
# with open(path, 'r') as f:
|
203 |
-
# kwargs = yaml.load(f)
|
204 |
-
|
205 |
-
# return Config(**kwargs)
|
206 |
-
|
207 |
-
|
208 |
-
if __name__ == '__main__':
|
209 |
-
args = parse_args(True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
retrieval/pth_loader.py
DELETED
@@ -1,334 +0,0 @@
|
|
1 |
-
from __future__ import absolute_import
|
2 |
-
from __future__ import division
|
3 |
-
from __future__ import print_function
|
4 |
-
|
5 |
-
import json
|
6 |
-
import h5py
|
7 |
-
from lmdbdict import lmdbdict
|
8 |
-
from lmdbdict.methods import DUMPS_FUNC, LOADS_FUNC
|
9 |
-
import os
|
10 |
-
import numpy as np
|
11 |
-
import numpy.random as npr
|
12 |
-
import random
|
13 |
-
|
14 |
-
import torch
|
15 |
-
import torch.utils.data as data
|
16 |
-
|
17 |
-
import multiprocessing
|
18 |
-
import six
|
19 |
-
|
20 |
-
verbose = True
|
21 |
-
# import torch
|
22 |
-
# if torch.cuda.current_device() in [0, -1]:
|
23 |
-
if 'LOCAL_RANK' in os.environ and os.environ['LOCAL_RANK'] != '0':
|
24 |
-
verbose = False
|
25 |
-
|
26 |
-
class HybridLoader:
|
27 |
-
"""
|
28 |
-
If db_path is a director, then use normal file loading
|
29 |
-
If lmdb, then load from lmdb
|
30 |
-
The loading method depend on extention.
|
31 |
-
|
32 |
-
in_memory: if in_memory is True, we save all the features in memory
|
33 |
-
For individual np(y|z)s, we don't need to do that because the system will do this for us.
|
34 |
-
Should be useful for lmdb or h5.
|
35 |
-
(Copied this idea from vilbert)
|
36 |
-
"""
|
37 |
-
def __init__(self, db_path, ext, in_memory=False):
|
38 |
-
self.db_path = db_path
|
39 |
-
self.ext = ext
|
40 |
-
if self.ext == '.npy':
|
41 |
-
self.loader = lambda x: np.load(six.BytesIO(x))
|
42 |
-
else:
|
43 |
-
self.loader = lambda x: np.load(six.BytesIO(x))['feat']
|
44 |
-
if db_path.endswith('.lmdb'):
|
45 |
-
self.db_type = 'lmdb'
|
46 |
-
self.lmdb = lmdbdict(db_path, unsafe=True)
|
47 |
-
self.lmdb._key_dumps = DUMPS_FUNC['ascii']
|
48 |
-
self.lmdb._value_loads = LOADS_FUNC['identity']
|
49 |
-
elif db_path.endswith('.pth'): # Assume a key,value dictionary
|
50 |
-
self.db_type = 'pth'
|
51 |
-
self.feat_file = torch.load(db_path)
|
52 |
-
self.loader = lambda x: x
|
53 |
-
print('HybridLoader: ext is ignored')
|
54 |
-
elif db_path.endswith('h5'):
|
55 |
-
self.db_type = 'h5'
|
56 |
-
self.loader = lambda x: np.array(x).astype('float32')
|
57 |
-
else:
|
58 |
-
self.db_type = 'dir'
|
59 |
-
|
60 |
-
self.in_memory = in_memory
|
61 |
-
if self.in_memory:
|
62 |
-
self.features = {}
|
63 |
-
|
64 |
-
def get(self, key):
|
65 |
-
|
66 |
-
if self.in_memory and key in self.features:
|
67 |
-
# We save f_input because we want to save the
|
68 |
-
# compressed bytes to save memory
|
69 |
-
f_input = self.features[key]
|
70 |
-
elif self.db_type == 'lmdb':
|
71 |
-
f_input = self.lmdb[key]
|
72 |
-
elif self.db_type == 'pth':
|
73 |
-
f_input = self.feat_file[key]
|
74 |
-
elif self.db_type == 'h5':
|
75 |
-
f_input = h5py.File(self.db_path, 'r')[key]
|
76 |
-
else:
|
77 |
-
f_input = open(os.path.join(self.db_path, key + self.ext), 'rb').read()
|
78 |
-
|
79 |
-
if self.in_memory and key not in self.features:
|
80 |
-
self.features[key] = f_input
|
81 |
-
|
82 |
-
# load image
|
83 |
-
feat = self.loader(f_input)
|
84 |
-
|
85 |
-
return feat
|
86 |
-
|
87 |
-
class CaptionDataset(data.Dataset):
|
88 |
-
|
89 |
-
def get_vocab_size(self):
|
90 |
-
return self.vocab_size
|
91 |
-
|
92 |
-
def get_vocab(self):
|
93 |
-
return self.ix_to_word
|
94 |
-
|
95 |
-
def get_seq_length(self):
|
96 |
-
return self.seq_length
|
97 |
-
|
98 |
-
def __init__(self, opt):
|
99 |
-
self.opt = opt
|
100 |
-
self.seq_per_img = opt.seq_per_img
|
101 |
-
|
102 |
-
# feature related options
|
103 |
-
self.use_fc = getattr(opt, 'use_fc', True)
|
104 |
-
self.use_att = getattr(opt, 'use_att', True)
|
105 |
-
self.use_box = getattr(opt, 'use_box', 0)
|
106 |
-
self.norm_att_feat = getattr(opt, 'norm_att_feat', 0)
|
107 |
-
self.norm_box_feat = getattr(opt, 'norm_box_feat', 0)
|
108 |
-
|
109 |
-
# load the json file which contains additional information about the dataset
|
110 |
-
if verbose:
|
111 |
-
print('DataLoader loading json file: ', opt.input_json)
|
112 |
-
self.info = json.load(open(self.opt.input_json))
|
113 |
-
if 'ix_to_word' in self.info:
|
114 |
-
self.ix_to_word = self.info['ix_to_word']
|
115 |
-
self.vocab_size = len(self.ix_to_word)
|
116 |
-
if verbose:
|
117 |
-
print('vocab size is ', self.vocab_size)
|
118 |
-
|
119 |
-
# open the hdf5 file
|
120 |
-
if verbose:
|
121 |
-
print('DataLoader loading h5 file: ', opt.input_fc_dir, opt.input_att_dir, opt.input_box_dir, opt.input_label_h5)
|
122 |
-
"""
|
123 |
-
Setting input_label_h5 to none is used when only doing generation.
|
124 |
-
For example, when you need to test on coco test set.
|
125 |
-
"""
|
126 |
-
if self.opt.input_label_h5 != 'none':
|
127 |
-
self.h5_label_file = h5py.File(self.opt.input_label_h5, 'r', driver='core')
|
128 |
-
# load in the sequence data
|
129 |
-
seq_size = self.h5_label_file['labels'].shape
|
130 |
-
self.label = self.h5_label_file['labels'][:]
|
131 |
-
self.seq_length = seq_size[1]
|
132 |
-
if verbose:
|
133 |
-
print('max sequence length in data is', self.seq_length)
|
134 |
-
# load the pointers in full to RAM (should be small enough)
|
135 |
-
self.label_start_ix = self.h5_label_file['label_start_ix'][:]
|
136 |
-
self.label_end_ix = self.h5_label_file['label_end_ix'][:]
|
137 |
-
else:
|
138 |
-
self.seq_length = 1
|
139 |
-
|
140 |
-
self.data_in_memory = getattr(opt, 'data_in_memory', False)
|
141 |
-
self.fc_loader = HybridLoader(self.opt.input_fc_dir, '.npy', in_memory=self.data_in_memory)
|
142 |
-
self.att_loader = HybridLoader(self.opt.input_att_dir, '.npz', in_memory=self.data_in_memory)
|
143 |
-
self.box_loader = HybridLoader(self.opt.input_box_dir, '.npy', in_memory=self.data_in_memory)
|
144 |
-
|
145 |
-
self.use_clipscore = getattr(opt, 'use_clipscore', False)
|
146 |
-
if self.use_clipscore:
|
147 |
-
self.clipscore_loader = HybridLoader(self.opt.input_clipscore_vis_dir, '.npy', in_memory=self.data_in_memory)
|
148 |
-
|
149 |
-
|
150 |
-
self.num_images = len(self.info['images']) # self.label_start_ix.shape[0]
|
151 |
-
if verbose:
|
152 |
-
print('read %d image features' %(self.num_images))
|
153 |
-
|
154 |
-
# separate out indexes for each of the provided splits
|
155 |
-
self.split_ix = {'train': [], 'val': [], 'test': []}
|
156 |
-
for ix in range(len(self.info['images'])):
|
157 |
-
img = self.info['images'][ix]
|
158 |
-
if not 'split' in img:
|
159 |
-
self.split_ix['train'].append(ix)
|
160 |
-
self.split_ix['val'].append(ix)
|
161 |
-
self.split_ix['test'].append(ix)
|
162 |
-
elif img['split'] == 'train':
|
163 |
-
self.split_ix['train'].append(ix)
|
164 |
-
elif img['split'] == 'val':
|
165 |
-
self.split_ix['val'].append(ix)
|
166 |
-
elif img['split'] == 'test':
|
167 |
-
self.split_ix['test'].append(ix)
|
168 |
-
elif opt.train_only == 0: # restval
|
169 |
-
self.split_ix['train'].append(ix)
|
170 |
-
|
171 |
-
if verbose:
|
172 |
-
print('assigned %d images to split train' %len(self.split_ix['train']))
|
173 |
-
print('assigned %d images to split val' %len(self.split_ix['val']))
|
174 |
-
print('assigned %d images to split test' %len(self.split_ix['test']))
|
175 |
-
|
176 |
-
def get_captions(self, ix, seq_per_img):
|
177 |
-
# fetch the sequence labels
|
178 |
-
ix1 = self.label_start_ix[ix] - 1 #label_start_ix starts from 1
|
179 |
-
ix2 = self.label_end_ix[ix] - 1
|
180 |
-
ncap = ix2 - ix1 + 1 # number of captions available for this image
|
181 |
-
assert ncap > 0, 'an image does not have any label. this can be handled but right now isn\'t'
|
182 |
-
|
183 |
-
if ncap < seq_per_img:
|
184 |
-
# we need to subsample (with replacement)
|
185 |
-
seq = np.zeros([seq_per_img, self.seq_length], dtype = 'int')
|
186 |
-
for q in range(seq_per_img):
|
187 |
-
ixl = random.randint(ix1,ix2)
|
188 |
-
seq[q, :] = self.label[ixl, :self.seq_length]
|
189 |
-
else:
|
190 |
-
ixl = random.randint(ix1, ix2 - seq_per_img + 1)
|
191 |
-
seq = self.label[ixl: ixl + seq_per_img, :self.seq_length]
|
192 |
-
|
193 |
-
return seq
|
194 |
-
|
195 |
-
def collate_func(self, batch):
|
196 |
-
seq_per_img = self.seq_per_img
|
197 |
-
|
198 |
-
fc_batch = []
|
199 |
-
att_batch = []
|
200 |
-
label_batch = []
|
201 |
-
|
202 |
-
clip_vis_feat_batch = []
|
203 |
-
|
204 |
-
wrapped = False
|
205 |
-
|
206 |
-
infos = []
|
207 |
-
gts = []
|
208 |
-
|
209 |
-
for sample in batch:
|
210 |
-
# fetch image
|
211 |
-
if self.use_clipscore:
|
212 |
-
tmp_fc, tmp_att, tmp_seq, \
|
213 |
-
ix, tmp_clip_vis_feat = sample
|
214 |
-
|
215 |
-
clip_vis_feat_batch.append(tmp_clip_vis_feat)
|
216 |
-
else:
|
217 |
-
tmp_fc, tmp_att, tmp_seq, \
|
218 |
-
ix = sample
|
219 |
-
|
220 |
-
fc_batch.append(tmp_fc)
|
221 |
-
att_batch.append(tmp_att)
|
222 |
-
|
223 |
-
tmp_label = np.zeros([seq_per_img, self.seq_length + 2], dtype = 'int')
|
224 |
-
if hasattr(self, 'h5_label_file'):
|
225 |
-
# if there is ground truth
|
226 |
-
tmp_label[:, 1 : self.seq_length + 1] = tmp_seq
|
227 |
-
label_batch.append(tmp_label)
|
228 |
-
|
229 |
-
# Used for reward evaluation
|
230 |
-
if hasattr(self, 'h5_label_file'):
|
231 |
-
# if there is ground truth
|
232 |
-
gts.append(self.label[self.label_start_ix[ix] - 1: self.label_end_ix[ix]])
|
233 |
-
else:
|
234 |
-
gts.append([])
|
235 |
-
|
236 |
-
# record associated info as well
|
237 |
-
info_dict = {}
|
238 |
-
info_dict['ix'] = ix
|
239 |
-
info_dict['id'] = self.info['images'][ix]['id']
|
240 |
-
info_dict['file_path'] = self.info['images'][ix].get('file_path', '')
|
241 |
-
infos.append(info_dict)
|
242 |
-
|
243 |
-
# #sort by att_feat length
|
244 |
-
# fc_batch, att_batch, label_batch, gts, infos = \
|
245 |
-
# zip(*sorted(zip(fc_batch, att_batch, np.vsplit(label_batch, batch_size), gts, infos), key=lambda x: len(x[1]), reverse=True))
|
246 |
-
if self.use_clipscore:
|
247 |
-
fc_batch, att_batch, label_batch, clip_vis_feat_batch, gts, infos = \
|
248 |
-
zip(*sorted(zip(fc_batch, att_batch, label_batch, clip_vis_feat_batch, gts, infos), key=lambda x: 0, reverse=True))
|
249 |
-
else:
|
250 |
-
fc_batch, att_batch, label_batch, gts, infos = \
|
251 |
-
zip(*sorted(zip(fc_batch, att_batch, label_batch, gts, infos), key=lambda x: 0, reverse=True))
|
252 |
-
data = {}
|
253 |
-
data['fc_feats'] = np.stack(fc_batch)
|
254 |
-
# merge att_feats
|
255 |
-
max_att_len = max([_.shape[0] for _ in att_batch])
|
256 |
-
data['att_feats'] = np.zeros([len(att_batch), max_att_len, att_batch[0].shape[1]], dtype = 'float32')
|
257 |
-
for i in range(len(att_batch)):
|
258 |
-
data['att_feats'][i, :att_batch[i].shape[0]] = att_batch[i]
|
259 |
-
data['att_masks'] = np.zeros(data['att_feats'].shape[:2], dtype='float32')
|
260 |
-
for i in range(len(att_batch)):
|
261 |
-
data['att_masks'][i, :att_batch[i].shape[0]] = 1
|
262 |
-
# set att_masks to None if attention features have same length
|
263 |
-
if data['att_masks'].sum() == data['att_masks'].size:
|
264 |
-
data['att_masks'] = None
|
265 |
-
|
266 |
-
if self.use_clipscore:
|
267 |
-
data['clip_vis_feats'] = np.stack(clip_vis_feat_batch)
|
268 |
-
|
269 |
-
data['labels'] = np.vstack(label_batch)
|
270 |
-
# generate mask
|
271 |
-
nonzeros = np.array(list(map(lambda x: (x != 0).sum()+2, data['labels'])))
|
272 |
-
mask_batch = np.zeros([data['labels'].shape[0], self.seq_length + 2], dtype = 'float32')
|
273 |
-
for ix, row in enumerate(mask_batch):
|
274 |
-
row[:nonzeros[ix]] = 1
|
275 |
-
data['masks'] = mask_batch
|
276 |
-
data['labels'] = data['labels'].reshape(len(batch), seq_per_img, -1)
|
277 |
-
data['masks'] = data['masks'].reshape(len(batch), seq_per_img, -1)
|
278 |
-
|
279 |
-
data['gts'] = gts # all ground truth captions of each images
|
280 |
-
data['infos'] = infos
|
281 |
-
|
282 |
-
data = {k:torch.from_numpy(v) if type(v) is np.ndarray else v for k,v in data.items()} # Turn all ndarray to torch tensor
|
283 |
-
|
284 |
-
return data
|
285 |
-
|
286 |
-
def __getitem__(self, ix):
|
287 |
-
"""This function returns a tuple that is further passed to collate_fn
|
288 |
-
"""
|
289 |
-
if self.use_att:
|
290 |
-
att_feat = self.att_loader.get(str(self.info['images'][ix]['id']))
|
291 |
-
# Reshape to K x C
|
292 |
-
att_feat = att_feat.reshape(-1, att_feat.shape[-1])
|
293 |
-
if self.norm_att_feat:
|
294 |
-
att_feat = att_feat / np.linalg.norm(att_feat, 2, 1, keepdims=True)
|
295 |
-
if self.use_box:
|
296 |
-
box_feat = self.box_loader.get(str(self.info['images'][ix]['id']))
|
297 |
-
# devided by image width and height
|
298 |
-
x1,y1,x2,y2 = np.hsplit(box_feat, 4)
|
299 |
-
h,w = self.info['images'][ix]['height'], self.info['images'][ix]['width']
|
300 |
-
box_feat = np.hstack((x1/w, y1/h, x2/w, y2/h, (x2-x1)*(y2-y1)/(w*h))) # question? x2-x1+1??
|
301 |
-
if self.norm_box_feat:
|
302 |
-
box_feat = box_feat / np.linalg.norm(box_feat, 2, 1, keepdims=True)
|
303 |
-
att_feat = np.hstack([att_feat, box_feat])
|
304 |
-
# sort the features by the size of boxes
|
305 |
-
att_feat = np.stack(sorted(att_feat, key=lambda x:x[-1], reverse=True))
|
306 |
-
else:
|
307 |
-
att_feat = np.zeros((0,0), dtype='float32')
|
308 |
-
if self.use_fc:
|
309 |
-
try:
|
310 |
-
fc_feat = self.fc_loader.get(str(self.info['images'][ix]['id']))
|
311 |
-
except:
|
312 |
-
# Use average of attention when there is no fc provided (For bottomup feature)
|
313 |
-
fc_feat = att_feat.mean(0)
|
314 |
-
else:
|
315 |
-
fc_feat = np.zeros((0), dtype='float32')
|
316 |
-
if hasattr(self, 'h5_label_file'):
|
317 |
-
seq = self.get_captions(ix, self.seq_per_img)
|
318 |
-
else:
|
319 |
-
seq = None
|
320 |
-
|
321 |
-
if self.use_clipscore:
|
322 |
-
clip_vis_feat = self.clipscore_loader.get(
|
323 |
-
str(self.info['images'][ix]['id']))
|
324 |
-
|
325 |
-
return (fc_feat,
|
326 |
-
att_feat, seq,
|
327 |
-
ix, clip_vis_feat)
|
328 |
-
|
329 |
-
return (fc_feat,
|
330 |
-
att_feat, seq,
|
331 |
-
ix)
|
332 |
-
|
333 |
-
def __len__(self):
|
334 |
-
return len(self.info['images'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
retrieval/text_utils.py
DELETED
@@ -1,74 +0,0 @@
|
|
1 |
-
import random
|
2 |
-
|
3 |
-
def repeat(text, n_max_gram=3, n_max_repeat=3):
|
4 |
-
"""repeat n-grams"""
|
5 |
-
tokens = text.split()
|
6 |
-
|
7 |
-
n_gram = random.randint(1, n_max_gram)
|
8 |
-
|
9 |
-
repeat_token_idx = random.randint(0, len(tokens) - n_gram)
|
10 |
-
|
11 |
-
repeated_tokens = tokens[repeat_token_idx:repeat_token_idx+n_gram]
|
12 |
-
|
13 |
-
n_repeat = random.randint(1, n_max_repeat)
|
14 |
-
for _ in range(n_repeat):
|
15 |
-
insert_idx = random.randint(0, len(tokens))
|
16 |
-
tokens = tokens[:insert_idx] + \
|
17 |
-
repeated_tokens + tokens[insert_idx:]
|
18 |
-
|
19 |
-
new_text = " ".join(tokens)
|
20 |
-
return new_text
|
21 |
-
|
22 |
-
def remove(text, n_max_gram=3):
|
23 |
-
"""remove n-grams"""
|
24 |
-
tokens = text.split()
|
25 |
-
|
26 |
-
n_gram = random.randint(1, n_max_gram)
|
27 |
-
|
28 |
-
remove_token_idx = random.randint(0, len(tokens) - n_gram)
|
29 |
-
|
30 |
-
tokens = tokens[:remove_token_idx] + tokens[remove_token_idx + n_gram:]
|
31 |
-
|
32 |
-
new_text = " ".join(tokens)
|
33 |
-
return new_text
|
34 |
-
|
35 |
-
def insert(text, vocab, n_max_tokens=3):
|
36 |
-
"""Insert tokens"""
|
37 |
-
tokens = text.split()
|
38 |
-
|
39 |
-
n_insert_token = random.randint(1, n_max_tokens)
|
40 |
-
|
41 |
-
for _ in range(n_insert_token):
|
42 |
-
insert_token_idx = random.randint(0, len(tokens) - 1)
|
43 |
-
insert_token = random.choice(vocab)
|
44 |
-
tokens = tokens[:insert_token_idx] + [insert_token] + tokens[insert_token_idx:]
|
45 |
-
|
46 |
-
new_text = " ".join(tokens)
|
47 |
-
return new_text
|
48 |
-
|
49 |
-
def swap(text, vocab, n_max_tokens=3):
|
50 |
-
"""Swap tokens"""
|
51 |
-
tokens = text.split()
|
52 |
-
|
53 |
-
n_swap_tokens = random.randint(1, n_max_tokens)
|
54 |
-
|
55 |
-
for _ in range(n_swap_tokens):
|
56 |
-
swap_token_idx = random.randint(0, len(tokens) - 1)
|
57 |
-
|
58 |
-
swap_token = random.choice(vocab)
|
59 |
-
while swap_token == tokens[swap_token_idx]:
|
60 |
-
swap_token = random.choice(vocab)
|
61 |
-
|
62 |
-
tokens[swap_token_idx] = swap_token
|
63 |
-
|
64 |
-
new_text = " ".join(tokens)
|
65 |
-
return new_text
|
66 |
-
|
67 |
-
def shuffle(text):
|
68 |
-
"""shuffle tokens"""
|
69 |
-
tokens = text.split()
|
70 |
-
|
71 |
-
random.shuffle(tokens)
|
72 |
-
|
73 |
-
new_text = " ".join(tokens)
|
74 |
-
return new_text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
retrieval/train_pl.py
DELETED
@@ -1,661 +0,0 @@
|
|
1 |
-
from ast import parse
|
2 |
-
import torch
|
3 |
-
import torch.nn as nn
|
4 |
-
import torch.nn.functional as F
|
5 |
-
import torch.optim as optim
|
6 |
-
|
7 |
-
import numpy as np
|
8 |
-
|
9 |
-
import time
|
10 |
-
import os
|
11 |
-
from collections import defaultdict
|
12 |
-
|
13 |
-
# import captioning.utils.opts as opts
|
14 |
-
# import captioning.models as models
|
15 |
-
# from captioning.data.pth_loader import CaptionDataset
|
16 |
-
# import captioning.utils.eval_utils as eval_utils
|
17 |
-
# import captioning.utils.misc as utils
|
18 |
-
# from captioning.utils.rewards import init_scorer, get_self_critical_reward
|
19 |
-
# from captioning.modules.loss_wrapper import LossWrapper
|
20 |
-
|
21 |
-
from clip_model import CLIPScore
|
22 |
-
from caption_data import COCORetrievalDataset
|
23 |
-
|
24 |
-
import pytorch_lightning as pl
|
25 |
-
|
26 |
-
import detectron2.utils.comm as d2comm
|
27 |
-
from detectron2.utils.env import seed_all_rng
|
28 |
-
seed_all_rng(1234)
|
29 |
-
|
30 |
-
|
31 |
-
class LitModel(pl.LightningModule):
|
32 |
-
def __init__(self, opt):
|
33 |
-
super().__init__()
|
34 |
-
self.opt = opt
|
35 |
-
self.args = args
|
36 |
-
# Intilaize dataset
|
37 |
-
# self.dataset = CaptionDataset(opt)
|
38 |
-
|
39 |
-
# self.dataset =
|
40 |
-
|
41 |
-
# opt.vocab_size = self.dataset.vocab_size
|
42 |
-
# opt.seq_length = self.dataset.seq_length
|
43 |
-
# self.batch_size = opt.batch_size
|
44 |
-
|
45 |
-
# Build model
|
46 |
-
# opt.vocab = self.dataset.get_vocab()
|
47 |
-
# model = models.setup(opt)
|
48 |
-
# print(model)
|
49 |
-
# del opt.vocab
|
50 |
-
|
51 |
-
# wrapper with loss in it.
|
52 |
-
# lw_model = LossWrapper(model, opt)
|
53 |
-
|
54 |
-
self.model = CLIPScore(use_grammar=opt.use_grammar, joint_out=opt.joint_out)
|
55 |
-
# self.lw_model = lw_model
|
56 |
-
|
57 |
-
for p in self.model.clip_model.vision_model.parameters():
|
58 |
-
p.requires_grad = False
|
59 |
-
for p in self.model.clip_model.visual_projection.parameters():
|
60 |
-
p.requires_grad = False
|
61 |
-
|
62 |
-
# self.struc_flag = None
|
63 |
-
# self.sc_flag = None
|
64 |
-
|
65 |
-
|
66 |
-
def forward(self, *args, **kwargs):
|
67 |
-
"""
|
68 |
-
I hate this design. Never pretend it as a nn.Module
|
69 |
-
"""
|
70 |
-
raise NotImplementedError
|
71 |
-
|
72 |
-
def train_dataloader(self):
|
73 |
-
# train_dataset = torch.utils.data.Subset(
|
74 |
-
# self.dataset,
|
75 |
-
# self.dataset.split_ix['train']
|
76 |
-
# )
|
77 |
-
|
78 |
-
# train_loader = torch.utils.data.DataLoader(
|
79 |
-
# dataset=train_dataset,
|
80 |
-
# batch_size=self.batch_size,
|
81 |
-
# shuffle=True,
|
82 |
-
# num_workers=4,
|
83 |
-
# collate_fn=self.dataset.collate_func
|
84 |
-
# )
|
85 |
-
|
86 |
-
train_dataset = COCORetrievalDataset(
|
87 |
-
split='karpathy_train', mode='train',
|
88 |
-
args=opt,
|
89 |
-
verbose=verbose
|
90 |
-
)
|
91 |
-
|
92 |
-
train_loader = torch.utils.data.DataLoader(
|
93 |
-
dataset=train_dataset,
|
94 |
-
batch_size=opt.batch_size,
|
95 |
-
shuffle=True,
|
96 |
-
num_workers=4,
|
97 |
-
collate_fn=train_dataset.collate_fn
|
98 |
-
)
|
99 |
-
|
100 |
-
return train_loader
|
101 |
-
|
102 |
-
def val_dataloader(self, split='karpathy_val'):
|
103 |
-
# val_dataset = torch.utils.data.Subset(
|
104 |
-
# self.dataset,
|
105 |
-
# self.dataset.split_ix[split]
|
106 |
-
# )
|
107 |
-
# val_loader = torch.utils.data.DataLoader(
|
108 |
-
# val_dataset,
|
109 |
-
# batch_size=self.batch_size,
|
110 |
-
# shuffle=False,
|
111 |
-
# num_workers=4,
|
112 |
-
# drop_last=False,
|
113 |
-
# collate_fn=self.dataset.collate_func
|
114 |
-
# )
|
115 |
-
|
116 |
-
val_dataset = COCORetrievalDataset(
|
117 |
-
split=split, mode='val',
|
118 |
-
args=opt,
|
119 |
-
verbose=verbose
|
120 |
-
)
|
121 |
-
|
122 |
-
val_loader = torch.utils.data.DataLoader(
|
123 |
-
dataset=val_dataset,
|
124 |
-
batch_size=opt.valid_batch_size,
|
125 |
-
shuffle=False,
|
126 |
-
num_workers=4,
|
127 |
-
drop_last=False,
|
128 |
-
collate_fn=val_dataset.collate_fn
|
129 |
-
)
|
130 |
-
|
131 |
-
return val_loader
|
132 |
-
|
133 |
-
def test_dataloader(self):
|
134 |
-
|
135 |
-
return self.val_dataloader('karpathy_test')
|
136 |
-
|
137 |
-
def training_step(self, data, batch_idx):
|
138 |
-
|
139 |
-
|
140 |
-
batch = data
|
141 |
-
self.model.train()
|
142 |
-
|
143 |
-
model_out = self.model.train_step(
|
144 |
-
img_feat=batch['img_feats'],
|
145 |
-
text=batch['text'],
|
146 |
-
neg_text=batch['neg_text'],
|
147 |
-
)
|
148 |
-
|
149 |
-
clip_loss = model_out['clip_loss']
|
150 |
-
|
151 |
-
if self.opt.joint_out:
|
152 |
-
loss = clip_loss
|
153 |
-
else:
|
154 |
-
grammar_loss = model_out['grammar_loss']
|
155 |
-
loss = clip_loss + grammar_loss
|
156 |
-
|
157 |
-
|
158 |
-
data_time = self.trainer.profiler.recorded_durations["get_train_batch"][-1]
|
159 |
-
data_time = torch.tensor(data_time)
|
160 |
-
|
161 |
-
# print('batch_idx', batch_idx)
|
162 |
-
# print('loss:', loss)
|
163 |
-
|
164 |
-
# logger_logs = model_out.copy()
|
165 |
-
logger_logs = {}
|
166 |
-
|
167 |
-
logger_logs['loss'] = loss.detach()
|
168 |
-
|
169 |
-
logger_logs['clip_loss'] = clip_loss.detach()
|
170 |
-
|
171 |
-
if not self.opt.joint_out:
|
172 |
-
logger_logs['grammar_loss'] = grammar_loss.detach()
|
173 |
-
|
174 |
-
logger_logs['data_time'] = data_time.detach()
|
175 |
-
|
176 |
-
# UserWarning: The {progress_bar:dict keyword} was deprecated in 0.9.1 and will be removed in 1.0.0
|
177 |
-
# Please use self.log(...) inside the lightningModule instead.
|
178 |
-
|
179 |
-
# # log on a step or aggregate epoch metric to the logger and/or progress bar
|
180 |
-
# # (inside LightningModule)
|
181 |
-
# self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
|
182 |
-
# warnings.warn(*args, **kwargs)
|
183 |
-
# UserWarning: The {log:dict keyword} was deprecated in 0.9.1 and will be removed in 1.0.0
|
184 |
-
# Please use self.log(...) inside the lightningModule instead.
|
185 |
-
|
186 |
-
# output = {
|
187 |
-
# 'loss': loss,
|
188 |
-
# 'log': logger_logs,
|
189 |
-
# 'progress_bar': {'data_time': data_time}
|
190 |
-
# }
|
191 |
-
|
192 |
-
for k, v in logger_logs.items():
|
193 |
-
if k in ['data_time', 'clip_loss', 'grammar_loss']:
|
194 |
-
self.log('train/'+k, v, prog_bar=True)
|
195 |
-
else:
|
196 |
-
self.log('train/'+k, v)
|
197 |
-
|
198 |
-
# print('training step logged')
|
199 |
-
|
200 |
-
return loss
|
201 |
-
|
202 |
-
def validation_step(self, data, batch_idx):
|
203 |
-
|
204 |
-
batch = data
|
205 |
-
self.model.eval()
|
206 |
-
|
207 |
-
with torch.no_grad():
|
208 |
-
model_out = self.model.train_step(
|
209 |
-
img_feat=batch['img_feats'],
|
210 |
-
text=batch['text'],
|
211 |
-
neg_text=batch['neg_text'],
|
212 |
-
)
|
213 |
-
|
214 |
-
if self.opt.joint_out:
|
215 |
-
clip_loss = model_out['clip_loss']
|
216 |
-
loss = clip_loss
|
217 |
-
|
218 |
-
output = {
|
219 |
-
# 'val_loss': loss,
|
220 |
-
'loss': loss.detach(),
|
221 |
-
'clip_loss': clip_loss.detach(),
|
222 |
-
# 'grammar_loss': grammar_loss.detach(),
|
223 |
-
|
224 |
-
'img_feat': model_out['img_feat'].detach(),
|
225 |
-
'text_feat': model_out['text_feat'].detach(),
|
226 |
-
# 'neg_text_feat': model_out['neg_text_feat'].detach(),
|
227 |
-
# 'grammar_pos_pred': model_out['grammar_pos_pred'].detach(),
|
228 |
-
# 'grammar_neg_pred': model_out['grammar_neg_pred'].detach(),
|
229 |
-
# 'predictions': predictions,
|
230 |
-
# 'n_predictions': n_predictions,
|
231 |
-
}
|
232 |
-
else:
|
233 |
-
clip_loss = model_out['clip_loss']
|
234 |
-
grammar_loss = model_out['grammar_loss']
|
235 |
-
loss = clip_loss + grammar_loss
|
236 |
-
|
237 |
-
output = {
|
238 |
-
# 'val_loss': loss,
|
239 |
-
'loss': loss.detach(),
|
240 |
-
'clip_loss': clip_loss.detach(),
|
241 |
-
'grammar_loss': grammar_loss.detach(),
|
242 |
-
|
243 |
-
'img_feat': model_out['img_feat'].detach(),
|
244 |
-
'text_feat': model_out['text_feat'].detach(),
|
245 |
-
# 'neg_text_feat': model_out['neg_text_feat'].detach(),
|
246 |
-
'grammar_pos_pred': model_out['grammar_pos_pred'].detach(),
|
247 |
-
'grammar_neg_pred': model_out['grammar_neg_pred'].detach(),
|
248 |
-
# 'predictions': predictions,
|
249 |
-
# 'n_predictions': n_predictions,
|
250 |
-
}
|
251 |
-
return output
|
252 |
-
|
253 |
-
def test_step(self, *args, **kwargs):
|
254 |
-
return self.validation_step(*args, **kwargs)
|
255 |
-
|
256 |
-
def validation_epoch_end(self, outputs, split='val'):
|
257 |
-
outputs = d2comm.gather(outputs)
|
258 |
-
# master node
|
259 |
-
if d2comm.is_main_process():
|
260 |
-
assert self.trainer.node_rank == 0 and self.trainer.local_rank == 0
|
261 |
-
outputs = sum(outputs, [])
|
262 |
-
|
263 |
-
out = {}
|
264 |
-
|
265 |
-
val_loss_mean = sum([_['loss'].cpu() for _ in outputs]) / len(outputs)
|
266 |
-
val_clip_loss_mean = sum([_['clip_loss'].cpu() for _ in outputs]) / len(outputs)
|
267 |
-
if not self.opt.joint_out:
|
268 |
-
val_grammar_loss_mean = sum([_['grammar_loss'].cpu() for _ in outputs]) / len(outputs)
|
269 |
-
|
270 |
-
print('loss', val_loss_mean.item())
|
271 |
-
print('clip_loss', val_clip_loss_mean.item())
|
272 |
-
if not self.opt.joint_out:
|
273 |
-
print('grammar_loss', val_grammar_loss_mean.item())
|
274 |
-
|
275 |
-
logit_scale = self.model.clip_model.logit_scale.exp().cpu()
|
276 |
-
|
277 |
-
text_feats = torch.cat([_['text_feat'].cpu() for _ in outputs], dim=0)
|
278 |
-
img_feats = torch.cat([_['img_feat'].cpu() for _ in outputs], dim=0)
|
279 |
-
|
280 |
-
assert text_feats.size() == (5000, 512), text_feats.size()
|
281 |
-
assert img_feats.size() == (5000, 512), img_feats.size()
|
282 |
-
|
283 |
-
logits_per_text = torch.matmul(text_feats, img_feats.t()) * logit_scale
|
284 |
-
logits_per_image = logits_per_text.T
|
285 |
-
|
286 |
-
# text-to-image retrieval
|
287 |
-
print('Text-to-Image retrieval')
|
288 |
-
for k in [1, 5, 10]:
|
289 |
-
text_to_image_topk = logits_per_text.topk(k, dim=1).indices
|
290 |
-
|
291 |
-
n_text = len(text_to_image_topk)
|
292 |
-
|
293 |
-
labels = torch.arange(0, n_text).view(-1, 1)
|
294 |
-
|
295 |
-
n_retrieved = ((text_to_image_topk == labels).sum(dim=1) > 0).sum()
|
296 |
-
|
297 |
-
recall_k = n_retrieved / n_text * 100
|
298 |
-
|
299 |
-
out[f'text_to_image_recall_{k}'] = recall_k.item()
|
300 |
-
|
301 |
-
print(f'R@{k}: {recall_k.item():.2f}%')
|
302 |
-
|
303 |
-
# image-to-text retrieval
|
304 |
-
print('Image-to-Text retrieval')
|
305 |
-
for k in [1, 5, 10]:
|
306 |
-
image_to_text_topk = logits_per_image.topk(k, dim=1).indices
|
307 |
-
|
308 |
-
n_image = len(image_to_text_topk)
|
309 |
-
|
310 |
-
labels = torch.arange(0, n_image).view(-1, 1)
|
311 |
-
|
312 |
-
n_retrieved = ((image_to_text_topk == labels).sum(dim=1) > 0).sum()
|
313 |
-
|
314 |
-
recall_k = n_retrieved / n_image * 100
|
315 |
-
|
316 |
-
out[f'image_to_text_recall_{k}'] = recall_k.item()
|
317 |
-
|
318 |
-
print(f'R@{k}: {recall_k.item():.2f}%')
|
319 |
-
|
320 |
-
out.update({
|
321 |
-
'loss': val_loss_mean.item(),
|
322 |
-
'clip_loss': val_clip_loss_mean.item()
|
323 |
-
})
|
324 |
-
|
325 |
-
if not self.opt.joint_out:
|
326 |
-
# grammar scoring
|
327 |
-
grammar_pos_pred = torch.cat([_['grammar_pos_pred'].cpu() for _ in outputs], dim=0)
|
328 |
-
grammar_neg_pred = torch.cat([_['grammar_neg_pred'].cpu() for _ in outputs], dim=0)
|
329 |
-
|
330 |
-
TP = (grammar_pos_pred == 1).sum().item()
|
331 |
-
FP = (grammar_pos_pred == 0).sum().item()
|
332 |
-
FN = (grammar_neg_pred == 1).sum().item()
|
333 |
-
TN = (grammar_neg_pred == 0).sum().item()
|
334 |
-
print('Grammar check')
|
335 |
-
print(f'TP: {TP} FP: {FP} FN: {FN} TN: {TN}')
|
336 |
-
|
337 |
-
precision = TP / (TP + FP) * 100
|
338 |
-
recall = TP / (TP + FN) * 100
|
339 |
-
accuracy = (TP + TN) / (TP + FP + FN + TN) * 100
|
340 |
-
f1 = 2 * precision * recall / (precision + recall)
|
341 |
-
print(f'Precision: {precision:.2f}%')
|
342 |
-
print(f'Recall: {recall:.2f}%')
|
343 |
-
print(f'Accuracy: {accuracy:.2f}%')
|
344 |
-
print(f'F1: {f1:.2f}%')
|
345 |
-
print('Total: {}'.format(len(grammar_pos_pred)))
|
346 |
-
|
347 |
-
out.update({
|
348 |
-
'grammar_loss': val_grammar_loss_mean,
|
349 |
-
|
350 |
-
'grammar_precision': precision,
|
351 |
-
'grammar_recall': recall,
|
352 |
-
'grammar_accuracy': accuracy,
|
353 |
-
'grammar_f1': f1,
|
354 |
-
|
355 |
-
})
|
356 |
-
|
357 |
-
else:
|
358 |
-
out = {}
|
359 |
-
|
360 |
-
out = d2comm.all_gather(out)[0] # Only the one from master node
|
361 |
-
assert len(out) > 0 # make sure the head has index 0
|
362 |
-
|
363 |
-
# must all be tensors
|
364 |
-
out = {k: torch.tensor(v) if not torch.is_tensor(
|
365 |
-
v) else v for k, v in out.items()}
|
366 |
-
|
367 |
-
for k, v in out.items():
|
368 |
-
self.log(f'{split}/{k}', v)
|
369 |
-
|
370 |
-
def test_epoch_end(self, outputs):
|
371 |
-
|
372 |
-
self.validation_epoch_end(outputs, 'test')
|
373 |
-
|
374 |
-
def configure_optimizers(self):
|
375 |
-
# opt = self.opt
|
376 |
-
# model = self.model
|
377 |
-
|
378 |
-
# parameters = [p for p in model.parameters() if p.requires_grad]
|
379 |
-
|
380 |
-
# if opt.noamopt:
|
381 |
-
# # assert opt.caption_model in ['transformer', 'bert', 'm2transformer'], 'noamopt can only work with transformer'
|
382 |
-
# optimizer = utils.get_std_opt(
|
383 |
-
# model, optim_func=opt.optim, factor=opt.noamopt_factor, warmup=opt.noamopt_warmup)
|
384 |
-
# elif opt.reduce_on_plateau:
|
385 |
-
# # optimizer = utils.build_optimizer(model.parameters(), opt)
|
386 |
-
# optimizer = utils.build_optimizer(parameters, opt)
|
387 |
-
# optimizer = utils.ReduceLROnPlateau(optimizer,
|
388 |
-
# factor=opt.reduce_on_plateau_factor,
|
389 |
-
# patience=opt.reduce_on_plateau_patience)
|
390 |
-
# else:
|
391 |
-
# # optimizer = utils.build_optimizer(model.parameters(), opt)
|
392 |
-
# optimizer = utils.build_optimizer(parameters, opt)
|
393 |
-
|
394 |
-
|
395 |
-
# from transformers.optimization import AdamW, get_linear_schedule_with_warmup
|
396 |
-
# batch_per_epoch = len(self.train_loader)
|
397 |
-
# t_total = batch_per_epoch // self.args.gradient_accumulation_steps * self.args.epochs
|
398 |
-
# warmup_ratio = self.args.warmup_ratio
|
399 |
-
# warmup_iters = int(t_total * warmup_ratio)
|
400 |
-
# if self.verbose:
|
401 |
-
# print("Batch per epoch: %d" % batch_per_epoch)
|
402 |
-
# print("Total Iters: %d" % t_total)
|
403 |
-
# print('Warmup ratio:', warmup_ratio)
|
404 |
-
# print("Warm up Iters: %d" % warmup_iters)
|
405 |
-
|
406 |
-
if self.args.optim == 'adamw':
|
407 |
-
no_decay = ["bias", "LayerNorm.weight"]
|
408 |
-
optimizer_grouped_parameters = [
|
409 |
-
{
|
410 |
-
"params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
|
411 |
-
"weight_decay": self.args.weight_decay,
|
412 |
-
},
|
413 |
-
{
|
414 |
-
"params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
|
415 |
-
"weight_decay": 0.0,
|
416 |
-
},
|
417 |
-
]
|
418 |
-
|
419 |
-
for group in optimizer_grouped_parameters:
|
420 |
-
group['params'] = [p for p in group['params'] if p.requires_grad]
|
421 |
-
|
422 |
-
from transformers.optimization import AdamW
|
423 |
-
optim = AdamW(optimizer_grouped_parameters,
|
424 |
-
lr=self.args.lr, eps=self.args.adam_eps)
|
425 |
-
# lr_scheduler = get_linear_schedule_with_warmup(
|
426 |
-
# optim, warmup_iters, t_total)
|
427 |
-
|
428 |
-
# optimizers = []
|
429 |
-
optimizers = [optim]
|
430 |
-
lr_schedulers = []
|
431 |
-
|
432 |
-
return optimizers, lr_schedulers
|
433 |
-
|
434 |
-
def optimizer_step(self, epoch, batch_idx, optimizer,
|
435 |
-
optimizer_idx, *args, **kwargs):
|
436 |
-
# # warm up lr
|
437 |
-
# opt = self.opt
|
438 |
-
# iteration = self.trainer.global_step
|
439 |
-
# if opt.use_warmup and (iteration < opt.noamopt_warmup):
|
440 |
-
# opt.current_lr = opt.learning_rate * \
|
441 |
-
# (iteration+1) / opt.noamopt_warmup
|
442 |
-
# utils.set_lr(optimizer, opt.current_lr)
|
443 |
-
|
444 |
-
super().optimizer_step(epoch, batch_idx, optimizer,
|
445 |
-
optimizer_idx, *args, **kwargs)
|
446 |
-
|
447 |
-
# print('optimizer step')
|
448 |
-
|
449 |
-
def state_dict(self):
|
450 |
-
"""
|
451 |
-
Save the model state dict as well as opt and vocab
|
452 |
-
"""
|
453 |
-
state_dict = self.model.state_dict()
|
454 |
-
device = next(iter(state_dict.values())).device
|
455 |
-
assert '_vocab' not in state_dict and '_opt' not in state_dict, 'Just in case'
|
456 |
-
# state_dict.update({
|
457 |
-
# '_vocab': utils.serialize_to_tensor(self.model.vocab).to(device),
|
458 |
-
# '_opt': utils.serialize_to_tensor(self.opt).to(device)
|
459 |
-
# })
|
460 |
-
return state_dict
|
461 |
-
|
462 |
-
def load_state_dict(self, state_dict=None, strict=True):
|
463 |
-
# if '_vocab' in state_dict:
|
464 |
-
# self.model.vocab = utils.deserialize(state_dict['_vocab'])
|
465 |
-
# del state_dict['_vocab']
|
466 |
-
# elif strict:
|
467 |
-
# raise KeyError
|
468 |
-
# if '_opt' in state_dict:
|
469 |
-
# saved_model_opt = utils.deserialize(state_dict['_opt'])
|
470 |
-
# del state_dict['_opt']
|
471 |
-
# opt = self.opt
|
472 |
-
# # Make sure the saved opt is compatible with the curren topt
|
473 |
-
# need_be_same = ["caption_model",
|
474 |
-
# "rnn_type", "rnn_size", "num_layers"]
|
475 |
-
# for checkme in need_be_same:
|
476 |
-
# if getattr(saved_model_opt, checkme) in ['updown', 'topdown'] and \
|
477 |
-
# getattr(opt, checkme) in ['updown', 'topdown']:
|
478 |
-
# continue
|
479 |
-
# assert getattr(saved_model_opt, checkme) == getattr(
|
480 |
-
# opt, checkme), "Command line argument and saved model disagree on '%s' " % checkme
|
481 |
-
# elif strict:
|
482 |
-
# raise KeyError
|
483 |
-
self.model.load_state_dict(state_dict, strict)
|
484 |
-
|
485 |
-
|
486 |
-
class OnEpochStartCallback(pl.Callback):
|
487 |
-
|
488 |
-
def on_epoch_start(self, trainer, pl_module):
|
489 |
-
# Update lr/training stage/scheduled sampling prob etc.
|
490 |
-
opt = pl_module.opt
|
491 |
-
model = pl_module.model
|
492 |
-
epoch = trainer.current_epoch
|
493 |
-
optimizer = trainer.optimizers[0]
|
494 |
-
|
495 |
-
# if not opt.noamopt and not opt.reduce_on_plateau:
|
496 |
-
# # Assign the learning rate
|
497 |
-
# if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
|
498 |
-
# frac = (
|
499 |
-
# epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every
|
500 |
-
# decay_factor = opt.learning_rate_decay_rate ** frac
|
501 |
-
# opt.current_lr = opt.learning_rate * decay_factor
|
502 |
-
# else:
|
503 |
-
# opt.current_lr = opt.learning_rate
|
504 |
-
# utils.set_lr(optimizer, opt.current_lr) # set the decayed rate
|
505 |
-
# # Assign the scheduled sampling prob
|
506 |
-
# if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
|
507 |
-
# frac = (
|
508 |
-
# epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every
|
509 |
-
# opt.ss_prob = min(opt.scheduled_sampling_increase_prob *
|
510 |
-
# frac, opt.scheduled_sampling_max_prob)
|
511 |
-
# model.ss_prob = opt.ss_prob
|
512 |
-
|
513 |
-
# # If start self critical training
|
514 |
-
# if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
|
515 |
-
# sc_flag = True
|
516 |
-
# init_scorer(opt.cached_tokens)
|
517 |
-
# else:
|
518 |
-
# sc_flag = False
|
519 |
-
|
520 |
-
# # If start structure loss training
|
521 |
-
# if opt.structure_after != -1 and epoch >= opt.structure_after:
|
522 |
-
# struc_flag = True
|
523 |
-
# init_scorer(opt.cached_tokens)
|
524 |
-
# else:
|
525 |
-
# struc_flag = False
|
526 |
-
|
527 |
-
# pl_module.struc_flag = struc_flag
|
528 |
-
# pl_module.sc_flag = sc_flag
|
529 |
-
|
530 |
-
|
531 |
-
class ModelCheckpoint(pl.callbacks.ModelCheckpoint):
|
532 |
-
|
533 |
-
def on_keyboard_interrupt(self, trainer, pl_module):
|
534 |
-
# Save model when keyboard interrupt
|
535 |
-
filepath = os.path.join(self.dirpath, self.prefix + 'interrupt.ckpt')
|
536 |
-
self._save_model(filepath)
|
537 |
-
|
538 |
-
from param import parse_args
|
539 |
-
# opt = opts.parse_opt()
|
540 |
-
args = parse_args()
|
541 |
-
opt = args
|
542 |
-
|
543 |
-
checkpoint_callback = ModelCheckpoint(
|
544 |
-
filepath=opt.checkpoint_dir + '{epoch:02d}',
|
545 |
-
# dirpath=opt.checkpoint_path,
|
546 |
-
save_last=True,
|
547 |
-
save_top_k=1,
|
548 |
-
verbose=True,
|
549 |
-
# monitor='to_monitor',
|
550 |
-
# monitor='val/to_monitor',
|
551 |
-
# monitor='val/CIDEr',
|
552 |
-
monitor='val/loss',
|
553 |
-
mode='min',
|
554 |
-
# prefix=opt.id+'_',
|
555 |
-
prefix=opt.id,
|
556 |
-
# filename=f'{opt.id}_',
|
557 |
-
)
|
558 |
-
|
559 |
-
verbose = True
|
560 |
-
# import torch
|
561 |
-
# if torch.cuda.current_device() in [0, -1]:
|
562 |
-
if 'LOCAL_RANK' in os.environ and os.environ['LOCAL_RANK'] != '0':
|
563 |
-
verbose = False
|
564 |
-
|
565 |
-
# if verbose:
|
566 |
-
# print(opt)
|
567 |
-
# print("""
|
568 |
-
# val_image_use,
|
569 |
-
# save_checkpoint_very
|
570 |
-
# save_every_epoch,
|
571 |
-
# save_history-ckpt will be ignored.
|
572 |
-
# """)
|
573 |
-
|
574 |
-
# Lightning defines batch size as batch size per gpu
|
575 |
-
assert opt.batch_size % torch.cuda.device_count() == 0
|
576 |
-
opt.batch_size = opt.batch_size // torch.cuda.device_count()
|
577 |
-
opt.valid_batch_size = opt.valid_batch_size // torch.cuda.device_count()
|
578 |
-
|
579 |
-
# If resume from last checkpoint
|
580 |
-
# if opt.start_from is not None and os.path.isfile(os.path.join(opt.start_from, f'{opt.id}_last.ckpt')):
|
581 |
-
# resume_from = os.path.join(opt.start_from, f'{opt.id}_last.ckpt')
|
582 |
-
if opt.start_from is not None and os.path.isfile(os.path.join(opt.start_from, f'{opt.id}-last.ckpt')):
|
583 |
-
resume_from = os.path.join(opt.start_from, f'{opt.id}-last.ckpt')
|
584 |
-
if verbose:
|
585 |
-
print('resume from', resume_from)
|
586 |
-
else:
|
587 |
-
resume_from = None
|
588 |
-
|
589 |
-
from pytorch_lightning.loggers import WandbLogger
|
590 |
-
wandb_logger = WandbLogger(
|
591 |
-
# project='CLIP-ViL-COCOCaption',
|
592 |
-
project='CLIP-Finetune-COCO',
|
593 |
-
name=opt.id,
|
594 |
-
)
|
595 |
-
|
596 |
-
if verbose:
|
597 |
-
wandb_logger.experiment.config.update(opt)
|
598 |
-
from pathlib import Path
|
599 |
-
import glob
|
600 |
-
import wandb
|
601 |
-
# src_dir = Path(__file__).resolve().parent.parent
|
602 |
-
glob_str = "*.py"
|
603 |
-
base_path = './'
|
604 |
-
wandb.save(glob_str=glob_str, base_path=base_path)
|
605 |
-
|
606 |
-
glob_str = "**/*.yaml"
|
607 |
-
base_path = './'
|
608 |
-
wandb.save(glob_str=glob_str, base_path=base_path)
|
609 |
-
|
610 |
-
# code = wandb.Artifact('project-source', type='code')
|
611 |
-
# for path in glob.glob('**/*.py', recursive=True):
|
612 |
-
# code.add_file(path, name='source/'+path)
|
613 |
-
# print(path)
|
614 |
-
# wandb.run.use_artifact(code)
|
615 |
-
|
616 |
-
|
617 |
-
|
618 |
-
|
619 |
-
lit = LitModel(opt)
|
620 |
-
# warning grad_clip_mode is ignored.
|
621 |
-
trainer = pl.Trainer(
|
622 |
-
callbacks=[
|
623 |
-
OnEpochStartCallback(),
|
624 |
-
# pl.callbacks.lr_logger.LearningRateLogger()
|
625 |
-
pl.callbacks.LearningRateMonitor()
|
626 |
-
],
|
627 |
-
default_root_dir=opt.checkpoint_dir,
|
628 |
-
resume_from_checkpoint=resume_from,
|
629 |
-
|
630 |
-
distributed_backend='ddp',
|
631 |
-
gpus=torch.cuda.device_count(),
|
632 |
-
|
633 |
-
# gpus=1,
|
634 |
-
|
635 |
-
check_val_every_n_epoch=1,
|
636 |
-
# max_epochs=opt.max_epochs,
|
637 |
-
max_epochs=opt.epochs,
|
638 |
-
# gradient_clip_val=opt.grad_clip_value,
|
639 |
-
gradient_clip_val=opt.clip_grad_norm,
|
640 |
-
|
641 |
-
checkpoint_callback=checkpoint_callback,
|
642 |
-
log_gpu_memory='min_max',
|
643 |
-
# log_save_interval=opt.losses_log_every,
|
644 |
-
log_every_n_steps=opt.losses_log_every,
|
645 |
-
profiler=True,
|
646 |
-
# profiler='simple',
|
647 |
-
# row_log_interval=10, # what is it?
|
648 |
-
flush_logs_every_n_steps=10,
|
649 |
-
num_sanity_val_steps=0,
|
650 |
-
# val_check_interval=0.01,
|
651 |
-
# limit_train_batches=500,
|
652 |
-
# progress_bar_refresh_rate=0,
|
653 |
-
# fast_dev_run=True,
|
654 |
-
precision=opt.precision,
|
655 |
-
logger=wandb_logger
|
656 |
-
)
|
657 |
-
|
658 |
-
if os.getenv('EVALUATE', '0') == '1':
|
659 |
-
trainer.test(lit)
|
660 |
-
else:
|
661 |
-
trainer.fit(lit)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
save/README.md
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
Directory for checkpoints
|
|
|
|
scripts/build_bpe_subword_nmt.py
DELETED
@@ -1,214 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Preprocess a raw json dataset into hdf5/json files for use in data_loader.lua
|
3 |
-
|
4 |
-
Input: json file that has the form
|
5 |
-
[{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...]
|
6 |
-
example element in this list would look like
|
7 |
-
{'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895}
|
8 |
-
|
9 |
-
This script reads this json, does some basic preprocessing on the captions
|
10 |
-
(e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays
|
11 |
-
|
12 |
-
Output: a json file and an hdf5 file
|
13 |
-
The hdf5 file contains several fields:
|
14 |
-
/labels is (M,max_length) uint32 array of encoded labels, zero padded
|
15 |
-
/label_start_ix and /label_end_ix are (N,) uint32 arrays of pointers to the
|
16 |
-
first and last indices (in range 1..M) of labels for each image
|
17 |
-
/label_length stores the length of the sequence for each of the M sequences
|
18 |
-
|
19 |
-
The json file has a dict that contains:
|
20 |
-
- an 'ix_to_word' field storing the vocab in form {ix:'word'}, where ix is 1-indexed
|
21 |
-
- an 'images' field that is a list holding auxiliary information for each image,
|
22 |
-
such as in particular the 'split' it was assigned to.
|
23 |
-
"""
|
24 |
-
|
25 |
-
from __future__ import absolute_import
|
26 |
-
from __future__ import division
|
27 |
-
from __future__ import print_function
|
28 |
-
|
29 |
-
import os
|
30 |
-
import json
|
31 |
-
import argparse
|
32 |
-
from random import shuffle, seed
|
33 |
-
import string
|
34 |
-
# non-standard dependencies:
|
35 |
-
import h5py
|
36 |
-
import numpy as np
|
37 |
-
import torch
|
38 |
-
import torchvision.models as models
|
39 |
-
import skimage.io
|
40 |
-
from PIL import Image
|
41 |
-
|
42 |
-
import codecs
|
43 |
-
import tempfile
|
44 |
-
from subword_nmt import learn_bpe, apply_bpe
|
45 |
-
|
46 |
-
# python scripts/build_bpe_subword_nmt.py --input_json data/dataset_coco.json --output_json data/cocotalkbpe.json --output_h5 data/cocotalkbpe
|
47 |
-
|
48 |
-
def build_vocab(imgs, params):
|
49 |
-
# count up the number of words
|
50 |
-
captions = []
|
51 |
-
for img in imgs:
|
52 |
-
for sent in img['sentences']:
|
53 |
-
captions.append(' '.join(sent['tokens']))
|
54 |
-
captions='\n'.join(captions)
|
55 |
-
all_captions = tempfile.NamedTemporaryFile(delete=False)
|
56 |
-
all_captions.close()
|
57 |
-
with open(all_captions.name, 'w') as txt_file:
|
58 |
-
txt_file.write(captions)
|
59 |
-
|
60 |
-
#
|
61 |
-
codecs_output = tempfile.NamedTemporaryFile(delete=False)
|
62 |
-
codecs_output.close()
|
63 |
-
with codecs.open(codecs_output.name, 'w', encoding='UTF-8') as output:
|
64 |
-
learn_bpe.learn_bpe(codecs.open(all_captions.name, encoding='UTF-8'), output, params['symbol_count'])
|
65 |
-
|
66 |
-
with codecs.open(codecs_output.name, encoding='UTF-8') as codes:
|
67 |
-
bpe = apply_bpe.BPE(codes)
|
68 |
-
|
69 |
-
tmp = tempfile.NamedTemporaryFile(delete=False)
|
70 |
-
tmp.close()
|
71 |
-
|
72 |
-
tmpout = codecs.open(tmp.name, 'w', encoding='UTF-8')
|
73 |
-
|
74 |
-
for _, img in enumerate(imgs):
|
75 |
-
img['final_captions'] = []
|
76 |
-
for sent in img['sentences']:
|
77 |
-
txt = ' '.join(sent['tokens'])
|
78 |
-
txt = bpe.segment(txt).strip()
|
79 |
-
img['final_captions'].append(txt.split(' '))
|
80 |
-
tmpout.write(txt)
|
81 |
-
tmpout.write('\n')
|
82 |
-
if _ < 20:
|
83 |
-
print(txt)
|
84 |
-
|
85 |
-
tmpout.close()
|
86 |
-
tmpin = codecs.open(tmp.name, encoding='UTF-8')
|
87 |
-
|
88 |
-
vocab = learn_bpe.get_vocabulary(tmpin)
|
89 |
-
vocab = sorted(vocab.keys(), key=lambda x: vocab[x], reverse=True)
|
90 |
-
|
91 |
-
# Always insert UNK
|
92 |
-
print('inserting the special UNK token')
|
93 |
-
vocab.append('UNK')
|
94 |
-
|
95 |
-
print('Vocab size:', len(vocab))
|
96 |
-
|
97 |
-
os.remove(all_captions.name)
|
98 |
-
with open(codecs_output.name, 'r') as codes:
|
99 |
-
bpe = codes.read()
|
100 |
-
os.remove(codecs_output.name)
|
101 |
-
os.remove(tmp.name)
|
102 |
-
|
103 |
-
return vocab, bpe
|
104 |
-
|
105 |
-
def encode_captions(imgs, params, wtoi):
|
106 |
-
"""
|
107 |
-
encode all captions into one large array, which will be 1-indexed.
|
108 |
-
also produces label_start_ix and label_end_ix which store 1-indexed
|
109 |
-
and inclusive (Lua-style) pointers to the first and last caption for
|
110 |
-
each image in the dataset.
|
111 |
-
"""
|
112 |
-
|
113 |
-
max_length = params['max_length']
|
114 |
-
N = len(imgs)
|
115 |
-
M = sum(len(img['final_captions']) for img in imgs) # total number of captions
|
116 |
-
|
117 |
-
label_arrays = []
|
118 |
-
label_start_ix = np.zeros(N, dtype='uint32') # note: these will be one-indexed
|
119 |
-
label_end_ix = np.zeros(N, dtype='uint32')
|
120 |
-
label_length = np.zeros(M, dtype='uint32')
|
121 |
-
caption_counter = 0
|
122 |
-
counter = 1
|
123 |
-
for i,img in enumerate(imgs):
|
124 |
-
n = len(img['final_captions'])
|
125 |
-
assert n > 0, 'error: some image has no captions'
|
126 |
-
|
127 |
-
Li = np.zeros((n, max_length), dtype='uint32')
|
128 |
-
for j,s in enumerate(img['final_captions']):
|
129 |
-
label_length[caption_counter] = min(max_length, len(s)) # record the length of this sequence
|
130 |
-
caption_counter += 1
|
131 |
-
for k,w in enumerate(s):
|
132 |
-
if k < max_length:
|
133 |
-
Li[j,k] = wtoi[w]
|
134 |
-
|
135 |
-
# note: word indices are 1-indexed, and captions are padded with zeros
|
136 |
-
label_arrays.append(Li)
|
137 |
-
label_start_ix[i] = counter
|
138 |
-
label_end_ix[i] = counter + n - 1
|
139 |
-
|
140 |
-
counter += n
|
141 |
-
|
142 |
-
L = np.concatenate(label_arrays, axis=0) # put all the labels together
|
143 |
-
assert L.shape[0] == M, 'lengths don\'t match? that\'s weird'
|
144 |
-
assert np.all(label_length > 0), 'error: some caption had no words?'
|
145 |
-
|
146 |
-
print('encoded captions to array of size ', L.shape)
|
147 |
-
return L, label_start_ix, label_end_ix, label_length
|
148 |
-
|
149 |
-
def main(params):
|
150 |
-
|
151 |
-
imgs = json.load(open(params['input_json'], 'r'))
|
152 |
-
imgs = imgs['images']
|
153 |
-
|
154 |
-
seed(123) # make reproducible
|
155 |
-
|
156 |
-
# create the vocab
|
157 |
-
vocab, bpe = build_vocab(imgs, params)
|
158 |
-
itow = {i+1:w for i,w in enumerate(vocab)} # a 1-indexed vocab translation table
|
159 |
-
wtoi = {w:i+1 for i,w in enumerate(vocab)} # inverse table
|
160 |
-
|
161 |
-
# encode captions in large arrays, ready to ship to hdf5 file
|
162 |
-
L, label_start_ix, label_end_ix, label_length = encode_captions(imgs, params, wtoi)
|
163 |
-
|
164 |
-
# create output h5 file
|
165 |
-
N = len(imgs)
|
166 |
-
f_lb = h5py.File(params['output_h5']+'_label.h5', "w")
|
167 |
-
f_lb.create_dataset("labels", dtype='uint32', data=L)
|
168 |
-
f_lb.create_dataset("label_start_ix", dtype='uint32', data=label_start_ix)
|
169 |
-
f_lb.create_dataset("label_end_ix", dtype='uint32', data=label_end_ix)
|
170 |
-
f_lb.create_dataset("label_length", dtype='uint32', data=label_length)
|
171 |
-
f_lb.close()
|
172 |
-
|
173 |
-
# create output json file
|
174 |
-
out = {}
|
175 |
-
out['ix_to_word'] = itow # encode the (1-indexed) vocab
|
176 |
-
out['images'] = []
|
177 |
-
out['bpe'] = bpe
|
178 |
-
for i,img in enumerate(imgs):
|
179 |
-
|
180 |
-
jimg = {}
|
181 |
-
jimg['split'] = img['split']
|
182 |
-
if 'filename' in img: jimg['file_path'] = os.path.join(img['filepath'], img['filename']) # copy it over, might need
|
183 |
-
if 'cocoid' in img: jimg['id'] = img['cocoid'] # copy over & mantain an id, if present (e.g. coco ids, useful)
|
184 |
-
|
185 |
-
if params['images_root'] != '':
|
186 |
-
with Image.open(os.path.join(params['images_root'], img['filepath'], img['filename'])) as _img:
|
187 |
-
jimg['width'], jimg['height'] = _img.size
|
188 |
-
|
189 |
-
out['images'].append(jimg)
|
190 |
-
|
191 |
-
json.dump(out, open(params['output_json'], 'w'))
|
192 |
-
print('wrote ', params['output_json'])
|
193 |
-
|
194 |
-
if __name__ == "__main__":
|
195 |
-
|
196 |
-
parser = argparse.ArgumentParser()
|
197 |
-
|
198 |
-
# input json
|
199 |
-
parser.add_argument('--input_json', required=True, help='input json file to process into hdf5')
|
200 |
-
parser.add_argument('--output_json', default='data.json', help='output json file')
|
201 |
-
parser.add_argument('--output_h5', default='data', help='output h5 file')
|
202 |
-
parser.add_argument('--images_root', default='', help='root location in which images are stored, to be prepended to file_path in input json')
|
203 |
-
|
204 |
-
# options
|
205 |
-
parser.add_argument('--max_length', default=16, type=int, help='max length of a caption, in number of words. captions longer than this get clipped.')
|
206 |
-
parser.add_argument('--symbol_count', default=10000, type=int, help='only words that occur more than this number of times will be put in vocab')
|
207 |
-
|
208 |
-
args = parser.parse_args()
|
209 |
-
params = vars(args) # convert to ordinary dict
|
210 |
-
print('parsed input parameters:')
|
211 |
-
print(json.dumps(params, indent = 2))
|
212 |
-
main(params)
|
213 |
-
|
214 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/clip_prepro_feats.py
DELETED
@@ -1,170 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Preprocess a raw json dataset into features files for use in data_loader.py
|
3 |
-
|
4 |
-
Input: json file that has the form
|
5 |
-
[{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...]
|
6 |
-
example element in this list would look like
|
7 |
-
{'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895}
|
8 |
-
|
9 |
-
This script reads this json, does some basic preprocessing on the captions
|
10 |
-
(e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays
|
11 |
-
|
12 |
-
Output: two folders of features
|
13 |
-
"""
|
14 |
-
|
15 |
-
from __future__ import absolute_import
|
16 |
-
from __future__ import division
|
17 |
-
from __future__ import print_function
|
18 |
-
|
19 |
-
import os
|
20 |
-
import json
|
21 |
-
import argparse
|
22 |
-
from random import shuffle, seed
|
23 |
-
import string
|
24 |
-
# non-standard dependencies:
|
25 |
-
import h5py
|
26 |
-
from six.moves import cPickle
|
27 |
-
import numpy as np
|
28 |
-
import torch
|
29 |
-
import torchvision.models as models
|
30 |
-
import skimage.io
|
31 |
-
|
32 |
-
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
33 |
-
from PIL import Image
|
34 |
-
from torch import nn
|
35 |
-
|
36 |
-
preprocess = Compose([
|
37 |
-
Resize((448, 448), interpolation=Image.BICUBIC),
|
38 |
-
CenterCrop((448, 448)),
|
39 |
-
ToTensor()
|
40 |
-
])
|
41 |
-
|
42 |
-
|
43 |
-
from clip.clip import load
|
44 |
-
from timm.models.vision_transformer import resize_pos_embed
|
45 |
-
import timm
|
46 |
-
|
47 |
-
from captioning.utils.resnet_utils import myResnet
|
48 |
-
import captioning.utils.resnet as resnet
|
49 |
-
|
50 |
-
from tqdm import tqdm
|
51 |
-
|
52 |
-
|
53 |
-
def main(params):
|
54 |
-
if params["model_type"] != 'vit_base_patch32_224_in21k':
|
55 |
-
model, transform = load(params["model_type"], jit=False)
|
56 |
-
else:
|
57 |
-
model = timm.create_model(params["model_type"], pretrained=True)
|
58 |
-
model = model.cuda()
|
59 |
-
|
60 |
-
if params["model_type"] != 'vit_base_patch32_224_in21k':
|
61 |
-
save_model_type = params["model_type"].split("-")[0]
|
62 |
-
mean = torch.Tensor([0.48145466, 0.4578275, 0.40821073]).to("cuda").reshape(3, 1, 1)
|
63 |
-
std = torch.Tensor([0.26862954, 0.26130258, 0.27577711]).to("cuda").reshape(3, 1, 1)
|
64 |
-
|
65 |
-
if "RN" in params["model_type"]:
|
66 |
-
num_patches = 196 #600 * 1000 // 32 // 32
|
67 |
-
pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, model.visual.attnpool.positional_embedding.shape[-1], device='cuda'),)
|
68 |
-
pos_embed.weight = resize_pos_embed(model.visual.attnpool.positional_embedding.unsqueeze(0), pos_embed)
|
69 |
-
model.visual.attnpool.positional_embedding = pos_embed
|
70 |
-
|
71 |
-
else:
|
72 |
-
save_model_type = 'vit_base'
|
73 |
-
mean = torch.Tensor([0.5, 0.5, 0.5]).to("cuda").reshape(3, 1, 1)
|
74 |
-
std = torch.Tensor([0.5, 0.5, 0.5]).to("cuda").reshape(3, 1, 1)
|
75 |
-
|
76 |
-
num_patches = 196 #600 * 1000 // 32 // 32
|
77 |
-
pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, 768, device='cuda'),)
|
78 |
-
pos_embed.weight = resize_pos_embed(model.pos_embed, pos_embed)
|
79 |
-
model.pos_embed = pos_embed
|
80 |
-
|
81 |
-
if params["model_type"] == "ViT-B/32":
|
82 |
-
num_patches = 196 #600 * 1000 // 32 // 32
|
83 |
-
pos_embed = nn.Parameter(torch.zeros(num_patches + 1, 768, device='cuda'),)
|
84 |
-
pos_embed.weight = resize_pos_embed(model.visual.positional_embedding.unsqueeze(0), pos_embed.unsqueeze(0))
|
85 |
-
model.visual.positional_embedding = pos_embed
|
86 |
-
imgs = json.load(open(params['input_json'], 'r'))
|
87 |
-
|
88 |
-
imgs = imgs['images']
|
89 |
-
|
90 |
-
if args.n_jobs > 1:
|
91 |
-
print('Total imgs:', len(imgs))
|
92 |
-
print('Using {} jobs'.format(args.n_jobs))
|
93 |
-
print('job id:', args.job_id)
|
94 |
-
imgs = imgs[args.job_id::args.n_jobs]
|
95 |
-
|
96 |
-
N = len(imgs)
|
97 |
-
|
98 |
-
seed(123) # make reproducible
|
99 |
-
|
100 |
-
dir_fc = params['output_dir']+'_clip_'+save_model_type+'_fc'
|
101 |
-
dir_att = params['output_dir']+'_clip_'+save_model_type+'_att'
|
102 |
-
if not os.path.isdir(dir_fc):
|
103 |
-
os.mkdir(dir_fc)
|
104 |
-
if not os.path.isdir(dir_att):
|
105 |
-
os.mkdir(dir_att)
|
106 |
-
|
107 |
-
for i,img in enumerate(tqdm(imgs)):
|
108 |
-
# load the image
|
109 |
-
with torch.no_grad():
|
110 |
-
|
111 |
-
image = preprocess(Image.open(os.path.join(params['images_root'], img['filepath'], img['filename']) ).convert("RGB"))
|
112 |
-
image = torch.tensor(np.stack([image])).cuda()
|
113 |
-
image -= mean
|
114 |
-
image /= std
|
115 |
-
if "RN" in params["model_type"]:
|
116 |
-
tmp_att, tmp_fc = model.encode_image(image)
|
117 |
-
tmp_att = tmp_att[0].permute(1, 2, 0)
|
118 |
-
tmp_fc = tmp_fc[0]
|
119 |
-
elif params["model_type"] == 'vit_base_patch32_224_in21k':
|
120 |
-
x = model(image)
|
121 |
-
tmp_fc = x[0, 0, :]
|
122 |
-
tmp_att = x[0, 1:, :].reshape( 14, 14, 768 )
|
123 |
-
else:
|
124 |
-
x = model.visual.conv1(image.half()) # shape = [*, width, grid, grid]
|
125 |
-
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
126 |
-
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
127 |
-
x = torch.cat([model.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
128 |
-
x = x + model.visual.positional_embedding.to(x.dtype)[:x.shape[1], :]
|
129 |
-
x = model.visual.ln_pre(x)
|
130 |
-
|
131 |
-
x = x.permute(1, 0, 2) # NLD -> LND
|
132 |
-
|
133 |
-
for layer_idx, layer in enumerate(model.visual.transformer.resblocks):
|
134 |
-
x = layer(x)
|
135 |
-
|
136 |
-
x = x.permute(1, 0, 2)
|
137 |
-
tmp_fc = x[0, 0, :]
|
138 |
-
tmp_att = x[0, 1:, :].reshape( 14, 14, 768 )
|
139 |
-
|
140 |
-
np.save(os.path.join(dir_fc, str(img['cocoid'])), tmp_fc.data.cpu().float().numpy())
|
141 |
-
np.savez_compressed(os.path.join(dir_att, str(img['cocoid'])), feat=tmp_att.data.cpu().float().numpy())
|
142 |
-
|
143 |
-
|
144 |
-
# if i % 1000 == 0:
|
145 |
-
# print('processing %d/%d (%.2f%% done)' % (i, N, i*100.0/N))
|
146 |
-
print('wrote ', dir_fc, dir_att)
|
147 |
-
|
148 |
-
if __name__ == "__main__":
|
149 |
-
|
150 |
-
parser = argparse.ArgumentParser()
|
151 |
-
|
152 |
-
# input json
|
153 |
-
parser.add_argument('--input_json', required=True, help='input json file to process into hdf5')
|
154 |
-
parser.add_argument('--output_dir', default='data', help='output h5 file')
|
155 |
-
|
156 |
-
# options
|
157 |
-
parser.add_argument('--images_root', default='', help='root location in which images are stored, to be prepended to file_path in input json')
|
158 |
-
parser.add_argument('--att_size', default=14, type=int, help='14x14 or 7x7')
|
159 |
-
parser.add_argument('--model_type', default='RN50', type=str, help='RN50, RN101, RN50x4, ViT-B/32, vit_base_patch32_224_in21k')
|
160 |
-
|
161 |
-
parser.add_argument('--n_jobs', default=-1, type=int, help='number of jobs to run in parallel')
|
162 |
-
parser.add_argument('--job_id', default=0, type=int, help='job id')
|
163 |
-
parser.add_argument('--batch_size', default=1, type=int, help='batch size')
|
164 |
-
|
165 |
-
|
166 |
-
args = parser.parse_args()
|
167 |
-
params = vars(args) # convert to ordinary dict
|
168 |
-
print('parsed input parameters:')
|
169 |
-
print(json.dumps(params, indent = 2))
|
170 |
-
main(params)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/clipscore_prepro_feats.py
DELETED
@@ -1,162 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Preprocess a raw json dataset into features files for use in data_loader.py
|
3 |
-
|
4 |
-
Input: json file that has the form
|
5 |
-
[{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...]
|
6 |
-
example element in this list would look like
|
7 |
-
{'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895}
|
8 |
-
|
9 |
-
This script reads this json, does some basic preprocessing on the captions
|
10 |
-
(e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays
|
11 |
-
|
12 |
-
Output: two folders of features
|
13 |
-
"""
|
14 |
-
|
15 |
-
from __future__ import absolute_import
|
16 |
-
from __future__ import division
|
17 |
-
from __future__ import print_function
|
18 |
-
|
19 |
-
import os
|
20 |
-
import json
|
21 |
-
import argparse
|
22 |
-
from random import shuffle, seed
|
23 |
-
import string
|
24 |
-
# non-standard dependencies:
|
25 |
-
import h5py
|
26 |
-
from six.moves import cPickle
|
27 |
-
import numpy as np
|
28 |
-
import torch
|
29 |
-
import torchvision.models as models
|
30 |
-
import skimage.io
|
31 |
-
|
32 |
-
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
33 |
-
from PIL import Image
|
34 |
-
from torch import nn
|
35 |
-
|
36 |
-
# preprocess = Compose([
|
37 |
-
# Resize((448, 448), interpolation=Image.BICUBIC),
|
38 |
-
# CenterCrop((448, 448)),
|
39 |
-
# ToTensor()
|
40 |
-
# ])
|
41 |
-
|
42 |
-
|
43 |
-
# from clip.clip import load
|
44 |
-
# from timm.models.vision_transformer import resize_pos_embed
|
45 |
-
# import timm
|
46 |
-
|
47 |
-
# from captioning.utils.resnet_utils import myResnet
|
48 |
-
# import captioning.utils.resnet as resnet
|
49 |
-
|
50 |
-
from captioning.utils.clipscore import CLIPScore
|
51 |
-
|
52 |
-
from tqdm import tqdm
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
def main(params):
|
57 |
-
|
58 |
-
clipscore_model = CLIPScore()
|
59 |
-
clipscore_model.to('cuda')
|
60 |
-
|
61 |
-
imgs = json.load(open(params['input_json'], 'r'))
|
62 |
-
imgs = imgs['images']
|
63 |
-
|
64 |
-
if args.n_jobs > 1:
|
65 |
-
print('Total imgs:', len(imgs))
|
66 |
-
print('Using {} jobs'.format(args.n_jobs))
|
67 |
-
print('job id:', args.job_id)
|
68 |
-
imgs = imgs[args.job_id::args.n_jobs]
|
69 |
-
|
70 |
-
N = len(imgs)
|
71 |
-
|
72 |
-
seed(123) # make reproducible
|
73 |
-
|
74 |
-
# dir_fc = params['output_dir']+'_clip_'+save_model_type+'_fc'
|
75 |
-
# dir_att = params['output_dir']+'_clip_'+save_model_type+'_att'
|
76 |
-
|
77 |
-
vis_dir_fc = params['output_dir']+'_clipscore_vis'
|
78 |
-
if not os.path.isdir(vis_dir_fc):
|
79 |
-
os.mkdir(vis_dir_fc)
|
80 |
-
|
81 |
-
# text_dir_fc = params['output_dir']+'_clipscore_text'
|
82 |
-
# if not os.path.isdir(text_dir_fc):
|
83 |
-
# os.mkdir(text_dir_fc)
|
84 |
-
|
85 |
-
# if not os.path.isdir(dir_att):
|
86 |
-
# os.mkdir(dir_att)
|
87 |
-
|
88 |
-
for i, img in enumerate(tqdm(imgs)):
|
89 |
-
# load the image
|
90 |
-
|
91 |
-
img_path = os.path.join(params['images_root'], img['filepath'], img['filename'])
|
92 |
-
img_feat = clipscore_model.image_extract(img_path)
|
93 |
-
img_feat = img_feat.view(512)
|
94 |
-
|
95 |
-
# for d in img['sentences']:
|
96 |
-
# text = d['raw'].strip()
|
97 |
-
# text_feat = clipscore_model.text_extract(text)
|
98 |
-
|
99 |
-
|
100 |
-
# with torch.no_grad():
|
101 |
-
|
102 |
-
# image = preprocess(Image.open(os.path.join(params['images_root'], img['filepath'], img['filename']) ).convert("RGB"))
|
103 |
-
# image = torch.tensor(np.stack([image])).cuda()
|
104 |
-
# image -= mean
|
105 |
-
# image /= std
|
106 |
-
# if "RN" in params["model_type"]:
|
107 |
-
# tmp_att, tmp_fc = model.encode_image(image)
|
108 |
-
# tmp_att = tmp_att[0].permute(1, 2, 0)
|
109 |
-
# tmp_fc = tmp_fc[0]
|
110 |
-
# elif params["model_type"] == 'vit_base_patch32_224_in21k':
|
111 |
-
# x = model(image)
|
112 |
-
# tmp_fc = x[0, 0, :]
|
113 |
-
# tmp_att = x[0, 1:, :].reshape( 14, 14, 768 )
|
114 |
-
# else:
|
115 |
-
# x = model.visual.conv1(image.half()) # shape = [*, width, grid, grid]
|
116 |
-
# x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
117 |
-
# x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
118 |
-
# x = torch.cat([model.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
119 |
-
# x = x + model.visual.positional_embedding.to(x.dtype)[:x.shape[1], :]
|
120 |
-
# x = model.visual.ln_pre(x)
|
121 |
-
|
122 |
-
# x = x.permute(1, 0, 2) # NLD -> LND
|
123 |
-
|
124 |
-
# for layer_idx, layer in enumerate(model.visual.transformer.resblocks):
|
125 |
-
# x = layer(x)
|
126 |
-
|
127 |
-
# x = x.permute(1, 0, 2)
|
128 |
-
# tmp_fc = x[0, 0, :]
|
129 |
-
# tmp_att = x[0, 1:, :].reshape( 14, 14, 768 )
|
130 |
-
|
131 |
-
np.save(os.path.join(vis_dir_fc, str(img['cocoid'])), img_feat.data.cpu().float().numpy())
|
132 |
-
# np.save(os.path.join(text_dir_fc, str(img['cocoid'])), tmp_fc.data.cpu().float().numpy())
|
133 |
-
|
134 |
-
|
135 |
-
# np.savez_compressed(os.path.join(dir_att, str(img['cocoid'])), feat=tmp_att.data.cpu().float().numpy())
|
136 |
-
|
137 |
-
if i % 1000 == 0:
|
138 |
-
print('processing %d/%d (%.2f%% done)' % (i, N, i*100.0/N))
|
139 |
-
print('wrote ', vis_dir_fc)
|
140 |
-
|
141 |
-
if __name__ == "__main__":
|
142 |
-
|
143 |
-
parser = argparse.ArgumentParser()
|
144 |
-
|
145 |
-
# input json
|
146 |
-
# dataset_coco.json
|
147 |
-
parser.add_argument('--input_json', required=True, help='input json file to process into hdf5')
|
148 |
-
parser.add_argument('--output_dir', default='data', help='output h5 file')
|
149 |
-
|
150 |
-
# options
|
151 |
-
parser.add_argument('--images_root', default='', help='root location in which images are stored, to be prepended to file_path in input json')
|
152 |
-
# parser.add_argument('--att_size', default=14, type=int, help='14x14 or 7x7')
|
153 |
-
# parser.add_argument('--model_type', default='RN50', type=str, help='RN50, RN101, RN50x4, ViT-B/32, vit_base_patch32_224_in21k')
|
154 |
-
|
155 |
-
parser.add_argument('--n_jobs', default=-1, type=int, help='number of jobs to run in parallel')
|
156 |
-
parser.add_argument('--job_id', default=0, type=int, help='job id')
|
157 |
-
|
158 |
-
args = parser.parse_args()
|
159 |
-
params = vars(args) # convert to ordinary dict
|
160 |
-
print('parsed input parameters:')
|
161 |
-
print(json.dumps(params, indent = 2))
|
162 |
-
main(params)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/copy_model.sh
DELETED
@@ -1,9 +0,0 @@
|
|
1 |
-
#!/bin/sh
|
2 |
-
|
3 |
-
if [ ! -d log_$2 ]; then
|
4 |
-
cp -r log_$1 log_$2
|
5 |
-
cd log_$2
|
6 |
-
mv infos_$1-best.pkl infos_$2-best.pkl
|
7 |
-
mv infos_$1.pkl infos_$2.pkl
|
8 |
-
cd ../
|
9 |
-
fi
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/dump_to_h5df.py
DELETED
@@ -1,56 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
import h5py
|
3 |
-
import os
|
4 |
-
import numpy as np
|
5 |
-
import json
|
6 |
-
from tqdm import tqdm
|
7 |
-
|
8 |
-
|
9 |
-
def main(params):
|
10 |
-
|
11 |
-
imgs = json.load(open(params['input_json'], 'r'))
|
12 |
-
imgs = imgs['images']
|
13 |
-
N = len(imgs)
|
14 |
-
|
15 |
-
if params['fc_input_dir'] is not None:
|
16 |
-
print('processing fc')
|
17 |
-
with h5py.File(params['fc_output']) as file_fc:
|
18 |
-
for i, img in enumerate(tqdm(imgs)):
|
19 |
-
npy_fc_path = os.path.join(
|
20 |
-
params['fc_input_dir'],
|
21 |
-
str(img['cocoid']) + '.npy')
|
22 |
-
|
23 |
-
d_set_fc = file_fc.create_dataset(
|
24 |
-
str(img['cocoid']), data=np.load(npy_fc_path))
|
25 |
-
file_fc.close()
|
26 |
-
|
27 |
-
if params['att_input_dir'] is not None:
|
28 |
-
print('processing att')
|
29 |
-
with h5py.File(params['att_output']) as file_att:
|
30 |
-
for i, img in enumerate(tqdm(imgs)):
|
31 |
-
npy_att_path = os.path.join(
|
32 |
-
params['att_input_dir'],
|
33 |
-
str(img['cocoid']) + '.npz')
|
34 |
-
|
35 |
-
d_set_att = file_att.create_dataset(
|
36 |
-
str(img['cocoid']),
|
37 |
-
data=np.load(npy_att_path)['feat'])
|
38 |
-
file_att.close()
|
39 |
-
|
40 |
-
|
41 |
-
if __name__ == "__main__":
|
42 |
-
|
43 |
-
parser = argparse.ArgumentParser()
|
44 |
-
|
45 |
-
parser.add_argument('--input_json', required=True, help='input json file to process into hdf5')
|
46 |
-
parser.add_argument('--fc_output', default='data', help='output h5 filename for fc')
|
47 |
-
parser.add_argument('--att_output', default='data', help='output h5 file for att')
|
48 |
-
parser.add_argument('--fc_input_dir', default=None, help='input directory for numpy fc files')
|
49 |
-
parser.add_argument('--att_input_dir', default=None, help='input directory for numpy att files')
|
50 |
-
|
51 |
-
args = parser.parse_args()
|
52 |
-
params = vars(args) # convert to ordinary dict
|
53 |
-
print('parsed input parameters:')
|
54 |
-
print(json.dumps(params, indent=2))
|
55 |
-
|
56 |
-
main(params)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/dump_to_lmdb.py
DELETED
@@ -1,241 +0,0 @@
|
|
1 |
-
# copy from https://github.com/Lyken17/Efficient-PyTorch/tools
|
2 |
-
|
3 |
-
from __future__ import absolute_import
|
4 |
-
from __future__ import division
|
5 |
-
from __future__ import print_function
|
6 |
-
|
7 |
-
import os
|
8 |
-
import os.path as osp
|
9 |
-
import os, sys
|
10 |
-
import os.path as osp
|
11 |
-
from PIL import Image
|
12 |
-
import six
|
13 |
-
import string
|
14 |
-
|
15 |
-
from lmdbdict import lmdbdict
|
16 |
-
from lmdbdict.methods import DUMPS_FUNC, LOADS_FUNC
|
17 |
-
import pickle
|
18 |
-
import tqdm
|
19 |
-
import numpy as np
|
20 |
-
import argparse
|
21 |
-
import json
|
22 |
-
|
23 |
-
import torch
|
24 |
-
import torch.utils.data as data
|
25 |
-
from torch.utils.data import DataLoader
|
26 |
-
|
27 |
-
import csv
|
28 |
-
csv.field_size_limit(sys.maxsize)
|
29 |
-
FIELDNAMES = ['image_id', 'status']
|
30 |
-
|
31 |
-
class FolderLMDB(data.Dataset):
|
32 |
-
def __init__(self, db_path, fn_list=None):
|
33 |
-
self.db_path = db_path
|
34 |
-
self.lmdb = lmdbdict(db_path, unsafe=True)
|
35 |
-
self.lmdb._key_dumps = DUMPS_FUNC['ascii']
|
36 |
-
self.lmdb._value_loads = LOADS_FUNC['identity']
|
37 |
-
if fn_list is not None:
|
38 |
-
self.length = len(fn_list)
|
39 |
-
self.keys = fn_list
|
40 |
-
else:
|
41 |
-
raise Error
|
42 |
-
|
43 |
-
def __getitem__(self, index):
|
44 |
-
byteflow = self.lmdb[self.keys[index]]
|
45 |
-
|
46 |
-
# load image
|
47 |
-
imgbuf = byteflow
|
48 |
-
buf = six.BytesIO()
|
49 |
-
buf.write(imgbuf)
|
50 |
-
buf.seek(0)
|
51 |
-
try:
|
52 |
-
if args.extension == '.npz':
|
53 |
-
feat = np.load(buf)['feat']
|
54 |
-
else:
|
55 |
-
feat = np.load(buf)
|
56 |
-
except Exception as e:
|
57 |
-
print(self.keys[index], e)
|
58 |
-
return None
|
59 |
-
|
60 |
-
return feat
|
61 |
-
|
62 |
-
def __len__(self):
|
63 |
-
return self.length
|
64 |
-
|
65 |
-
def __repr__(self):
|
66 |
-
return self.__class__.__name__ + ' (' + self.db_path + ')'
|
67 |
-
|
68 |
-
|
69 |
-
def make_dataset(dir, extension):
|
70 |
-
images = []
|
71 |
-
dir = os.path.expanduser(dir)
|
72 |
-
for root, _, fnames in sorted(os.walk(dir)):
|
73 |
-
for fname in sorted(fnames):
|
74 |
-
if has_file_allowed_extension(fname, [extension]):
|
75 |
-
path = os.path.join(root, fname)
|
76 |
-
images.append(path)
|
77 |
-
|
78 |
-
return images
|
79 |
-
|
80 |
-
|
81 |
-
def raw_reader(path):
|
82 |
-
with open(path, 'rb') as f:
|
83 |
-
bin_data = f.read()
|
84 |
-
return bin_data
|
85 |
-
|
86 |
-
|
87 |
-
def raw_npz_reader(path):
|
88 |
-
with open(path, 'rb') as f:
|
89 |
-
bin_data = f.read()
|
90 |
-
try:
|
91 |
-
npz_data = np.load(six.BytesIO(bin_data))['feat']
|
92 |
-
except Exception as e:
|
93 |
-
print(path)
|
94 |
-
npz_data = None
|
95 |
-
print(e)
|
96 |
-
return bin_data, npz_data
|
97 |
-
|
98 |
-
|
99 |
-
def raw_npy_reader(path):
|
100 |
-
with open(path, 'rb') as f:
|
101 |
-
bin_data = f.read()
|
102 |
-
try:
|
103 |
-
npy_data = np.load(six.BytesIO(bin_data))
|
104 |
-
except Exception as e:
|
105 |
-
print(path)
|
106 |
-
npy_data = None
|
107 |
-
print(e)
|
108 |
-
return bin_data, npy_data
|
109 |
-
|
110 |
-
|
111 |
-
class Folder(data.Dataset):
|
112 |
-
|
113 |
-
def __init__(self, root, loader, extension, fn_list=None):
|
114 |
-
super(Folder, self).__init__()
|
115 |
-
self.root = root
|
116 |
-
if fn_list:
|
117 |
-
samples = [os.path.join(root, str(_)+extension) for _ in fn_list]
|
118 |
-
else:
|
119 |
-
samples = make_dataset(self.root, extension)
|
120 |
-
|
121 |
-
self.loader = loader
|
122 |
-
self.extension = extension
|
123 |
-
self.samples = samples
|
124 |
-
|
125 |
-
def __getitem__(self, index):
|
126 |
-
"""
|
127 |
-
Args:
|
128 |
-
index (int): Index
|
129 |
-
Returns:
|
130 |
-
tuple: (sample, target) where target is class_index of the target class.
|
131 |
-
"""
|
132 |
-
path = self.samples[index]
|
133 |
-
sample = self.loader(path)
|
134 |
-
|
135 |
-
return (path.split('/')[-1].split('.')[0],) + sample
|
136 |
-
|
137 |
-
def __len__(self):
|
138 |
-
return len(self.samples)
|
139 |
-
|
140 |
-
|
141 |
-
def folder2lmdb(dpath, fn_list, write_frequency=5000):
|
142 |
-
directory = osp.expanduser(osp.join(dpath))
|
143 |
-
print("Loading dataset from %s" % directory)
|
144 |
-
if args.extension == '.npz':
|
145 |
-
dataset = Folder(directory, loader=raw_npz_reader, extension='.npz',
|
146 |
-
fn_list=fn_list)
|
147 |
-
else:
|
148 |
-
dataset = Folder(directory, loader=raw_npy_reader, extension='.npy',
|
149 |
-
fn_list=fn_list)
|
150 |
-
data_loader = DataLoader(dataset, num_workers=16, collate_fn=lambda x: x)
|
151 |
-
|
152 |
-
# lmdb_path = osp.join(dpath, "%s.lmdb" % (directory.split('/')[-1]))
|
153 |
-
lmdb_path = osp.join("%s.lmdb" % (directory))
|
154 |
-
isdir = os.path.isdir(lmdb_path)
|
155 |
-
|
156 |
-
print("Generate LMDB to %s" % lmdb_path)
|
157 |
-
db = lmdbdict(lmdb_path, mode='w', key_method='ascii', value_method='identity')
|
158 |
-
|
159 |
-
tsvfile = open(args.output_file, 'a')
|
160 |
-
writer = csv.DictWriter(tsvfile, delimiter='\t', fieldnames=FIELDNAMES)
|
161 |
-
names = []
|
162 |
-
all_keys = []
|
163 |
-
for idx, data in enumerate(tqdm.tqdm(data_loader)):
|
164 |
-
# print(type(data), data)
|
165 |
-
name, byte, npz = data[0]
|
166 |
-
if npz is not None:
|
167 |
-
db[name] = byte
|
168 |
-
all_keys.append(name)
|
169 |
-
names.append({'image_id': name, 'status': str(npz is not None)})
|
170 |
-
if idx % write_frequency == 0:
|
171 |
-
print("[%d/%d]" % (idx, len(data_loader)))
|
172 |
-
print('writing')
|
173 |
-
db.flush()
|
174 |
-
# write in tsv
|
175 |
-
for name in names:
|
176 |
-
writer.writerow(name)
|
177 |
-
names = []
|
178 |
-
tsvfile.flush()
|
179 |
-
print('writing finished')
|
180 |
-
# write all keys
|
181 |
-
# txn.put("keys".encode(), pickle.dumps(all_keys))
|
182 |
-
# # finish iterating through dataset
|
183 |
-
# txn.commit()
|
184 |
-
for name in names:
|
185 |
-
writer.writerow(name)
|
186 |
-
tsvfile.flush()
|
187 |
-
tsvfile.close()
|
188 |
-
|
189 |
-
print("Flushing database ...")
|
190 |
-
db.flush()
|
191 |
-
del db
|
192 |
-
|
193 |
-
def parse_args():
|
194 |
-
"""
|
195 |
-
Parse input arguments
|
196 |
-
"""
|
197 |
-
parser = argparse.ArgumentParser(description='Generate bbox output from a Fast R-CNN network')
|
198 |
-
# parser.add_argument('--json)
|
199 |
-
parser.add_argument('--input_json', default='./data/dataset_coco.json', type=str)
|
200 |
-
parser.add_argument('--output_file', default='.dump_cache.tsv', type=str)
|
201 |
-
parser.add_argument('--folder', default='./data/cocobu_att', type=str)
|
202 |
-
parser.add_argument('--extension', default='.npz', type=str)
|
203 |
-
|
204 |
-
args = parser.parse_args()
|
205 |
-
return args
|
206 |
-
|
207 |
-
if __name__ == "__main__":
|
208 |
-
global args
|
209 |
-
args = parse_args()
|
210 |
-
|
211 |
-
args.output_file += args.folder.split('/')[-1]
|
212 |
-
if args.folder.find('/') > 0:
|
213 |
-
args.output_file = args.folder[:args.folder.rfind('/')+1]+args.output_file
|
214 |
-
print(args.output_file)
|
215 |
-
|
216 |
-
img_list = json.load(open(args.input_json, 'r'))['images']
|
217 |
-
fn_list = [str(_['cocoid']) for _ in img_list]
|
218 |
-
found_ids = set()
|
219 |
-
try:
|
220 |
-
with open(args.output_file, 'r') as tsvfile:
|
221 |
-
reader = csv.DictReader(tsvfile, delimiter='\t', fieldnames=FIELDNAMES)
|
222 |
-
for item in reader:
|
223 |
-
if item['status'] == 'True':
|
224 |
-
found_ids.add(item['image_id'])
|
225 |
-
except:
|
226 |
-
pass
|
227 |
-
fn_list = [_ for _ in fn_list if _ not in found_ids]
|
228 |
-
folder2lmdb(args.folder, fn_list)
|
229 |
-
|
230 |
-
# Test existing.
|
231 |
-
found_ids = set()
|
232 |
-
with open(args.output_file, 'r') as tsvfile:
|
233 |
-
reader = csv.DictReader(tsvfile, delimiter='\t', fieldnames=FIELDNAMES)
|
234 |
-
for item in reader:
|
235 |
-
if item['status'] == 'True':
|
236 |
-
found_ids.add(item['image_id'])
|
237 |
-
|
238 |
-
folder_dataset = FolderLMDB(args.folder+'.lmdb', list(found_ids))
|
239 |
-
data_loader = DataLoader(folder_dataset, num_workers=16, collate_fn=lambda x: x)
|
240 |
-
for data in tqdm.tqdm(data_loader):
|
241 |
-
assert data[0] is not None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/make_bu_data.py
DELETED
@@ -1,52 +0,0 @@
|
|
1 |
-
from __future__ import absolute_import
|
2 |
-
from __future__ import division
|
3 |
-
from __future__ import print_function
|
4 |
-
|
5 |
-
import os
|
6 |
-
import base64
|
7 |
-
import numpy as np
|
8 |
-
import csv
|
9 |
-
import sys
|
10 |
-
import zlib
|
11 |
-
import time
|
12 |
-
import mmap
|
13 |
-
import argparse
|
14 |
-
|
15 |
-
parser = argparse.ArgumentParser()
|
16 |
-
|
17 |
-
# output_dir
|
18 |
-
parser.add_argument('--downloaded_feats', default='data/bu_data', help='downloaded feature directory')
|
19 |
-
parser.add_argument('--output_dir', default='data/cocobu', help='output feature files')
|
20 |
-
|
21 |
-
args = parser.parse_args()
|
22 |
-
|
23 |
-
csv.field_size_limit(sys.maxsize)
|
24 |
-
|
25 |
-
|
26 |
-
FIELDNAMES = ['image_id', 'image_w','image_h','num_boxes', 'boxes', 'features']
|
27 |
-
infiles = ['trainval/karpathy_test_resnet101_faster_rcnn_genome.tsv',
|
28 |
-
'trainval/karpathy_val_resnet101_faster_rcnn_genome.tsv',\
|
29 |
-
'trainval/karpathy_train_resnet101_faster_rcnn_genome.tsv.0', \
|
30 |
-
'trainval/karpathy_train_resnet101_faster_rcnn_genome.tsv.1']
|
31 |
-
|
32 |
-
os.makedirs(args.output_dir+'_att')
|
33 |
-
os.makedirs(args.output_dir+'_fc')
|
34 |
-
os.makedirs(args.output_dir+'_box')
|
35 |
-
|
36 |
-
for infile in infiles:
|
37 |
-
print('Reading ' + infile)
|
38 |
-
with open(os.path.join(args.downloaded_feats, infile), "r") as tsv_in_file:
|
39 |
-
reader = csv.DictReader(tsv_in_file, delimiter='\t', fieldnames = FIELDNAMES)
|
40 |
-
for item in reader:
|
41 |
-
item['image_id'] = int(item['image_id'])
|
42 |
-
item['num_boxes'] = int(item['num_boxes'])
|
43 |
-
for field in ['boxes', 'features']:
|
44 |
-
item[field] = np.frombuffer(base64.decodestring(item[field].encode('ascii')),
|
45 |
-
dtype=np.float32).reshape((item['num_boxes'],-1))
|
46 |
-
np.savez_compressed(os.path.join(args.output_dir+'_att', str(item['image_id'])), feat=item['features'])
|
47 |
-
np.save(os.path.join(args.output_dir+'_fc', str(item['image_id'])), item['features'].mean(0))
|
48 |
-
np.save(os.path.join(args.output_dir+'_box', str(item['image_id'])), item['boxes'])
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/prepro_feats.py
DELETED
@@ -1,103 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Preprocess a raw json dataset into features files for use in data_loader.py
|
3 |
-
|
4 |
-
Input: json file that has the form
|
5 |
-
[{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...]
|
6 |
-
example element in this list would look like
|
7 |
-
{'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895}
|
8 |
-
|
9 |
-
This script reads this json, does some basic preprocessing on the captions
|
10 |
-
(e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays
|
11 |
-
|
12 |
-
Output: two folders of features
|
13 |
-
"""
|
14 |
-
|
15 |
-
from __future__ import absolute_import
|
16 |
-
from __future__ import division
|
17 |
-
from __future__ import print_function
|
18 |
-
|
19 |
-
import os
|
20 |
-
import json
|
21 |
-
import argparse
|
22 |
-
from random import shuffle, seed
|
23 |
-
import string
|
24 |
-
# non-standard dependencies:
|
25 |
-
import h5py
|
26 |
-
from six.moves import cPickle
|
27 |
-
import numpy as np
|
28 |
-
import torch
|
29 |
-
import torchvision.models as models
|
30 |
-
import skimage.io
|
31 |
-
|
32 |
-
from torchvision import transforms as trn
|
33 |
-
preprocess = trn.Compose([
|
34 |
-
#trn.ToTensor(),
|
35 |
-
trn.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
36 |
-
])
|
37 |
-
|
38 |
-
from captioning.utils.resnet_utils import myResnet
|
39 |
-
import captioning.utils.resnet as resnet
|
40 |
-
|
41 |
-
|
42 |
-
def main(params):
|
43 |
-
net = getattr(resnet, params['model'])()
|
44 |
-
net.load_state_dict(torch.load(os.path.join(params['model_root'],params['model']+'.pth')))
|
45 |
-
my_resnet = myResnet(net)
|
46 |
-
my_resnet.cuda()
|
47 |
-
my_resnet.eval()
|
48 |
-
|
49 |
-
imgs = json.load(open(params['input_json'], 'r'))
|
50 |
-
imgs = imgs['images']
|
51 |
-
N = len(imgs)
|
52 |
-
|
53 |
-
seed(123) # make reproducible
|
54 |
-
|
55 |
-
dir_fc = params['output_dir']+'_fc'
|
56 |
-
dir_att = params['output_dir']+'_att'
|
57 |
-
if not os.path.isdir(dir_fc):
|
58 |
-
os.mkdir(dir_fc)
|
59 |
-
if not os.path.isdir(dir_att):
|
60 |
-
os.mkdir(dir_att)
|
61 |
-
|
62 |
-
for i,img in enumerate(imgs):
|
63 |
-
# load the image
|
64 |
-
I = skimage.io.imread(os.path.join(params['images_root'], img['filepath'], img['filename']))
|
65 |
-
# handle grayscale input images
|
66 |
-
if len(I.shape) == 2:
|
67 |
-
I = I[:,:,np.newaxis]
|
68 |
-
I = np.concatenate((I,I,I), axis=2)
|
69 |
-
|
70 |
-
I = I.astype('float32')/255.0
|
71 |
-
I = torch.from_numpy(I.transpose([2,0,1])).cuda()
|
72 |
-
I = preprocess(I)
|
73 |
-
with torch.no_grad():
|
74 |
-
tmp_fc, tmp_att = my_resnet(I, params['att_size'])
|
75 |
-
# write to pkl
|
76 |
-
# print(dir_fc, str(img['cocoid']), tmp_fc.shape, tmp_att.shape, dir_att)
|
77 |
-
# exit()
|
78 |
-
np.save(os.path.join(dir_fc, str(img['cocoid'])), tmp_fc.data.cpu().float().numpy())
|
79 |
-
np.savez_compressed(os.path.join(dir_att, str(img['cocoid'])), feat=tmp_att.data.cpu().float().numpy())
|
80 |
-
|
81 |
-
if i % 1000 == 0:
|
82 |
-
print('processing %d/%d (%.2f%% done)' % (i, N, i*100.0/N))
|
83 |
-
print('wrote ', params['output_dir'])
|
84 |
-
|
85 |
-
if __name__ == "__main__":
|
86 |
-
|
87 |
-
parser = argparse.ArgumentParser()
|
88 |
-
|
89 |
-
# input json
|
90 |
-
parser.add_argument('--input_json', required=True, help='input json file to process into hdf5')
|
91 |
-
parser.add_argument('--output_dir', default='data', help='output h5 file')
|
92 |
-
|
93 |
-
# options
|
94 |
-
parser.add_argument('--images_root', default='', help='root location in which images are stored, to be prepended to file_path in input json')
|
95 |
-
parser.add_argument('--att_size', default=14, type=int, help='14x14 or 7x7')
|
96 |
-
parser.add_argument('--model', default='resnet101', type=str, help='resnet101, resnet152')
|
97 |
-
parser.add_argument('--model_root', default='./data/imagenet_weights', type=str, help='model root')
|
98 |
-
|
99 |
-
args = parser.parse_args()
|
100 |
-
params = vars(args) # convert to ordinary dict
|
101 |
-
print('parsed input parameters:')
|
102 |
-
print(json.dumps(params, indent = 2))
|
103 |
-
main(params)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/prepro_labels.py
DELETED
@@ -1,206 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Preprocess a raw json dataset into hdf5/json files for use in data_loader.py
|
3 |
-
|
4 |
-
Input: json file that has the form
|
5 |
-
[{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...]
|
6 |
-
example element in this list would look like
|
7 |
-
{'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895}
|
8 |
-
|
9 |
-
This script reads this json, does some basic preprocessing on the captions
|
10 |
-
(e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays
|
11 |
-
|
12 |
-
Output: a json file and an hdf5 file
|
13 |
-
The hdf5 file contains several fields:
|
14 |
-
/labels is (M,max_length) uint32 array of encoded labels, zero padded
|
15 |
-
/label_start_ix and /label_end_ix are (N,) uint32 arrays of pointers to the
|
16 |
-
first and last indices (in range 1..M) of labels for each image
|
17 |
-
/label_length stores the length of the sequence for each of the M sequences
|
18 |
-
|
19 |
-
The json file has a dict that contains:
|
20 |
-
- an 'ix_to_word' field storing the vocab in form {ix:'word'}, where ix is 1-indexed
|
21 |
-
- an 'images' field that is a list holding auxiliary information for each image,
|
22 |
-
such as in particular the 'split' it was assigned to.
|
23 |
-
"""
|
24 |
-
|
25 |
-
from __future__ import absolute_import
|
26 |
-
from __future__ import division
|
27 |
-
from __future__ import print_function
|
28 |
-
|
29 |
-
import os
|
30 |
-
import json
|
31 |
-
import argparse
|
32 |
-
from random import shuffle, seed
|
33 |
-
import string
|
34 |
-
# non-standard dependencies:
|
35 |
-
import h5py
|
36 |
-
import numpy as np
|
37 |
-
import torch
|
38 |
-
import torchvision.models as models
|
39 |
-
import skimage.io
|
40 |
-
from PIL import Image
|
41 |
-
|
42 |
-
|
43 |
-
def build_vocab(imgs, params):
|
44 |
-
count_thr = params['word_count_threshold']
|
45 |
-
|
46 |
-
# count up the number of words
|
47 |
-
counts = {}
|
48 |
-
for img in imgs:
|
49 |
-
for sent in img['sentences']:
|
50 |
-
for w in sent['tokens']:
|
51 |
-
counts[w] = counts.get(w, 0) + 1
|
52 |
-
cw = sorted([(count,w) for w,count in counts.items()], reverse=True)
|
53 |
-
print('top words and their counts:')
|
54 |
-
print('\n'.join(map(str,cw[:20])))
|
55 |
-
|
56 |
-
# print some stats
|
57 |
-
total_words = sum(counts.values())
|
58 |
-
print('total words:', total_words)
|
59 |
-
bad_words = [w for w,n in counts.items() if n <= count_thr]
|
60 |
-
vocab = [w for w,n in counts.items() if n > count_thr]
|
61 |
-
bad_count = sum(counts[w] for w in bad_words)
|
62 |
-
print('number of bad words: %d/%d = %.2f%%' % (len(bad_words), len(counts), len(bad_words)*100.0/len(counts)))
|
63 |
-
print('number of words in vocab would be %d' % (len(vocab), ))
|
64 |
-
print('number of UNKs: %d/%d = %.2f%%' % (bad_count, total_words, bad_count*100.0/total_words))
|
65 |
-
|
66 |
-
# lets look at the distribution of lengths as well
|
67 |
-
sent_lengths = {}
|
68 |
-
for img in imgs:
|
69 |
-
for sent in img['sentences']:
|
70 |
-
txt = sent['tokens']
|
71 |
-
nw = len(txt)
|
72 |
-
sent_lengths[nw] = sent_lengths.get(nw, 0) + 1
|
73 |
-
max_len = max(sent_lengths.keys())
|
74 |
-
print('max length sentence in raw data: ', max_len)
|
75 |
-
print('sentence length distribution (count, number of words):')
|
76 |
-
sum_len = sum(sent_lengths.values())
|
77 |
-
for i in range(max_len+1):
|
78 |
-
print('%2d: %10d %f%%' % (i, sent_lengths.get(i,0), sent_lengths.get(i,0)*100.0/sum_len))
|
79 |
-
|
80 |
-
# lets now produce the final annotations
|
81 |
-
if bad_count > 0:
|
82 |
-
# additional special UNK token we will use below to map infrequent words to
|
83 |
-
print('inserting the special UNK token')
|
84 |
-
vocab.append('UNK')
|
85 |
-
|
86 |
-
for img in imgs:
|
87 |
-
img['final_captions'] = []
|
88 |
-
for sent in img['sentences']:
|
89 |
-
txt = sent['tokens']
|
90 |
-
caption = [w if counts.get(w,0) > count_thr else 'UNK' for w in txt]
|
91 |
-
img['final_captions'].append(caption)
|
92 |
-
|
93 |
-
return vocab
|
94 |
-
|
95 |
-
|
96 |
-
def encode_captions(imgs, params, wtoi):
|
97 |
-
"""
|
98 |
-
encode all captions into one large array, which will be 1-indexed.
|
99 |
-
also produces label_start_ix and label_end_ix which store 1-indexed
|
100 |
-
and inclusive (Lua-style) pointers to the first and last caption for
|
101 |
-
each image in the dataset.
|
102 |
-
"""
|
103 |
-
|
104 |
-
max_length = params['max_length']
|
105 |
-
N = len(imgs)
|
106 |
-
M = sum(len(img['final_captions']) for img in imgs) # total number of captions
|
107 |
-
|
108 |
-
label_arrays = []
|
109 |
-
label_start_ix = np.zeros(N, dtype='uint32') # note: these will be one-indexed
|
110 |
-
label_end_ix = np.zeros(N, dtype='uint32')
|
111 |
-
label_length = np.zeros(M, dtype='uint32')
|
112 |
-
caption_counter = 0
|
113 |
-
counter = 1
|
114 |
-
for i,img in enumerate(imgs):
|
115 |
-
n = len(img['final_captions'])
|
116 |
-
assert n > 0, 'error: some image has no captions'
|
117 |
-
|
118 |
-
Li = np.zeros((n, max_length), dtype='uint32')
|
119 |
-
for j,s in enumerate(img['final_captions']):
|
120 |
-
label_length[caption_counter] = min(max_length, len(s)) # record the length of this sequence
|
121 |
-
caption_counter += 1
|
122 |
-
for k,w in enumerate(s):
|
123 |
-
if k < max_length:
|
124 |
-
Li[j,k] = wtoi[w]
|
125 |
-
|
126 |
-
# note: word indices are 1-indexed, and captions are padded with zeros
|
127 |
-
label_arrays.append(Li)
|
128 |
-
label_start_ix[i] = counter
|
129 |
-
label_end_ix[i] = counter + n - 1
|
130 |
-
|
131 |
-
counter += n
|
132 |
-
|
133 |
-
L = np.concatenate(label_arrays, axis=0) # put all the labels together
|
134 |
-
assert L.shape[0] == M, 'lengths don\'t match? that\'s weird'
|
135 |
-
assert np.all(label_length > 0), 'error: some caption had no words?'
|
136 |
-
|
137 |
-
print('encoded captions to array of size ', L.shape)
|
138 |
-
return L, label_start_ix, label_end_ix, label_length
|
139 |
-
|
140 |
-
|
141 |
-
def main(params):
|
142 |
-
|
143 |
-
imgs = json.load(open(params['input_json'], 'r'))
|
144 |
-
imgs = imgs['images']
|
145 |
-
|
146 |
-
seed(123) # make reproducible
|
147 |
-
|
148 |
-
# create the vocab
|
149 |
-
vocab = build_vocab(imgs, params)
|
150 |
-
itow = {i+1:w for i,w in enumerate(vocab)} # a 1-indexed vocab translation table
|
151 |
-
wtoi = {w:i+1 for i,w in enumerate(vocab)} # inverse table
|
152 |
-
|
153 |
-
# encode captions in large arrays, ready to ship to hdf5 file
|
154 |
-
L, label_start_ix, label_end_ix, label_length = encode_captions(imgs, params, wtoi)
|
155 |
-
|
156 |
-
# create output h5 file
|
157 |
-
N = len(imgs)
|
158 |
-
f_lb = h5py.File(params['output_h5']+'_label.h5', "w")
|
159 |
-
f_lb.create_dataset("labels", dtype='uint32', data=L)
|
160 |
-
f_lb.create_dataset("label_start_ix", dtype='uint32', data=label_start_ix)
|
161 |
-
f_lb.create_dataset("label_end_ix", dtype='uint32', data=label_end_ix)
|
162 |
-
f_lb.create_dataset("label_length", dtype='uint32', data=label_length)
|
163 |
-
f_lb.close()
|
164 |
-
|
165 |
-
# create output json file
|
166 |
-
out = {}
|
167 |
-
out['ix_to_word'] = itow # encode the (1-indexed) vocab
|
168 |
-
out['images'] = []
|
169 |
-
for i,img in enumerate(imgs):
|
170 |
-
|
171 |
-
jimg = {}
|
172 |
-
jimg['split'] = img['split']
|
173 |
-
if 'filename' in img: jimg['file_path'] = os.path.join(img.get('filepath', ''), img['filename']) # copy it over, might need
|
174 |
-
if 'cocoid' in img:
|
175 |
-
jimg['id'] = img['cocoid'] # copy over & mantain an id, if present (e.g. coco ids, useful)
|
176 |
-
elif 'imgid' in img:
|
177 |
-
jimg['id'] = img['imgid']
|
178 |
-
|
179 |
-
if params['images_root'] != '':
|
180 |
-
with Image.open(os.path.join(params['images_root'], img['filepath'], img['filename'])) as _img:
|
181 |
-
jimg['width'], jimg['height'] = _img.size
|
182 |
-
|
183 |
-
out['images'].append(jimg)
|
184 |
-
|
185 |
-
json.dump(out, open(params['output_json'], 'w'))
|
186 |
-
print('wrote ', params['output_json'])
|
187 |
-
|
188 |
-
if __name__ == "__main__":
|
189 |
-
|
190 |
-
parser = argparse.ArgumentParser()
|
191 |
-
|
192 |
-
# input json
|
193 |
-
parser.add_argument('--input_json', required=True, help='input json file to process into hdf5')
|
194 |
-
parser.add_argument('--output_json', default='data.json', help='output json file')
|
195 |
-
parser.add_argument('--output_h5', default='data', help='output h5 file')
|
196 |
-
parser.add_argument('--images_root', default='', help='root location in which images are stored, to be prepended to file_path in input json')
|
197 |
-
|
198 |
-
# options
|
199 |
-
parser.add_argument('--max_length', default=16, type=int, help='max length of a caption, in number of words. captions longer than this get clipped.')
|
200 |
-
parser.add_argument('--word_count_threshold', default=5, type=int, help='only words that occur more than this number of times will be put in vocab')
|
201 |
-
|
202 |
-
args = parser.parse_args()
|
203 |
-
params = vars(args) # convert to ordinary dict
|
204 |
-
print('parsed input parameters:')
|
205 |
-
print(json.dumps(params, indent = 2))
|
206 |
-
main(params)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/prepro_ngrams.py
DELETED
@@ -1,94 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Precompute ngram counts of captions, to accelerate cider computation during training time.
|
3 |
-
"""
|
4 |
-
|
5 |
-
import os
|
6 |
-
import json
|
7 |
-
import argparse
|
8 |
-
from six.moves import cPickle
|
9 |
-
import captioning.utils.misc as utils
|
10 |
-
from collections import defaultdict
|
11 |
-
|
12 |
-
import sys
|
13 |
-
sys.path.append("cider")
|
14 |
-
from pyciderevalcap.ciderD.ciderD_scorer import CiderScorer
|
15 |
-
|
16 |
-
|
17 |
-
def get_doc_freq(refs, params):
|
18 |
-
tmp = CiderScorer(df_mode="corpus")
|
19 |
-
for ref in refs:
|
20 |
-
tmp.cook_append(None, ref)
|
21 |
-
tmp.compute_doc_freq()
|
22 |
-
return tmp.document_frequency, len(tmp.crefs)
|
23 |
-
|
24 |
-
|
25 |
-
def build_dict(imgs, wtoi, params):
|
26 |
-
wtoi['<eos>'] = 0
|
27 |
-
|
28 |
-
count_imgs = 0
|
29 |
-
|
30 |
-
refs_words = []
|
31 |
-
refs_idxs = []
|
32 |
-
for img in imgs:
|
33 |
-
if (params['split'] == img['split']) or \
|
34 |
-
(params['split'] == 'train' and img['split'] == 'restval') or \
|
35 |
-
(params['split'] == 'all'):
|
36 |
-
#(params['split'] == 'val' and img['split'] == 'restval') or \
|
37 |
-
ref_words = []
|
38 |
-
ref_idxs = []
|
39 |
-
for sent in img['sentences']:
|
40 |
-
if hasattr(params, 'bpe'):
|
41 |
-
sent['tokens'] = params.bpe.segment(' '.join(sent['tokens'])).strip().split(' ')
|
42 |
-
tmp_tokens = sent['tokens'] + ['<eos>']
|
43 |
-
tmp_tokens = [_ if _ in wtoi else 'UNK' for _ in tmp_tokens]
|
44 |
-
ref_words.append(' '.join(tmp_tokens))
|
45 |
-
ref_idxs.append(' '.join([str(wtoi[_]) for _ in tmp_tokens]))
|
46 |
-
refs_words.append(ref_words)
|
47 |
-
refs_idxs.append(ref_idxs)
|
48 |
-
count_imgs += 1
|
49 |
-
print('total imgs:', count_imgs)
|
50 |
-
|
51 |
-
ngram_words, count_refs = get_doc_freq(refs_words, params)
|
52 |
-
ngram_idxs, count_refs = get_doc_freq(refs_idxs, params)
|
53 |
-
print('count_refs:', count_refs)
|
54 |
-
return ngram_words, ngram_idxs, count_refs
|
55 |
-
|
56 |
-
def main(params):
|
57 |
-
|
58 |
-
imgs = json.load(open(params['input_json'], 'r'))
|
59 |
-
dict_json = json.load(open(params['dict_json'], 'r'))
|
60 |
-
itow = dict_json['ix_to_word']
|
61 |
-
wtoi = {w:i for i,w in itow.items()}
|
62 |
-
|
63 |
-
# Load bpe
|
64 |
-
if 'bpe' in dict_json:
|
65 |
-
import tempfile
|
66 |
-
import codecs
|
67 |
-
codes_f = tempfile.NamedTemporaryFile(delete=False)
|
68 |
-
codes_f.close()
|
69 |
-
with open(codes_f.name, 'w') as f:
|
70 |
-
f.write(dict_json['bpe'])
|
71 |
-
with codecs.open(codes_f.name, encoding='UTF-8') as codes:
|
72 |
-
bpe = apply_bpe.BPE(codes)
|
73 |
-
params.bpe = bpe
|
74 |
-
|
75 |
-
imgs = imgs['images']
|
76 |
-
|
77 |
-
ngram_words, ngram_idxs, ref_len = build_dict(imgs, wtoi, params)
|
78 |
-
|
79 |
-
utils.pickle_dump({'document_frequency': ngram_words, 'ref_len': ref_len}, open(params['output_pkl']+'-words.p','wb'))
|
80 |
-
utils.pickle_dump({'document_frequency': ngram_idxs, 'ref_len': ref_len}, open(params['output_pkl']+'-idxs.p','wb'))
|
81 |
-
|
82 |
-
if __name__ == "__main__":
|
83 |
-
|
84 |
-
parser = argparse.ArgumentParser()
|
85 |
-
|
86 |
-
# input json
|
87 |
-
parser.add_argument('--input_json', default='data/dataset_coco.json', help='input json file to process into hdf5')
|
88 |
-
parser.add_argument('--dict_json', default='data/cocotalk.json', help='output json file')
|
89 |
-
parser.add_argument('--output_pkl', default='data/coco-all', help='output pickle file')
|
90 |
-
parser.add_argument('--split', default='all', help='test, val, train, all')
|
91 |
-
args = parser.parse_args()
|
92 |
-
params = vars(args) # convert to ordinary dict
|
93 |
-
|
94 |
-
main(params)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/prepro_reference_json.py
DELETED
@@ -1,69 +0,0 @@
|
|
1 |
-
# coding: utf-8
|
2 |
-
"""
|
3 |
-
Create a reference json file used for evaluation with `coco-caption` repo.
|
4 |
-
Used when reference json is not provided, (e.g., flickr30k, or you have your own split of train/val/test)
|
5 |
-
"""
|
6 |
-
|
7 |
-
from __future__ import absolute_import
|
8 |
-
from __future__ import division
|
9 |
-
from __future__ import print_function
|
10 |
-
|
11 |
-
import os
|
12 |
-
import json
|
13 |
-
import argparse
|
14 |
-
import sys
|
15 |
-
import hashlib
|
16 |
-
from random import shuffle, seed
|
17 |
-
|
18 |
-
|
19 |
-
def main(params):
|
20 |
-
|
21 |
-
imgs = json.load(open(params['input_json'][0], 'r'))['images']
|
22 |
-
# tmp = []
|
23 |
-
# for k in imgs.keys():
|
24 |
-
# for img in imgs[k]:
|
25 |
-
# img['filename'] = img['image_id'] # k+'/'+img['image_id']
|
26 |
-
# img['image_id'] = int(
|
27 |
-
# int(hashlib.sha256(img['image_id']).hexdigest(), 16) % sys.maxint)
|
28 |
-
# tmp.append(img)
|
29 |
-
# imgs = tmp
|
30 |
-
|
31 |
-
# create output json file
|
32 |
-
out = {'info': {'description': 'This is stable 1.0 version of the 2014 MS COCO dataset.', 'url': 'http://mscoco.org', 'version': '1.0', 'year': 2014, 'contributor': 'Microsoft COCO group', 'date_created': '2015-01-27 09:11:52.357475'}, 'licenses': [{'url': 'http://creativecommons.org/licenses/by-nc-sa/2.0/', 'id': 1, 'name': 'Attribution-NonCommercial-ShareAlike License'}, {'url': 'http://creativecommons.org/licenses/by-nc/2.0/', 'id': 2, 'name': 'Attribution-NonCommercial License'}, {'url': 'http://creativecommons.org/licenses/by-nc-nd/2.0/', 'id': 3, 'name': 'Attribution-NonCommercial-NoDerivs License'}, {'url': 'http://creativecommons.org/licenses/by/2.0/', 'id': 4, 'name': 'Attribution License'}, {'url': 'http://creativecommons.org/licenses/by-sa/2.0/', 'id': 5, 'name': 'Attribution-ShareAlike License'}, {'url': 'http://creativecommons.org/licenses/by-nd/2.0/', 'id': 6, 'name': 'Attribution-NoDerivs License'}, {'url': 'http://flickr.com/commons/usage/', 'id': 7, 'name': 'No known copyright restrictions'}, {'url': 'http://www.usa.gov/copyright.shtml', 'id': 8, 'name': 'United States Government Work'}], 'type': 'captions'}
|
33 |
-
out.update({'images': [], 'annotations': []})
|
34 |
-
|
35 |
-
cnt = 0
|
36 |
-
empty_cnt = 0
|
37 |
-
for i, img in enumerate(imgs):
|
38 |
-
if img['split'] == 'train':
|
39 |
-
continue
|
40 |
-
out['images'].append(
|
41 |
-
{'id': img.get('cocoid', img['imgid'])})
|
42 |
-
for j, s in enumerate(img['sentences']):
|
43 |
-
if len(s) == 0:
|
44 |
-
continue
|
45 |
-
s = ' '.join(s['tokens'])
|
46 |
-
out['annotations'].append(
|
47 |
-
{'image_id': out['images'][-1]['id'], 'caption': s, 'id': cnt})
|
48 |
-
cnt += 1
|
49 |
-
|
50 |
-
json.dump(out, open(params['output_json'], 'w'))
|
51 |
-
print('wrote ', params['output_json'])
|
52 |
-
|
53 |
-
|
54 |
-
if __name__ == "__main__":
|
55 |
-
|
56 |
-
parser = argparse.ArgumentParser()
|
57 |
-
|
58 |
-
# input json
|
59 |
-
parser.add_argument('--input_json', nargs='+', required=True,
|
60 |
-
help='input json file to process into hdf5')
|
61 |
-
parser.add_argument('--output_json', default='data.json',
|
62 |
-
help='output json file')
|
63 |
-
|
64 |
-
args = parser.parse_args()
|
65 |
-
params = vars(args) # convert to ordinary dict
|
66 |
-
print('parsed input parameters:')
|
67 |
-
print(json.dumps(params, indent=2))
|
68 |
-
main(params)
|
69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts_FineCapEval/clip_prepro_feats.py
DELETED
@@ -1,163 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Preprocess a raw json dataset into features files for use in data_loader.py
|
3 |
-
|
4 |
-
Input: json file that has the form
|
5 |
-
[{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...]
|
6 |
-
example element in this list would look like
|
7 |
-
{'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895}
|
8 |
-
|
9 |
-
This script reads this json, does some basic preprocessing on the captions
|
10 |
-
(e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays
|
11 |
-
|
12 |
-
Output: two folders of features
|
13 |
-
"""
|
14 |
-
|
15 |
-
from __future__ import absolute_import
|
16 |
-
from __future__ import division
|
17 |
-
from __future__ import print_function
|
18 |
-
|
19 |
-
import os
|
20 |
-
import json
|
21 |
-
import argparse
|
22 |
-
from random import shuffle, seed
|
23 |
-
import string
|
24 |
-
# non-standard dependencies:
|
25 |
-
import h5py
|
26 |
-
from six.moves import cPickle
|
27 |
-
import numpy as np
|
28 |
-
import torch
|
29 |
-
import torchvision.models as models
|
30 |
-
import skimage.io
|
31 |
-
|
32 |
-
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
33 |
-
from PIL import Image
|
34 |
-
from torch import nn
|
35 |
-
|
36 |
-
preprocess = Compose([
|
37 |
-
Resize((448, 448), interpolation=Image.BICUBIC),
|
38 |
-
CenterCrop((448, 448)),
|
39 |
-
ToTensor()
|
40 |
-
])
|
41 |
-
|
42 |
-
|
43 |
-
from clip.clip import load
|
44 |
-
from timm.models.vision_transformer import resize_pos_embed
|
45 |
-
import timm
|
46 |
-
|
47 |
-
from captioning.utils.resnet_utils import myResnet
|
48 |
-
import captioning.utils.resnet as resnet
|
49 |
-
|
50 |
-
from tqdm import tqdm
|
51 |
-
|
52 |
-
|
53 |
-
def main(params):
|
54 |
-
if params["model_type"] != 'vit_base_patch32_224_in21k':
|
55 |
-
model, transform = load(params["model_type"], jit=False)
|
56 |
-
else:
|
57 |
-
model = timm.create_model(params["model_type"], pretrained=True)
|
58 |
-
model = model.cuda()
|
59 |
-
|
60 |
-
if params["model_type"] != 'vit_base_patch32_224_in21k':
|
61 |
-
save_model_type = params["model_type"].split("-")[0]
|
62 |
-
mean = torch.Tensor([0.48145466, 0.4578275, 0.40821073]).to("cuda").reshape(3, 1, 1)
|
63 |
-
std = torch.Tensor([0.26862954, 0.26130258, 0.27577711]).to("cuda").reshape(3, 1, 1)
|
64 |
-
|
65 |
-
if "RN" in params["model_type"]:
|
66 |
-
num_patches = 196 #600 * 1000 // 32 // 32
|
67 |
-
pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, model.visual.attnpool.positional_embedding.shape[-1], device='cuda'),)
|
68 |
-
pos_embed.weight = resize_pos_embed(model.visual.attnpool.positional_embedding.unsqueeze(0), pos_embed)
|
69 |
-
model.visual.attnpool.positional_embedding = pos_embed
|
70 |
-
|
71 |
-
else:
|
72 |
-
save_model_type = 'vit_base'
|
73 |
-
mean = torch.Tensor([0.5, 0.5, 0.5]).to("cuda").reshape(3, 1, 1)
|
74 |
-
std = torch.Tensor([0.5, 0.5, 0.5]).to("cuda").reshape(3, 1, 1)
|
75 |
-
|
76 |
-
num_patches = 196 #600 * 1000 // 32 // 32
|
77 |
-
pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, 768, device='cuda'),)
|
78 |
-
pos_embed.weight = resize_pos_embed(model.pos_embed, pos_embed)
|
79 |
-
model.pos_embed = pos_embed
|
80 |
-
|
81 |
-
if params["model_type"] == "ViT-B/32":
|
82 |
-
num_patches = 196 #600 * 1000 // 32 // 32
|
83 |
-
pos_embed = nn.Parameter(torch.zeros(num_patches + 1, 768, device='cuda'),)
|
84 |
-
pos_embed.weight = resize_pos_embed(model.visual.positional_embedding.unsqueeze(0), pos_embed.unsqueeze(0))
|
85 |
-
model.visual.positional_embedding = pos_embed
|
86 |
-
imgs = json.load(open(params['input_json'], 'r'))
|
87 |
-
imgs = imgs['images']
|
88 |
-
N = len(imgs)
|
89 |
-
|
90 |
-
seed(123) # make reproducible
|
91 |
-
|
92 |
-
dir_fc = params['output_dir']+'_clip_'+save_model_type+'_fc'
|
93 |
-
dir_att = params['output_dir']+'_clip_'+save_model_type+'_att'
|
94 |
-
if not os.path.isdir(dir_fc):
|
95 |
-
os.mkdir(dir_fc)
|
96 |
-
if not os.path.isdir(dir_att):
|
97 |
-
os.mkdir(dir_att)
|
98 |
-
|
99 |
-
for i, img in enumerate(tqdm(imgs)):
|
100 |
-
with torch.no_grad():
|
101 |
-
|
102 |
-
# img_path = os.path.join(params['images_root'], img['filepath'], img['filename'])
|
103 |
-
# img_path = os.path.join(params['images_root'], img['file_name'])
|
104 |
-
|
105 |
-
img_path = os.path.join(params['images_root'], img['file_path'])
|
106 |
-
|
107 |
-
image = preprocess(Image.open( img_path ).convert("RGB"))
|
108 |
-
image = torch.tensor(np.stack([image])).cuda()
|
109 |
-
image -= mean
|
110 |
-
image /= std
|
111 |
-
if "RN" in params["model_type"]:
|
112 |
-
tmp_att, tmp_fc = model.encode_image(image)
|
113 |
-
tmp_att = tmp_att[0].permute(1, 2, 0)
|
114 |
-
tmp_fc = tmp_fc[0]
|
115 |
-
elif params["model_type"] == 'vit_base_patch32_224_in21k':
|
116 |
-
x = model(image)
|
117 |
-
tmp_fc = x[0, 0, :]
|
118 |
-
tmp_att = x[0, 1:, :].reshape( 14, 14, 768 )
|
119 |
-
else:
|
120 |
-
x = model.visual.conv1(image.half()) # shape = [*, width, grid, grid]
|
121 |
-
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
122 |
-
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
123 |
-
x = torch.cat([model.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
124 |
-
x = x + model.visual.positional_embedding.to(x.dtype)[:x.shape[1], :]
|
125 |
-
x = model.visual.ln_pre(x)
|
126 |
-
|
127 |
-
x = x.permute(1, 0, 2) # NLD -> LND
|
128 |
-
|
129 |
-
for layer_idx, layer in enumerate(model.visual.transformer.resblocks):
|
130 |
-
x = layer(x)
|
131 |
-
|
132 |
-
x = x.permute(1, 0, 2)
|
133 |
-
tmp_fc = x[0, 0, :]
|
134 |
-
tmp_att = x[0, 1:, :].reshape( 14, 14, 768 )
|
135 |
-
|
136 |
-
# np.save(os.path.join(dir_fc, str(img['cocoid'])), tmp_fc.data.cpu().float().numpy())
|
137 |
-
# np.savez_compressed(os.path.join(dir_att, str(img['cocoid'])), feat=tmp_att.data.cpu().float().numpy())
|
138 |
-
np.save(os.path.join(dir_fc, str(img['id'])), tmp_fc.data.cpu().float().numpy())
|
139 |
-
np.savez_compressed(os.path.join(dir_att, str(img['id'])), feat=tmp_att.data.cpu().float().numpy())
|
140 |
-
|
141 |
-
|
142 |
-
# if i % 1000 == 0:
|
143 |
-
# print('processing %d/%d (%.2f%% done)' % (i, N, i*100.0/N))
|
144 |
-
print('wrote ', dir_fc, dir_att)
|
145 |
-
|
146 |
-
if __name__ == "__main__":
|
147 |
-
|
148 |
-
parser = argparse.ArgumentParser()
|
149 |
-
|
150 |
-
# input json
|
151 |
-
parser.add_argument('--input_json', required=True, help='input json file to process into hdf5')
|
152 |
-
parser.add_argument('--output_dir', default='data', help='output h5 file')
|
153 |
-
|
154 |
-
# options
|
155 |
-
parser.add_argument('--images_root', default='', help='root location in which images are stored, to be prepended to file_path in input json')
|
156 |
-
parser.add_argument('--att_size', default=14, type=int, help='14x14 or 7x7')
|
157 |
-
parser.add_argument('--model_type', default='RN50', type=str, help='RN50, RN101, RN50x4, ViT-B/32, vit_base_patch32_224_in21k')
|
158 |
-
|
159 |
-
args = parser.parse_args()
|
160 |
-
params = vars(args) # convert to ordinary dict
|
161 |
-
print('parsed input parameters:')
|
162 |
-
print(json.dumps(params, indent = 2))
|
163 |
-
main(params)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts_FineCapEval/clipscore_prepro_feats.py
DELETED
@@ -1,154 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Preprocess a raw json dataset into features files for use in data_loader.py
|
3 |
-
|
4 |
-
Input: json file that has the form
|
5 |
-
[{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...]
|
6 |
-
example element in this list would look like
|
7 |
-
{'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895}
|
8 |
-
|
9 |
-
This script reads this json, does some basic preprocessing on the captions
|
10 |
-
(e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays
|
11 |
-
|
12 |
-
Output: two folders of features
|
13 |
-
"""
|
14 |
-
|
15 |
-
from __future__ import absolute_import
|
16 |
-
from __future__ import division
|
17 |
-
from __future__ import print_function
|
18 |
-
|
19 |
-
import os
|
20 |
-
import json
|
21 |
-
import argparse
|
22 |
-
from random import shuffle, seed
|
23 |
-
import string
|
24 |
-
# non-standard dependencies:
|
25 |
-
import h5py
|
26 |
-
from six.moves import cPickle
|
27 |
-
import numpy as np
|
28 |
-
import torch
|
29 |
-
import torchvision.models as models
|
30 |
-
import skimage.io
|
31 |
-
|
32 |
-
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
33 |
-
from PIL import Image
|
34 |
-
from torch import nn
|
35 |
-
|
36 |
-
# preprocess = Compose([
|
37 |
-
# Resize((448, 448), interpolation=Image.BICUBIC),
|
38 |
-
# CenterCrop((448, 448)),
|
39 |
-
# ToTensor()
|
40 |
-
# ])
|
41 |
-
|
42 |
-
|
43 |
-
# from clip.clip import load
|
44 |
-
# from timm.models.vision_transformer import resize_pos_embed
|
45 |
-
# import timm
|
46 |
-
|
47 |
-
# from captioning.utils.resnet_utils import myResnet
|
48 |
-
# import captioning.utils.resnet as resnet
|
49 |
-
|
50 |
-
from captioning.utils.clipscore import CLIPScore
|
51 |
-
|
52 |
-
from tqdm import tqdm
|
53 |
-
|
54 |
-
|
55 |
-
def main(params):
|
56 |
-
|
57 |
-
clipscore_model = CLIPScore()
|
58 |
-
clipscore_model.to('cuda')
|
59 |
-
|
60 |
-
imgs = json.load(open(params['input_json'], 'r'))
|
61 |
-
imgs = imgs['images']
|
62 |
-
N = len(imgs)
|
63 |
-
|
64 |
-
seed(123) # make reproducible
|
65 |
-
|
66 |
-
# dir_fc = params['output_dir']+'_clip_'+save_model_type+'_fc'
|
67 |
-
# dir_att = params['output_dir']+'_clip_'+save_model_type+'_att'
|
68 |
-
|
69 |
-
vis_dir_fc = params['output_dir']+'_clipscore_vis'
|
70 |
-
if not os.path.isdir(vis_dir_fc):
|
71 |
-
os.mkdir(vis_dir_fc)
|
72 |
-
|
73 |
-
# text_dir_fc = params['output_dir']+'_clipscore_text'
|
74 |
-
# if not os.path.isdir(text_dir_fc):
|
75 |
-
# os.mkdir(text_dir_fc)
|
76 |
-
|
77 |
-
# if not os.path.isdir(dir_att):
|
78 |
-
# os.mkdir(dir_att)
|
79 |
-
|
80 |
-
for i,img in enumerate(tqdm(imgs)):
|
81 |
-
# load the image
|
82 |
-
|
83 |
-
# img_path = os.path.join(params['images_root'], img['filepath'], img['filename'])
|
84 |
-
# img_path = os.path.join(params['images_root'], img['file_name'])
|
85 |
-
img_path = os.path.join(params['images_root'], img['file_path'])
|
86 |
-
|
87 |
-
img_feat = clipscore_model.image_extract(img_path)
|
88 |
-
img_feat = img_feat.view(512)
|
89 |
-
|
90 |
-
# for d in img['sentences']:
|
91 |
-
# text = d['raw'].strip()
|
92 |
-
# text_feat = clipscore_model.text_extract(text)
|
93 |
-
|
94 |
-
|
95 |
-
# with torch.no_grad():
|
96 |
-
|
97 |
-
# image = preprocess(Image.open(os.path.join(params['images_root'], img['filepath'], img['filename']) ).convert("RGB"))
|
98 |
-
# image = torch.tensor(np.stack([image])).cuda()
|
99 |
-
# image -= mean
|
100 |
-
# image /= std
|
101 |
-
# if "RN" in params["model_type"]:
|
102 |
-
# tmp_att, tmp_fc = model.encode_image(image)
|
103 |
-
# tmp_att = tmp_att[0].permute(1, 2, 0)
|
104 |
-
# tmp_fc = tmp_fc[0]
|
105 |
-
# elif params["model_type"] == 'vit_base_patch32_224_in21k':
|
106 |
-
# x = model(image)
|
107 |
-
# tmp_fc = x[0, 0, :]
|
108 |
-
# tmp_att = x[0, 1:, :].reshape( 14, 14, 768 )
|
109 |
-
# else:
|
110 |
-
# x = model.visual.conv1(image.half()) # shape = [*, width, grid, grid]
|
111 |
-
# x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
112 |
-
# x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
113 |
-
# x = torch.cat([model.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
114 |
-
# x = x + model.visual.positional_embedding.to(x.dtype)[:x.shape[1], :]
|
115 |
-
# x = model.visual.ln_pre(x)
|
116 |
-
|
117 |
-
# x = x.permute(1, 0, 2) # NLD -> LND
|
118 |
-
|
119 |
-
# for layer_idx, layer in enumerate(model.visual.transformer.resblocks):
|
120 |
-
# x = layer(x)
|
121 |
-
|
122 |
-
# x = x.permute(1, 0, 2)
|
123 |
-
# tmp_fc = x[0, 0, :]
|
124 |
-
# tmp_att = x[0, 1:, :].reshape( 14, 14, 768 )
|
125 |
-
|
126 |
-
np.save(os.path.join(vis_dir_fc, str(img['id'])), img_feat.data.cpu().float().numpy())
|
127 |
-
# np.save(os.path.join(text_dir_fc, str(img['cocoid'])), tmp_fc.data.cpu().float().numpy())
|
128 |
-
|
129 |
-
|
130 |
-
# np.savez_compressed(os.path.join(dir_att, str(img['cocoid'])), feat=tmp_att.data.cpu().float().numpy())
|
131 |
-
|
132 |
-
# if i % 1000 == 0:
|
133 |
-
# print('processing %d/%d (%.2f%% done)' % (i, N, i*100.0/N))
|
134 |
-
print('wrote ', vis_dir_fc)
|
135 |
-
|
136 |
-
if __name__ == "__main__":
|
137 |
-
|
138 |
-
parser = argparse.ArgumentParser()
|
139 |
-
|
140 |
-
# input json
|
141 |
-
# dataset_coco.json
|
142 |
-
parser.add_argument('--input_json', required=True, help='input json file to process into hdf5')
|
143 |
-
parser.add_argument('--output_dir', default='data', help='output h5 file')
|
144 |
-
|
145 |
-
# options
|
146 |
-
parser.add_argument('--images_root', default='', help='root location in which images are stored, to be prepended to file_path in input json')
|
147 |
-
# parser.add_argument('--att_size', default=14, type=int, help='14x14 or 7x7')
|
148 |
-
# parser.add_argument('--model_type', default='RN50', type=str, help='RN50, RN101, RN50x4, ViT-B/32, vit_base_patch32_224_in21k')
|
149 |
-
|
150 |
-
args = parser.parse_args()
|
151 |
-
params = vars(args) # convert to ordinary dict
|
152 |
-
print('parsed input parameters:')
|
153 |
-
print(json.dumps(params, indent = 2))
|
154 |
-
main(params)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts_FineCapEval/prepro_labels.py
DELETED
@@ -1,209 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Preprocess a raw json dataset into hdf5/json files for use in data_loader.py
|
3 |
-
|
4 |
-
Input: json file that has the form
|
5 |
-
[{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...]
|
6 |
-
example element in this list would look like
|
7 |
-
{'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895}
|
8 |
-
|
9 |
-
This script reads this json, does some basic preprocessing on the captions
|
10 |
-
(e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays
|
11 |
-
|
12 |
-
Output: a json file and an hdf5 file
|
13 |
-
The hdf5 file contains several fields:
|
14 |
-
/labels is (M,max_length) uint32 array of encoded labels, zero padded
|
15 |
-
/label_start_ix and /label_end_ix are (N,) uint32 arrays of pointers to the
|
16 |
-
first and last indices (in range 1..M) of labels for each image
|
17 |
-
/label_length stores the length of the sequence for each of the M sequences
|
18 |
-
|
19 |
-
The json file has a dict that contains:
|
20 |
-
- an 'ix_to_word' field storing the vocab in form {ix:'word'}, where ix is 1-indexed
|
21 |
-
- an 'images' field that is a list holding auxiliary information for each image,
|
22 |
-
such as in particular the 'split' it was assigned to.
|
23 |
-
"""
|
24 |
-
|
25 |
-
from __future__ import absolute_import
|
26 |
-
from __future__ import division
|
27 |
-
from __future__ import print_function
|
28 |
-
|
29 |
-
import os
|
30 |
-
import json
|
31 |
-
import argparse
|
32 |
-
from random import shuffle, seed
|
33 |
-
import string
|
34 |
-
# non-standard dependencies:
|
35 |
-
import h5py
|
36 |
-
import numpy as np
|
37 |
-
import torch
|
38 |
-
import torchvision.models as models
|
39 |
-
import skimage.io
|
40 |
-
from PIL import Image
|
41 |
-
|
42 |
-
|
43 |
-
def build_vocab(imgs, params):
|
44 |
-
count_thr = params['word_count_threshold']
|
45 |
-
|
46 |
-
# count up the number of words
|
47 |
-
counts = {}
|
48 |
-
for img in imgs:
|
49 |
-
for sent in img['sentences']:
|
50 |
-
for w in sent['tokens']:
|
51 |
-
counts[w] = counts.get(w, 0) + 1
|
52 |
-
cw = sorted([(count,w) for w,count in counts.items()], reverse=True)
|
53 |
-
print('top words and their counts:')
|
54 |
-
print('\n'.join(map(str,cw[:20])))
|
55 |
-
|
56 |
-
# print some stats
|
57 |
-
total_words = sum(counts.values())
|
58 |
-
print('total words:', total_words)
|
59 |
-
bad_words = [w for w,n in counts.items() if n <= count_thr]
|
60 |
-
vocab = [w for w,n in counts.items() if n > count_thr]
|
61 |
-
bad_count = sum(counts[w] for w in bad_words)
|
62 |
-
print('number of bad words: %d/%d = %.2f%%' % (len(bad_words), len(counts), len(bad_words)*100.0/len(counts)))
|
63 |
-
print('number of words in vocab would be %d' % (len(vocab), ))
|
64 |
-
print('number of UNKs: %d/%d = %.2f%%' % (bad_count, total_words, bad_count*100.0/total_words))
|
65 |
-
|
66 |
-
# lets look at the distribution of lengths as well
|
67 |
-
sent_lengths = {}
|
68 |
-
for img in imgs:
|
69 |
-
for sent in img['sentences']:
|
70 |
-
txt = sent['tokens']
|
71 |
-
nw = len(txt)
|
72 |
-
sent_lengths[nw] = sent_lengths.get(nw, 0) + 1
|
73 |
-
max_len = max(sent_lengths.keys())
|
74 |
-
print('max length sentence in raw data: ', max_len)
|
75 |
-
print('sentence length distribution (count, number of words):')
|
76 |
-
sum_len = sum(sent_lengths.values())
|
77 |
-
for i in range(max_len+1):
|
78 |
-
print('%2d: %10d %f%%' % (i, sent_lengths.get(i,0), sent_lengths.get(i,0)*100.0/sum_len))
|
79 |
-
|
80 |
-
# lets now produce the final annotations
|
81 |
-
if bad_count > 0:
|
82 |
-
# additional special UNK token we will use below to map infrequent words to
|
83 |
-
print('inserting the special UNK token')
|
84 |
-
vocab.append('UNK')
|
85 |
-
|
86 |
-
for img in imgs:
|
87 |
-
img['final_captions'] = []
|
88 |
-
for sent in img['sentences']:
|
89 |
-
txt = sent['tokens']
|
90 |
-
caption = [w if counts.get(w,0) > count_thr else 'UNK' for w in txt]
|
91 |
-
img['final_captions'].append(caption)
|
92 |
-
|
93 |
-
return vocab
|
94 |
-
|
95 |
-
|
96 |
-
def encode_captions(imgs, params, wtoi):
|
97 |
-
"""
|
98 |
-
encode all captions into one large array, which will be 1-indexed.
|
99 |
-
also produces label_start_ix and label_end_ix which store 1-indexed
|
100 |
-
and inclusive (Lua-style) pointers to the first and last caption for
|
101 |
-
each image in the dataset.
|
102 |
-
"""
|
103 |
-
|
104 |
-
max_length = params['max_length']
|
105 |
-
N = len(imgs)
|
106 |
-
M = sum(len(img['final_captions']) for img in imgs) # total number of captions
|
107 |
-
|
108 |
-
label_arrays = []
|
109 |
-
label_start_ix = np.zeros(N, dtype='uint32') # note: these will be one-indexed
|
110 |
-
label_end_ix = np.zeros(N, dtype='uint32')
|
111 |
-
label_length = np.zeros(M, dtype='uint32')
|
112 |
-
caption_counter = 0
|
113 |
-
counter = 1
|
114 |
-
for i,img in enumerate(imgs):
|
115 |
-
n = len(img['final_captions'])
|
116 |
-
assert n > 0, 'error: some image has no captions'
|
117 |
-
|
118 |
-
Li = np.zeros((n, max_length), dtype='uint32')
|
119 |
-
for j,s in enumerate(img['final_captions']):
|
120 |
-
label_length[caption_counter] = min(max_length, len(s)) # record the length of this sequence
|
121 |
-
caption_counter += 1
|
122 |
-
for k,w in enumerate(s):
|
123 |
-
if k < max_length:
|
124 |
-
Li[j,k] = wtoi[w]
|
125 |
-
|
126 |
-
# note: word indices are 1-indexed, and captions are padded with zeros
|
127 |
-
label_arrays.append(Li)
|
128 |
-
label_start_ix[i] = counter
|
129 |
-
label_end_ix[i] = counter + n - 1
|
130 |
-
|
131 |
-
counter += n
|
132 |
-
|
133 |
-
L = np.concatenate(label_arrays, axis=0) # put all the labels together
|
134 |
-
assert L.shape[0] == M, 'lengths don\'t match? that\'s weird'
|
135 |
-
assert np.all(label_length > 0), 'error: some caption had no words?'
|
136 |
-
|
137 |
-
print('encoded captions to array of size ', L.shape)
|
138 |
-
return L, label_start_ix, label_end_ix, label_length
|
139 |
-
|
140 |
-
|
141 |
-
def main(params):
|
142 |
-
|
143 |
-
imgs = json.load(open(params['input_json'], 'r'))
|
144 |
-
imgs = imgs['images']
|
145 |
-
|
146 |
-
seed(123) # make reproducible
|
147 |
-
|
148 |
-
# # create the vocab
|
149 |
-
# vocab = build_vocab(imgs, params)
|
150 |
-
# itow = {i+1:w for i,w in enumerate(vocab)} # a 1-indexed vocab translation table
|
151 |
-
# wtoi = {w:i+1 for i,w in enumerate(vocab)} # inverse table
|
152 |
-
|
153 |
-
itow = imgs['ix_to_word']
|
154 |
-
wtoi = {w:i for i, w in itow.items()}
|
155 |
-
|
156 |
-
# encode captions in large arrays, ready to ship to hdf5 file
|
157 |
-
L, label_start_ix, label_end_ix, label_length = encode_captions(imgs, params, wtoi)
|
158 |
-
|
159 |
-
# create output h5 file
|
160 |
-
N = len(imgs)
|
161 |
-
f_lb = h5py.File(params['output_h5']+'_label.h5', "w")
|
162 |
-
f_lb.create_dataset("labels", dtype='uint32', data=L)
|
163 |
-
f_lb.create_dataset("label_start_ix", dtype='uint32', data=label_start_ix)
|
164 |
-
f_lb.create_dataset("label_end_ix", dtype='uint32', data=label_end_ix)
|
165 |
-
f_lb.create_dataset("label_length", dtype='uint32', data=label_length)
|
166 |
-
f_lb.close()
|
167 |
-
|
168 |
-
# create output json file
|
169 |
-
out = {}
|
170 |
-
out['ix_to_word'] = itow # encode the (1-indexed) vocab
|
171 |
-
out['images'] = []
|
172 |
-
for i,img in enumerate(imgs):
|
173 |
-
|
174 |
-
jimg = {}
|
175 |
-
jimg['split'] = img['split']
|
176 |
-
if 'filename' in img: jimg['file_path'] = os.path.join(img.get('filepath', ''), img['filename']) # copy it over, might need
|
177 |
-
if 'cocoid' in img:
|
178 |
-
jimg['id'] = img['cocoid'] # copy over & mantain an id, if present (e.g. coco ids, useful)
|
179 |
-
elif 'imgid' in img:
|
180 |
-
jimg['id'] = img['imgid']
|
181 |
-
|
182 |
-
if params['images_root'] != '':
|
183 |
-
with Image.open(os.path.join(params['images_root'], img['filepath'], img['filename'])) as _img:
|
184 |
-
jimg['width'], jimg['height'] = _img.size
|
185 |
-
|
186 |
-
out['images'].append(jimg)
|
187 |
-
|
188 |
-
json.dump(out, open(params['output_json'], 'w'))
|
189 |
-
print('wrote ', params['output_json'])
|
190 |
-
|
191 |
-
if __name__ == "__main__":
|
192 |
-
|
193 |
-
parser = argparse.ArgumentParser()
|
194 |
-
|
195 |
-
# input json
|
196 |
-
parser.add_argument('--input_json', required=True, help='input json file to process into hdf5')
|
197 |
-
parser.add_argument('--output_json', default='data.json', help='output json file')
|
198 |
-
parser.add_argument('--output_h5', default='data', help='output h5 file')
|
199 |
-
parser.add_argument('--images_root', default='', help='root location in which images are stored, to be prepended to file_path in input json')
|
200 |
-
|
201 |
-
# options
|
202 |
-
parser.add_argument('--max_length', default=16, type=int, help='max length of a caption, in number of words. captions longer than this get clipped.')
|
203 |
-
parser.add_argument('--word_count_threshold', default=5, type=int, help='only words that occur more than this number of times will be put in vocab')
|
204 |
-
|
205 |
-
args = parser.parse_args()
|
206 |
-
params = vars(args) # convert to ordinary dict
|
207 |
-
print('parsed input parameters:')
|
208 |
-
print(json.dumps(params, indent = 2))
|
209 |
-
main(params)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tools/eval.py
DELETED
@@ -1,125 +0,0 @@
|
|
1 |
-
from __future__ import absolute_import
|
2 |
-
from __future__ import division
|
3 |
-
from __future__ import print_function
|
4 |
-
|
5 |
-
import json
|
6 |
-
import numpy as np
|
7 |
-
|
8 |
-
import time
|
9 |
-
import os
|
10 |
-
from six.moves import cPickle
|
11 |
-
|
12 |
-
import captioning.utils.opts as opts
|
13 |
-
import captioning.models as models
|
14 |
-
from captioning.data.dataloader import *
|
15 |
-
# from captioning.data.dataloaderraw import *
|
16 |
-
import captioning.utils.eval_utils as eval_utils
|
17 |
-
import argparse
|
18 |
-
import captioning.utils.misc as utils
|
19 |
-
import captioning.modules.losses as losses
|
20 |
-
import torch
|
21 |
-
|
22 |
-
# Input arguments and options
|
23 |
-
parser = argparse.ArgumentParser()
|
24 |
-
# Input paths
|
25 |
-
parser.add_argument('--model', type=str, default='',
|
26 |
-
help='path to model to evaluate')
|
27 |
-
parser.add_argument('--cnn_model', type=str, default='resnet101',
|
28 |
-
help='resnet101, resnet152')
|
29 |
-
parser.add_argument('--infos_path', type=str, default='',
|
30 |
-
help='path to infos to evaluate')
|
31 |
-
parser.add_argument('--only_lang_eval', type=int, default=0,
|
32 |
-
help='lang eval on saved results')
|
33 |
-
parser.add_argument('--force', type=int, default=0,
|
34 |
-
help='force to evaluate no matter if there are results available')
|
35 |
-
parser.add_argument('--device', type=str, default='cuda',
|
36 |
-
help='cpu or cuda')
|
37 |
-
opts.add_eval_options(parser)
|
38 |
-
opts.add_diversity_opts(parser)
|
39 |
-
opt = parser.parse_args()
|
40 |
-
|
41 |
-
# Load infos
|
42 |
-
with open(opt.infos_path, 'rb') as f:
|
43 |
-
infos = utils.pickle_load(f)
|
44 |
-
|
45 |
-
# override and collect parameters
|
46 |
-
replace = ['input_fc_dir', 'input_att_dir', 'input_box_dir', 'input_label_h5', 'input_json', 'batch_size', 'id']
|
47 |
-
ignore = ['start_from']
|
48 |
-
|
49 |
-
for k in vars(infos['opt']).keys():
|
50 |
-
if k in replace:
|
51 |
-
setattr(opt, k, getattr(opt, k) or getattr(infos['opt'], k, ''))
|
52 |
-
elif k not in ignore:
|
53 |
-
if not k in vars(opt):
|
54 |
-
vars(opt).update({k: vars(infos['opt'])[k]}) # copy over options from model
|
55 |
-
|
56 |
-
vocab = infos['vocab'] # ix -> word mapping
|
57 |
-
|
58 |
-
pred_fn = os.path.join('eval_results/', '.saved_pred_'+ opt.id + '_' + opt.split + '.pth')
|
59 |
-
result_fn = os.path.join('eval_results/', opt.id + '_' + opt.split + '.json')
|
60 |
-
|
61 |
-
if opt.only_lang_eval == 1 or (not opt.force and os.path.isfile(pred_fn)):
|
62 |
-
# if results existed, then skip, unless force is on
|
63 |
-
if not opt.force:
|
64 |
-
try:
|
65 |
-
if os.path.isfile(result_fn):
|
66 |
-
print(result_fn)
|
67 |
-
json.load(open(result_fn, 'r'))
|
68 |
-
print('already evaluated')
|
69 |
-
os._exit(0)
|
70 |
-
except:
|
71 |
-
pass
|
72 |
-
|
73 |
-
predictions, n_predictions = torch.load(pred_fn)
|
74 |
-
lang_stats = eval_utils.language_eval(opt.input_json, predictions, n_predictions, vars(opt), opt.split)
|
75 |
-
print(lang_stats)
|
76 |
-
os._exit(0)
|
77 |
-
|
78 |
-
# At this point only_lang_eval if 0
|
79 |
-
if not opt.force:
|
80 |
-
# Check out if
|
81 |
-
try:
|
82 |
-
# if no pred exists, then continue
|
83 |
-
tmp = torch.load(pred_fn)
|
84 |
-
# if language_eval == 1, and no pred exists, then continue
|
85 |
-
if opt.language_eval == 1:
|
86 |
-
json.load(open(result_fn, 'r'))
|
87 |
-
print('Result is already there')
|
88 |
-
os._exit(0)
|
89 |
-
except:
|
90 |
-
pass
|
91 |
-
|
92 |
-
# Setup the model
|
93 |
-
opt.vocab = vocab
|
94 |
-
model = models.setup(opt)
|
95 |
-
del opt.vocab
|
96 |
-
model.load_state_dict(torch.load(opt.model, map_location='cpu'))
|
97 |
-
model.to(opt.device)
|
98 |
-
model.eval()
|
99 |
-
crit = losses.LanguageModelCriterion()
|
100 |
-
|
101 |
-
# Create the Data Loader instance
|
102 |
-
if len(opt.image_folder) == 0:
|
103 |
-
loader = DataLoader(opt)
|
104 |
-
else:
|
105 |
-
loader = DataLoaderRaw({'folder_path': opt.image_folder,
|
106 |
-
'coco_json': opt.coco_json,
|
107 |
-
'batch_size': opt.batch_size,
|
108 |
-
'cnn_model': opt.cnn_model})
|
109 |
-
# When eval using provided pretrained model, the vocab may be different from what you have in your cocotalk.json
|
110 |
-
# So make sure to use the vocab in infos file.
|
111 |
-
loader.dataset.ix_to_word = infos['vocab']
|
112 |
-
|
113 |
-
|
114 |
-
# Set sample options
|
115 |
-
opt.dataset = opt.input_json
|
116 |
-
loss, split_predictions, lang_stats = eval_utils.eval_split(model, crit, loader,
|
117 |
-
vars(opt))
|
118 |
-
|
119 |
-
print('loss: ', loss)
|
120 |
-
if lang_stats:
|
121 |
-
print(lang_stats)
|
122 |
-
|
123 |
-
if opt.dump_json == 1:
|
124 |
-
# dump the json
|
125 |
-
json.dump(split_predictions, open('vis/vis.json', 'w'))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tools/eval_clip_retrieval.py
DELETED
@@ -1,231 +0,0 @@
|
|
1 |
-
|
2 |
-
from PIL import Image
|
3 |
-
# import requests
|
4 |
-
|
5 |
-
from transformers import CLIPProcessor, CLIPModel
|
6 |
-
|
7 |
-
import torch
|
8 |
-
from torch.utils.data import DataLoader, Dataset
|
9 |
-
|
10 |
-
from pathlib import Path
|
11 |
-
from tqdm import tqdm
|
12 |
-
import json
|
13 |
-
import argparse
|
14 |
-
import numpy as np
|
15 |
-
|
16 |
-
class COCODataset(Dataset):
|
17 |
-
def __init__(self,
|
18 |
-
coco_root="/nas-ssd/jmincho/datasets/COCO/",
|
19 |
-
gen_caption_path=None,
|
20 |
-
is_gt=True):
|
21 |
-
super().__init__()
|
22 |
-
|
23 |
-
self.coco_root = Path(coco_root)
|
24 |
-
|
25 |
-
self.image_dir = self.coco_root.joinpath('images/val2014')
|
26 |
-
|
27 |
-
if is_gt:
|
28 |
-
print("Loading karpathy splits")
|
29 |
-
data_info_path = self.coco_root.joinpath('dataset_coco.json')
|
30 |
-
with open(data_info_path) as f:
|
31 |
-
karpathy_data = json.load(f)
|
32 |
-
|
33 |
-
data = []
|
34 |
-
for datum in karpathy_data['images']:
|
35 |
-
# karpathy test split
|
36 |
-
if datum['split'] == 'test':
|
37 |
-
img_id = datum['filename'].split('.')[0]
|
38 |
-
new_datum = {
|
39 |
-
'img_id': img_id,
|
40 |
-
'captions': [d['raw'].strip() for d in datum['sentences']],
|
41 |
-
}
|
42 |
-
data.append(new_datum)
|
43 |
-
else:
|
44 |
-
print("Loading generated captions")
|
45 |
-
gen_caption_path = Path(gen_caption_path)
|
46 |
-
with open(gen_caption_path) as f:
|
47 |
-
# karpathy_data = json.load(f)
|
48 |
-
imgTogen_results = json.load(f)['imgToEval']
|
49 |
-
data = []
|
50 |
-
for img_id, img_data in imgTogen_results.items():
|
51 |
-
new_datum = {
|
52 |
-
'img_id': img_id,
|
53 |
-
'captions': [img_data['caption']],
|
54 |
-
}
|
55 |
-
data.append(new_datum)
|
56 |
-
|
57 |
-
self.data = data
|
58 |
-
print('# images:', len(self.data))
|
59 |
-
|
60 |
-
self.img_transform = processor.feature_extractor
|
61 |
-
self.tokenizer = processor.tokenizer
|
62 |
-
|
63 |
-
def __len__(self):
|
64 |
-
return len(self.data)
|
65 |
-
|
66 |
-
def __getitem__(self, idx):
|
67 |
-
datum = self.data[idx]
|
68 |
-
img_id = datum['img_id']
|
69 |
-
if 'COCO' not in img_id:
|
70 |
-
img_id = f'COCO_val2014_{str(img_id).zfill(12)}'
|
71 |
-
img_fname = f"{img_id}.jpg"
|
72 |
-
# COCO_val2014_000000522418.jpg
|
73 |
-
img_path = self.image_dir.joinpath(img_fname)
|
74 |
-
img = Image.open(img_path).convert("RGB")
|
75 |
-
|
76 |
-
# take first caption
|
77 |
-
caption = datum['captions'][0]
|
78 |
-
|
79 |
-
return {
|
80 |
-
"img": img,
|
81 |
-
"caption": caption,
|
82 |
-
}
|
83 |
-
|
84 |
-
def collate_fn(self, datum_list):
|
85 |
-
B = len(datum_list)
|
86 |
-
imgs = [datum['img'] for datum in datum_list]
|
87 |
-
images = self.img_transform(imgs, return_tensors="pt")
|
88 |
-
|
89 |
-
captions = [datum['caption'] for datum in datum_list]
|
90 |
-
|
91 |
-
text_tokens = self.tokenizer(captions, return_tensors="pt", padding=True)
|
92 |
-
batch = {
|
93 |
-
'images': images,
|
94 |
-
'captions': text_tokens,
|
95 |
-
}
|
96 |
-
return batch
|
97 |
-
|
98 |
-
|
99 |
-
def compute_similarity(image_features, text_features, bs = 1000):
|
100 |
-
# compute similarity
|
101 |
-
max_pairs = image_features.shape[0]
|
102 |
-
similarity_scores = torch.zeros(max_pairs, max_pairs)
|
103 |
-
for v in range(0, max_pairs, bs):
|
104 |
-
for t in range(0, max_pairs, bs):
|
105 |
-
# print('Processing Visual '+str(v)+' Text '+str(t), end='\r')
|
106 |
-
batch_visual_emb = image_features[v:v+bs]
|
107 |
-
batch_caption_emb = text_features[t:t+bs]
|
108 |
-
|
109 |
-
logits = batch_visual_emb @ batch_caption_emb.t()
|
110 |
-
similarity_scores[v:v+bs,t:t+bs] = logits
|
111 |
-
|
112 |
-
print('Done similarity')
|
113 |
-
return similarity_scores
|
114 |
-
|
115 |
-
def compute_retrieval(a2b_sims, return_ranks=True):
|
116 |
-
"""
|
117 |
-
Args:
|
118 |
-
a2b_sims: Result of computing similarity between two sets of embeddings (emb1 @ emb2.T)
|
119 |
-
with shape (num_datapoints, num_datapoints).
|
120 |
-
|
121 |
-
Returns:
|
122 |
-
Retrieval metrics for that similarity.
|
123 |
-
"""
|
124 |
-
npts = a2b_sims.shape[0]
|
125 |
-
ranks = np.zeros(npts)
|
126 |
-
top1 = np.zeros(npts)
|
127 |
-
# loop source embedding indices
|
128 |
-
for index in range(npts):
|
129 |
-
# get order of similarities to target embeddings
|
130 |
-
inds = np.argsort(a2b_sims[index])[::-1]
|
131 |
-
# find where the correct embedding is ranked
|
132 |
-
where = np.where(inds == index)
|
133 |
-
rank = where[0][0]
|
134 |
-
ranks[index] = rank
|
135 |
-
# save the top1 result as well
|
136 |
-
top1[index] = inds[0]
|
137 |
-
|
138 |
-
# Compute metrics
|
139 |
-
r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
|
140 |
-
r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
|
141 |
-
r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
|
142 |
-
r50 = 100.0 * len(np.where(ranks < 50)[0]) / len(ranks)
|
143 |
-
medr = np.floor(np.median(ranks)) + 1
|
144 |
-
meanr = ranks.mean() + 1
|
145 |
-
|
146 |
-
report_dict = {"r1": r1, "r5": r5, "r10": r10, "r50": r50, "medr": medr, "meanr": meanr, "sum": r1 + r5 + r10}
|
147 |
-
|
148 |
-
if return_ranks:
|
149 |
-
return report_dict, (ranks, top1)
|
150 |
-
else:
|
151 |
-
return report_dict
|
152 |
-
|
153 |
-
|
154 |
-
if __name__ == '__main__':
|
155 |
-
|
156 |
-
parser = argparse.ArgumentParser()
|
157 |
-
parser.add_argument('--coco_root', type=str, default="/nas-ssd/jmincho/datasets/COCO/")
|
158 |
-
parser.add_argument('--gt', action='store_true')
|
159 |
-
parser.add_argument('--gen_caption_path', type=str, default="./eval_results/clipRN50_cider_test.json")
|
160 |
-
args = parser.parse_args()
|
161 |
-
|
162 |
-
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
163 |
-
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
164 |
-
|
165 |
-
device = "cuda"
|
166 |
-
model = model.to(device)
|
167 |
-
model.eval()
|
168 |
-
print(f"Loaded CLIP at {device}")
|
169 |
-
|
170 |
-
batch_size = 1000
|
171 |
-
|
172 |
-
dataset = COCODataset(
|
173 |
-
coco_root="/nas-ssd/jmincho/datasets/COCO/",
|
174 |
-
gen_caption_path=args.gen_caption_path,
|
175 |
-
is_gt=args.gt
|
176 |
-
)
|
177 |
-
data_loader = DataLoader(
|
178 |
-
dataset,
|
179 |
-
batch_size=batch_size,
|
180 |
-
collate_fn=dataset.collate_fn,
|
181 |
-
shuffle=False,
|
182 |
-
num_workers=8)
|
183 |
-
|
184 |
-
# fwd all samples
|
185 |
-
image_features = []
|
186 |
-
text_features = []
|
187 |
-
for batch_idx, batch in enumerate(tqdm(data_loader)):
|
188 |
-
# print('Evaluating batch {}/{}'.format(batch_idx, len(data_loader)), end="\r")
|
189 |
-
# images, texts = batch
|
190 |
-
|
191 |
-
with torch.no_grad():
|
192 |
-
images = batch["images"].to(device)
|
193 |
-
texts = batch["captions"].to(device)
|
194 |
-
|
195 |
-
vision_outputs = model.vision_model(**batch['images'])
|
196 |
-
text_outputs = model.text_model(**batch['captions'])
|
197 |
-
|
198 |
-
image_embeds = vision_outputs[1]
|
199 |
-
image_embeds = model.visual_projection(image_embeds)
|
200 |
-
|
201 |
-
text_embeds = text_outputs[1]
|
202 |
-
text_embeds = model.text_projection(text_embeds)
|
203 |
-
|
204 |
-
# normalized features
|
205 |
-
image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)
|
206 |
-
text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
|
207 |
-
|
208 |
-
text_features.append(text_embeds.detach().cpu())
|
209 |
-
image_features.append(image_embeds.detach().cpu())
|
210 |
-
|
211 |
-
image_features = torch.cat(image_features, 0)
|
212 |
-
text_features = torch.cat(text_features, 0)
|
213 |
-
print('Done forward')
|
214 |
-
|
215 |
-
# normalized features
|
216 |
-
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
217 |
-
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
218 |
-
|
219 |
-
# if not single_caption:
|
220 |
-
# for cap_idx in range(text_features.shape[1]):
|
221 |
-
# similarity_scores = compute_similarity(image_features, text_features[:,cap_idx,:])
|
222 |
-
# i2t_dict = compute_retrieval(similarity_scores.numpy())
|
223 |
-
# t2i_dict = compute_retrieval(similarity_scores.t().numpy())
|
224 |
-
# print(cap_idx, 'i2t', i2t_dict)
|
225 |
-
# print(cap_idx, 't2i', t2i_dict)
|
226 |
-
# else:
|
227 |
-
similarity_scores = compute_similarity(image_features, text_features)
|
228 |
-
i2t_dict = compute_retrieval(similarity_scores.numpy())
|
229 |
-
t2i_dict = compute_retrieval(similarity_scores.t().numpy())
|
230 |
-
print('i2t', i2t_dict)
|
231 |
-
print('t2i', t2i_dict)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tools/eval_finecapeval.py
DELETED
@@ -1,204 +0,0 @@
|
|
1 |
-
|
2 |
-
from tqdm import tqdm
|
3 |
-
from pprint import pprint
|
4 |
-
import pandas as pd
|
5 |
-
import argparse
|
6 |
-
import re
|
7 |
-
import json
|
8 |
-
import nltk
|
9 |
-
from nltk.tokenize import word_tokenize
|
10 |
-
from nltk.stem.porter import PorterStemmer
|
11 |
-
p_stemmer = PorterStemmer()
|
12 |
-
|
13 |
-
# nltk.download('punkt')
|
14 |
-
# nltk.download('wordnet')
|
15 |
-
# nltk.download('stopwords')
|
16 |
-
|
17 |
-
import language_evaluation
|
18 |
-
evaluator = language_evaluation.CocoEvaluator()
|
19 |
-
|
20 |
-
|
21 |
-
def nltk_process(text):
|
22 |
-
# Tokenization
|
23 |
-
nltk_tokenList = word_tokenize(text)
|
24 |
-
|
25 |
-
# Stemming
|
26 |
-
nltk_stemedList = []
|
27 |
-
for word in nltk_tokenList:
|
28 |
-
nltk_stemedList.append(p_stemmer.stem(word))
|
29 |
-
|
30 |
-
filtered_sentence = nltk_stemedList
|
31 |
-
|
32 |
-
# Removing Punctuation
|
33 |
-
|
34 |
-
tokens = [re.sub(r'[^a-zA-Z0-9]', '', tok) for tok in filtered_sentence]
|
35 |
-
|
36 |
-
text = " ".join(tokens)
|
37 |
-
|
38 |
-
return text
|
39 |
-
|
40 |
-
|
41 |
-
def calculate_finegrained_scores(pred_id2sent, id2caption, use_coco_eval=False):
|
42 |
-
if use_coco_eval:
|
43 |
-
n_total = 0
|
44 |
-
refs = []
|
45 |
-
hyps = []
|
46 |
-
for id, gt_captions in id2caption.items():
|
47 |
-
pred_sent = pred_id2sent[id]
|
48 |
-
|
49 |
-
refs.append(gt_captions)
|
50 |
-
hyps.append(pred_sent)
|
51 |
-
|
52 |
-
n_total += 1
|
53 |
-
|
54 |
-
print('caption')
|
55 |
-
results = evaluator.run_evaluation(hyps, refs)
|
56 |
-
pprint(results)
|
57 |
-
|
58 |
-
n_total = 0
|
59 |
-
total_score = 0
|
60 |
-
for id, gt_phrases in id2background.items():
|
61 |
-
pred_sent = pred_id2sent[id]
|
62 |
-
|
63 |
-
score = 0
|
64 |
-
n_phrases = len(gt_phrases)
|
65 |
-
|
66 |
-
for gt_phrase in gt_phrases:
|
67 |
-
word_score = 0
|
68 |
-
for gt_word in gt_phrase.split():
|
69 |
-
if gt_word in pred_sent:
|
70 |
-
word_score += 1
|
71 |
-
if len(gt_phrase.split()) > 0:
|
72 |
-
score += word_score / len(gt_phrase.split())
|
73 |
-
|
74 |
-
if n_phrases > 0:
|
75 |
-
score /= n_phrases
|
76 |
-
|
77 |
-
total_score += score
|
78 |
-
n_total += 1
|
79 |
-
print('background')
|
80 |
-
# print('# retrieved words:', n_retrieved)
|
81 |
-
print(f'Acc: {total_score / n_total * 100:.2f}')
|
82 |
-
|
83 |
-
n_total = 0
|
84 |
-
total_score = 0
|
85 |
-
for id, gt_phrases in id2object.items():
|
86 |
-
pred_sent = pred_id2sent[id]
|
87 |
-
|
88 |
-
score = 0
|
89 |
-
n_phrases = len(gt_phrases)
|
90 |
-
|
91 |
-
for gt_phrase in gt_phrases:
|
92 |
-
word_score = 0
|
93 |
-
for gt_word in gt_phrase.split():
|
94 |
-
if gt_word in pred_sent:
|
95 |
-
word_score += 1
|
96 |
-
if len(gt_phrase.split()) > 0:
|
97 |
-
score += word_score / len(gt_phrase.split())
|
98 |
-
|
99 |
-
if n_phrases > 0:
|
100 |
-
score /= n_phrases
|
101 |
-
|
102 |
-
total_score += score
|
103 |
-
n_total += 1
|
104 |
-
print('object')
|
105 |
-
# print('# retrieved words:', n_retrieved)
|
106 |
-
print(f'Acc: {total_score / n_total * 100:.2f}')
|
107 |
-
|
108 |
-
n_total = 0
|
109 |
-
total_score = 0
|
110 |
-
for id, gt_phrases in id2relation.items():
|
111 |
-
pred_sent = pred_id2sent[id]
|
112 |
-
|
113 |
-
score = 0
|
114 |
-
n_phrases = len(gt_phrases)
|
115 |
-
|
116 |
-
for gt_phrase in gt_phrases:
|
117 |
-
word_score = 0
|
118 |
-
for gt_word in gt_phrase.split():
|
119 |
-
if gt_word in pred_sent:
|
120 |
-
word_score += 1
|
121 |
-
if len(gt_phrase.split()) > 0:
|
122 |
-
score += word_score / len(gt_phrase.split())
|
123 |
-
|
124 |
-
if n_phrases > 0:
|
125 |
-
score /= n_phrases
|
126 |
-
|
127 |
-
total_score += score
|
128 |
-
n_total += 1
|
129 |
-
print('relation')
|
130 |
-
# print('# retrieved words:', n_retrieved)
|
131 |
-
print(f'Acc: {total_score / n_total * 100:.2f}')
|
132 |
-
|
133 |
-
|
134 |
-
if __name__ == '__main__':
|
135 |
-
parser = argparse.ArgumentParser()
|
136 |
-
parser.add_argument('--finecapeval_path', type=str, default="data/FineCapEval.csv")
|
137 |
-
parser.add_argument('--generated_id2caption', type=str, default="FineCapEval_results/mle.json")
|
138 |
-
args = parser.parse_args()
|
139 |
-
|
140 |
-
df = pd.read_csv(args.finecapeval_path)
|
141 |
-
assert df.shape == (5000, 5)
|
142 |
-
|
143 |
-
generated_id2caption = json.load(open(args.generated_id2caption, 'r'))
|
144 |
-
|
145 |
-
print("Preprocessing GT FineCapEval data...")
|
146 |
-
id2caption = {}
|
147 |
-
id2background = {}
|
148 |
-
id2object = {}
|
149 |
-
id2relation = {}
|
150 |
-
|
151 |
-
for row in tqdm(df.itertuples(), total=len(df)):
|
152 |
-
|
153 |
-
id = row.image.split('.')[0]
|
154 |
-
caption = row.caption
|
155 |
-
background = row.background
|
156 |
-
object = row.object
|
157 |
-
relation = row.relation
|
158 |
-
|
159 |
-
if not isinstance(caption, str):
|
160 |
-
continue
|
161 |
-
if not isinstance(background, str):
|
162 |
-
continue
|
163 |
-
if not isinstance(object, str):
|
164 |
-
continue
|
165 |
-
if not isinstance(relation, str):
|
166 |
-
continue
|
167 |
-
|
168 |
-
if id not in id2caption:
|
169 |
-
id2caption[id] = []
|
170 |
-
id2background[id] = []
|
171 |
-
id2object[id] = []
|
172 |
-
id2relation[id] = []
|
173 |
-
|
174 |
-
id2caption[id].append(caption)
|
175 |
-
|
176 |
-
phrases = []
|
177 |
-
for phrase in background.lower().split('\;'):
|
178 |
-
if len(phrase) > 1:
|
179 |
-
phrase = nltk_process(phrase)
|
180 |
-
phrases.append(phrase)
|
181 |
-
id2background[id].extend(phrases)
|
182 |
-
|
183 |
-
phrases = []
|
184 |
-
for phrase in object.lower().split('\;'):
|
185 |
-
if len(phrase) > 1:
|
186 |
-
phrase = nltk_process(phrase)
|
187 |
-
phrases.append(phrase)
|
188 |
-
id2object[id].extend(phrases)
|
189 |
-
|
190 |
-
phrases = []
|
191 |
-
for phrase in relation.lower().split('\;'):
|
192 |
-
if len(phrase) > 1:
|
193 |
-
phrase = nltk_process(phrase)
|
194 |
-
phrases.append(phrase)
|
195 |
-
id2relation[id].extend(phrases)
|
196 |
-
|
197 |
-
print("Calculating scores...")
|
198 |
-
calculate_finegrained_scores(
|
199 |
-
generated_id2caption,
|
200 |
-
id2caption,
|
201 |
-
use_coco_eval=True)
|
202 |
-
|
203 |
-
|
204 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tools/finecapeval_inference.py
DELETED
@@ -1,186 +0,0 @@
|
|
1 |
-
import sys
|
2 |
-
import torch
|
3 |
-
import torch.nn as nn
|
4 |
-
import torch.nn.functional as F
|
5 |
-
import torch.optim as optim
|
6 |
-
|
7 |
-
import numpy as np
|
8 |
-
|
9 |
-
import time
|
10 |
-
import os
|
11 |
-
from collections import defaultdict
|
12 |
-
import json
|
13 |
-
|
14 |
-
import captioning.utils.opts as opts
|
15 |
-
import captioning.models as models
|
16 |
-
from captioning.data.pth_loader import CaptionDataset
|
17 |
-
import captioning.utils.eval_utils as eval_utils
|
18 |
-
# import captioning.utils.vizwiz_eval_utils as vizwiz_eval_utils
|
19 |
-
import captioning.utils.misc as utils
|
20 |
-
from captioning.utils.rewards import init_scorer, get_self_critical_reward
|
21 |
-
from captioning.modules.loss_wrapper import LossWrapper
|
22 |
-
|
23 |
-
import pytorch_lightning as pl
|
24 |
-
|
25 |
-
|
26 |
-
class ModelCheckpoint(pl.callbacks.ModelCheckpoint):
|
27 |
-
|
28 |
-
def on_keyboard_interrupt(self, trainer, pl_module):
|
29 |
-
# Save model when keyboard interrupt
|
30 |
-
filepath = os.path.join(self.dirpath, self.prefix + 'interrupt.ckpt')
|
31 |
-
self._save_model(filepath)
|
32 |
-
|
33 |
-
|
34 |
-
if __name__ == '__main__':
|
35 |
-
|
36 |
-
device = 'cuda'
|
37 |
-
|
38 |
-
import argparse
|
39 |
-
parser = argparse.ArgumentParser()
|
40 |
-
parser.add_argument('--reward', type=str, default='mle')
|
41 |
-
args = parser.parse_args()
|
42 |
-
|
43 |
-
if args.reward == 'mle':
|
44 |
-
cfg = f'configs/phase1/fg_clipRN50_{args.reward}.yml'
|
45 |
-
else:
|
46 |
-
cfg = f'configs/phase2/fg_clipRN50_{args.reward}.yml'
|
47 |
-
|
48 |
-
print("Loading cfg from", cfg)
|
49 |
-
|
50 |
-
opt = opts.parse_opt(parse=False, cfg=cfg)
|
51 |
-
|
52 |
-
dataset = CaptionDataset(opt)
|
53 |
-
|
54 |
-
opt.vocab_size = dataset.vocab_size
|
55 |
-
opt.seq_length = dataset.seq_length
|
56 |
-
|
57 |
-
opt.batch_size = 40
|
58 |
-
|
59 |
-
opt.vocab = dataset.get_vocab()
|
60 |
-
|
61 |
-
model = models.setup(opt)
|
62 |
-
del opt.vocab
|
63 |
-
|
64 |
-
ckpt_path = opt.checkpoint_path + '-last.ckpt'
|
65 |
-
|
66 |
-
print("Loading checkpoint from", ckpt_path)
|
67 |
-
raw_state_dict = torch.load(
|
68 |
-
ckpt_path,
|
69 |
-
map_location=device)
|
70 |
-
|
71 |
-
strict = True
|
72 |
-
|
73 |
-
state_dict = raw_state_dict['state_dict']
|
74 |
-
|
75 |
-
if '_vocab' in state_dict:
|
76 |
-
model.vocab = utils.deserialize(state_dict['_vocab'])
|
77 |
-
del state_dict['_vocab']
|
78 |
-
elif strict:
|
79 |
-
raise KeyError
|
80 |
-
if '_opt' in state_dict:
|
81 |
-
saved_model_opt = utils.deserialize(state_dict['_opt'])
|
82 |
-
del state_dict['_opt']
|
83 |
-
# Make sure the saved opt is compatible with the curren topt
|
84 |
-
need_be_same = ["caption_model",
|
85 |
-
"rnn_type", "rnn_size", "num_layers"]
|
86 |
-
for checkme in need_be_same:
|
87 |
-
if getattr(saved_model_opt, checkme) in ['updown', 'topdown'] and \
|
88 |
-
getattr(opt, checkme) in ['updown', 'topdown']:
|
89 |
-
continue
|
90 |
-
assert getattr(saved_model_opt, checkme) == getattr(
|
91 |
-
opt, checkme), "Command line argument and saved model disagree on '%s' " % checkme
|
92 |
-
elif strict:
|
93 |
-
raise KeyError
|
94 |
-
res = model.load_state_dict(state_dict, strict)
|
95 |
-
print(res)
|
96 |
-
|
97 |
-
opt.use_grammar = False
|
98 |
-
|
99 |
-
lw_model = LossWrapper(model, opt)
|
100 |
-
|
101 |
-
split = 'test'
|
102 |
-
|
103 |
-
print("Building dataloader...")
|
104 |
-
|
105 |
-
test_dataset = torch.utils.data.Subset(
|
106 |
-
dataset,
|
107 |
-
dataset.split_ix[split]
|
108 |
-
)
|
109 |
-
test_loader = torch.utils.data.DataLoader(
|
110 |
-
test_dataset,
|
111 |
-
batch_size=opt.batch_size,
|
112 |
-
shuffle=False,
|
113 |
-
num_workers=4,
|
114 |
-
drop_last=False,
|
115 |
-
collate_fn=dataset.collate_func
|
116 |
-
)
|
117 |
-
|
118 |
-
eval_kwargs = {'dataset': opt.input_json}
|
119 |
-
eval_kwargs.update(vars(opt))
|
120 |
-
|
121 |
-
verbose = eval_kwargs.get('verbose', True)
|
122 |
-
verbose_beam = eval_kwargs.get('verbose_beam', 0)
|
123 |
-
verbose_loss = eval_kwargs.get('verbose_loss', 1)
|
124 |
-
# num_images = eval_kwargs.get('num_images', eval_kwargs.get('val_images_use', -1))
|
125 |
-
# lang_eval = eval_kwargs.get('language_eval', 0)
|
126 |
-
dataset = eval_kwargs.get('dataset', 'coco')
|
127 |
-
beam_size = eval_kwargs.get('beam_size', 1)
|
128 |
-
sample_n = eval_kwargs.get('sample_n', 1)
|
129 |
-
remove_bad_endings = eval_kwargs.get('remove_bad_endings', 0)
|
130 |
-
|
131 |
-
crit = lw_model.crit
|
132 |
-
|
133 |
-
model = model.to(device)
|
134 |
-
|
135 |
-
from tqdm import tqdm
|
136 |
-
|
137 |
-
test_id2sent = {}
|
138 |
-
|
139 |
-
model.eval()
|
140 |
-
|
141 |
-
print("running inference...")
|
142 |
-
|
143 |
-
for data in tqdm(test_loader):
|
144 |
-
with torch.no_grad():
|
145 |
-
# forward the model to get loss
|
146 |
-
tmp = [data['fc_feats'], data['att_feats'],
|
147 |
-
data['labels'], data['masks'], data['att_masks']]
|
148 |
-
tmp = [d.to(device) if isinstance(d, torch.Tensor) else d for d in tmp]
|
149 |
-
|
150 |
-
fc_feats, att_feats, labels, masks, att_masks = tmp
|
151 |
-
|
152 |
-
loss = crit(model(fc_feats, att_feats,
|
153 |
-
labels[..., :-1], att_masks), labels[..., 1:], masks[..., 1:])
|
154 |
-
|
155 |
-
# forward the model to also get generated samples for each image
|
156 |
-
# Only leave one feature for each image, in case duplicate sample
|
157 |
-
tmp_eval_kwargs = eval_kwargs.copy()
|
158 |
-
tmp_eval_kwargs.update({'sample_n': 1})
|
159 |
-
seq, seq_logprobs = model(
|
160 |
-
fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample')
|
161 |
-
seq = seq.data
|
162 |
-
entropy = - (F.softmax(seq_logprobs, dim=2) *
|
163 |
-
seq_logprobs).sum(2).sum(1) / ((seq > 0).to(seq_logprobs).sum(1)+1)
|
164 |
-
perplexity = - \
|
165 |
-
seq_logprobs.gather(2, seq.unsqueeze(2)).squeeze(
|
166 |
-
2).sum(1) / ((seq > 0).to(seq_logprobs).sum(1)+1)
|
167 |
-
|
168 |
-
# Print beam search
|
169 |
-
if beam_size > 1 and verbose_beam:
|
170 |
-
for i in range(fc_feats.shape[0]):
|
171 |
-
print('\n'.join([utils.decode_sequence(model.vocab, _[
|
172 |
-
'seq'].unsqueeze(0))[0] for _ in model.done_beams[i]]))
|
173 |
-
print('--' * 10)
|
174 |
-
sents = utils.decode_sequence(model.vocab, seq)
|
175 |
-
|
176 |
-
for d, sent in zip(data['infos'], sents):
|
177 |
-
test_id2sent[d['id']] = sent
|
178 |
-
|
179 |
-
res_path = f'FineCapEval_results/clipRN50_{args.reward}.json'
|
180 |
-
|
181 |
-
print("Results save at {}".format(res_path))
|
182 |
-
|
183 |
-
with open(res_path, 'w') as f:
|
184 |
-
json.dump(test_id2sent, f)
|
185 |
-
|
186 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tools/train_pl.py
DELETED
@@ -1,709 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import torch.nn as nn
|
3 |
-
import torch.nn.functional as F
|
4 |
-
import torch.optim as optim
|
5 |
-
|
6 |
-
import numpy as np
|
7 |
-
|
8 |
-
import time
|
9 |
-
import os
|
10 |
-
from collections import defaultdict
|
11 |
-
|
12 |
-
import captioning.utils.opts as opts
|
13 |
-
import captioning.models as models
|
14 |
-
from captioning.data.pth_loader import CaptionDataset
|
15 |
-
import captioning.utils.eval_utils as eval_utils
|
16 |
-
import captioning.utils.misc as utils
|
17 |
-
from captioning.utils.rewards import init_scorer, get_self_critical_reward
|
18 |
-
from captioning.modules.loss_wrapper import LossWrapper
|
19 |
-
|
20 |
-
import pytorch_lightning as pl
|
21 |
-
|
22 |
-
import detectron2.utils.comm as d2comm
|
23 |
-
from detectron2.utils.env import seed_all_rng
|
24 |
-
seed_all_rng(1234)
|
25 |
-
|
26 |
-
|
27 |
-
class LitModel(pl.LightningModule):
|
28 |
-
def __init__(self, opt):
|
29 |
-
super().__init__()
|
30 |
-
self.opt = opt
|
31 |
-
# Intilaize dataset
|
32 |
-
self.dataset = CaptionDataset(opt)
|
33 |
-
opt.vocab_size = self.dataset.vocab_size
|
34 |
-
opt.seq_length = self.dataset.seq_length
|
35 |
-
self.batch_size = opt.batch_size
|
36 |
-
|
37 |
-
# Build model
|
38 |
-
opt.vocab = self.dataset.get_vocab()
|
39 |
-
model = models.setup(opt)
|
40 |
-
# print(model)
|
41 |
-
del opt.vocab
|
42 |
-
|
43 |
-
# wrapper with loss in it.
|
44 |
-
lw_model = LossWrapper(model, opt)
|
45 |
-
|
46 |
-
self.model = model
|
47 |
-
self.lw_model = lw_model
|
48 |
-
|
49 |
-
self.struc_flag = None
|
50 |
-
self.sc_flag = None
|
51 |
-
|
52 |
-
# if self.opt.use_clipscore:
|
53 |
-
# if self.opt.use_clipscore or os.getenv('EVALUATE', '0') == '1':
|
54 |
-
# if CLIP-S+Grammar is used in reward -> Launch another CLIP-S where parameter is unchanged
|
55 |
-
if getattr(self.opt, 'use_grammar', False):
|
56 |
-
from captioning.utils.clipscore import CLIPScore
|
57 |
-
self.val_clipscore_model = CLIPScore(
|
58 |
-
mode=opt.clipscore_mode, use_grammar=False)
|
59 |
-
for p in self.val_clipscore_model.parameters():
|
60 |
-
p.requires_grad = False
|
61 |
-
else:
|
62 |
-
if self.lw_model.clipscore_model is not None:
|
63 |
-
self.val_clipscore_model = self.lw_model.clipscore_model
|
64 |
-
else:
|
65 |
-
from captioning.utils.clipscore import CLIPScore
|
66 |
-
self.val_clipscore_model = CLIPScore(
|
67 |
-
mode=opt.clipscore_mode, use_grammar=False)
|
68 |
-
for p in self.val_clipscore_model.parameters():
|
69 |
-
p.requires_grad = False
|
70 |
-
self.val_clipscore_model.eval()
|
71 |
-
|
72 |
-
# BERTSCORE
|
73 |
-
from bert_score import BERTScorer
|
74 |
-
self.bert_scorer = BERTScorer(
|
75 |
-
lang="en",
|
76 |
-
# rescale_with_baseline=True,
|
77 |
-
rescale_with_baseline=False,
|
78 |
-
device='cpu'
|
79 |
-
)
|
80 |
-
|
81 |
-
def forward(self, *args, **kwargs):
|
82 |
-
"""
|
83 |
-
I hate this design. Never pretend it as a nn.Module
|
84 |
-
"""
|
85 |
-
raise NotImplementedError
|
86 |
-
|
87 |
-
def train_dataloader(self):
|
88 |
-
train_dataset = torch.utils.data.Subset(
|
89 |
-
self.dataset,
|
90 |
-
self.dataset.split_ix['train']
|
91 |
-
)
|
92 |
-
|
93 |
-
train_loader = torch.utils.data.DataLoader(
|
94 |
-
dataset=train_dataset,
|
95 |
-
batch_size=self.batch_size,
|
96 |
-
shuffle=True,
|
97 |
-
num_workers=4,
|
98 |
-
collate_fn=self.dataset.collate_func
|
99 |
-
)
|
100 |
-
return train_loader
|
101 |
-
|
102 |
-
def val_dataloader(self, split='val'):
|
103 |
-
val_dataset = torch.utils.data.Subset(
|
104 |
-
self.dataset,
|
105 |
-
self.dataset.split_ix[split]
|
106 |
-
)
|
107 |
-
val_loader = torch.utils.data.DataLoader(
|
108 |
-
val_dataset,
|
109 |
-
batch_size=self.batch_size,
|
110 |
-
shuffle=False,
|
111 |
-
num_workers=4,
|
112 |
-
drop_last=False,
|
113 |
-
collate_fn=self.dataset.collate_func
|
114 |
-
)
|
115 |
-
return val_loader
|
116 |
-
|
117 |
-
def test_dataloader(self):
|
118 |
-
return self.val_dataloader('test')
|
119 |
-
|
120 |
-
def training_step(self, data, batch_idx):
|
121 |
-
sc_flag, struc_flag = self.sc_flag, self.struc_flag
|
122 |
-
|
123 |
-
tmp = [data['fc_feats'], data['att_feats'],
|
124 |
-
data['labels'], data['masks'], data['att_masks']]
|
125 |
-
fc_feats, att_feats, labels, masks, att_masks = tmp
|
126 |
-
if int(os.getenv('M2_cider', '0')) != 0:
|
127 |
-
data['gts'] = data['rawgts']
|
128 |
-
|
129 |
-
if self.opt.use_clipscore:
|
130 |
-
clip_vis_feats = data['clip_vis_feats']
|
131 |
-
model_out = self.lw_model(fc_feats, att_feats, labels, masks, att_masks,
|
132 |
-
data['gts'], torch.arange(0, len(data['gts'])), sc_flag, struc_flag,
|
133 |
-
clip_vis_feats=clip_vis_feats)
|
134 |
-
else:
|
135 |
-
model_out = self.lw_model(fc_feats, att_feats, labels, masks, att_masks,
|
136 |
-
data['gts'], torch.arange(0, len(data['gts'])), sc_flag, struc_flag)
|
137 |
-
loss = model_out['loss']
|
138 |
-
|
139 |
-
data_time = self.trainer.profiler.recorded_durations["get_train_batch"][-1]
|
140 |
-
data_time = torch.tensor(data_time)
|
141 |
-
|
142 |
-
logger_logs = model_out.copy()
|
143 |
-
# if struc_flag or sc_flag:
|
144 |
-
# logger_logs['reward'] = model_out['reward'].mean()
|
145 |
-
# logger_logs['reward_var'] = model_out['reward'].var(1).mean()
|
146 |
-
if struc_flag or sc_flag:
|
147 |
-
logger_logs['reward'] = model_out['reward'].mean()
|
148 |
-
for k in ['CLIP-S', 'RefCLIP-S', 'CIDEr', 'grammar_reward']:
|
149 |
-
if k in model_out:
|
150 |
-
logger_logs[k] = model_out[k]
|
151 |
-
if struc_flag:
|
152 |
-
logger_logs['reward_var'] = model_out['reward'].var(1).mean()
|
153 |
-
|
154 |
-
logger_logs['scheduled_sampling_prob'] = torch.tensor(
|
155 |
-
self.model.ss_prob)
|
156 |
-
# logger_logs['training_loss'] = loss
|
157 |
-
logger_logs['loss'] = loss
|
158 |
-
logger_logs['data_time'] = data_time
|
159 |
-
|
160 |
-
# UserWarning: The {progress_bar:dict keyword} was deprecated in 0.9.1 and will be removed in 1.0.0
|
161 |
-
# Please use self.log(...) inside the lightningModule instead.
|
162 |
-
|
163 |
-
# # log on a step or aggregate epoch metric to the logger and/or progress bar
|
164 |
-
# # (inside LightningModule)
|
165 |
-
# self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
|
166 |
-
# warnings.warn(*args, **kwargs)
|
167 |
-
# UserWarning: The {log:dict keyword} was deprecated in 0.9.1 and will be removed in 1.0.0
|
168 |
-
# Please use self.log(...) inside the lightningModule instead.
|
169 |
-
|
170 |
-
# output = {
|
171 |
-
# 'loss': loss,
|
172 |
-
# 'log': logger_logs,
|
173 |
-
# 'progress_bar': {'data_time': data_time}
|
174 |
-
# }
|
175 |
-
|
176 |
-
for k, v in logger_logs.items():
|
177 |
-
if k in ['reward', 'reward_var', 'data_time', 'CLIP-S', 'RefCLIP-S', 'CIDEr', 'grammar_reward']:
|
178 |
-
self.log('train/'+k, v, prog_bar=True)
|
179 |
-
else:
|
180 |
-
self.log('train/'+k, v)
|
181 |
-
|
182 |
-
return loss
|
183 |
-
|
184 |
-
def validation_step(self, data, batch_idx):
|
185 |
-
model = self.model
|
186 |
-
crit = self.lw_model.crit
|
187 |
-
|
188 |
-
opt = self.opt
|
189 |
-
eval_kwargs = {'dataset': opt.input_json}
|
190 |
-
eval_kwargs.update(vars(opt))
|
191 |
-
|
192 |
-
# CLIPScore
|
193 |
-
use_grammar = getattr(self.opt, 'use_grammar', False)
|
194 |
-
joint_out = getattr(self.opt, 'joint_out', False)
|
195 |
-
|
196 |
-
verbose = eval_kwargs.get('verbose', True)
|
197 |
-
verbose_beam = eval_kwargs.get('verbose_beam', 0)
|
198 |
-
verbose_loss = eval_kwargs.get('verbose_loss', 1)
|
199 |
-
# num_images = eval_kwargs.get('num_images', eval_kwargs.get('val_images_use', -1))
|
200 |
-
# lang_eval = eval_kwargs.get('language_eval', 0)
|
201 |
-
dataset = eval_kwargs.get('dataset', 'coco')
|
202 |
-
beam_size = eval_kwargs.get('beam_size', 1)
|
203 |
-
sample_n = eval_kwargs.get('sample_n', 1)
|
204 |
-
remove_bad_endings = eval_kwargs.get('remove_bad_endings', 0)
|
205 |
-
# Use this nasty way to make other code clean since it's a global configuration
|
206 |
-
os.environ["REMOVE_BAD_ENDINGS"] = str(remove_bad_endings)
|
207 |
-
|
208 |
-
predictions = []
|
209 |
-
n_predictions = []
|
210 |
-
|
211 |
-
loss = torch.tensor(0)
|
212 |
-
if data.get('labels', None) is not None and verbose_loss:
|
213 |
-
# forward the model to get loss
|
214 |
-
tmp = [data['fc_feats'], data['att_feats'],
|
215 |
-
data['labels'], data['masks'], data['att_masks']]
|
216 |
-
fc_feats, att_feats, labels, masks, att_masks = tmp
|
217 |
-
|
218 |
-
loss = crit(model(fc_feats, att_feats,
|
219 |
-
labels[..., :-1], att_masks), labels[..., 1:], masks[..., 1:])
|
220 |
-
|
221 |
-
# forward the model to also get generated samples for each image
|
222 |
-
# Only leave one feature for each image, in case duplicate sample
|
223 |
-
tmp_eval_kwargs = eval_kwargs.copy()
|
224 |
-
tmp_eval_kwargs.update({'sample_n': 1})
|
225 |
-
seq, seq_logprobs = model(
|
226 |
-
fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample')
|
227 |
-
seq = seq.data
|
228 |
-
entropy = - (F.softmax(seq_logprobs, dim=2) *
|
229 |
-
seq_logprobs).sum(2).sum(1) / ((seq > 0).to(seq_logprobs).sum(1)+1)
|
230 |
-
perplexity = - \
|
231 |
-
seq_logprobs.gather(2, seq.unsqueeze(2)).squeeze(
|
232 |
-
2).sum(1) / ((seq > 0).to(seq_logprobs).sum(1)+1)
|
233 |
-
|
234 |
-
# Print beam search
|
235 |
-
if beam_size > 1 and verbose_beam:
|
236 |
-
for i in range(fc_feats.shape[0]):
|
237 |
-
print('\n'.join([utils.decode_sequence(model.vocab, _[
|
238 |
-
'seq'].unsqueeze(0))[0] for _ in model.done_beams[i]]))
|
239 |
-
print('--' * 10)
|
240 |
-
sents = utils.decode_sequence(model.vocab, seq)
|
241 |
-
|
242 |
-
# if self.opt.use_clipscore or os.getenv('EVALUATE', '0') == '1':
|
243 |
-
# text_feat = self.lw_model.clipscore_model.text_extract(sents)
|
244 |
-
text_feat = self.val_clipscore_model.text_extract(sents, proj_norm=False)
|
245 |
-
|
246 |
-
text_cont_feat = self.val_clipscore_model.clip_model.text_projection(text_feat)
|
247 |
-
text_cont_feat = text_cont_feat / text_cont_feat.norm(dim=-1, keepdim=True)
|
248 |
-
|
249 |
-
vis_feat = data['clip_vis_feats']
|
250 |
-
# if self.opt.clipscore_mode == 'clip_s':
|
251 |
-
# clip_s = self.val_clipscore_model(text_feat=text_cont_feat, img_feat=vis_feat, mode='clip_s')
|
252 |
-
|
253 |
-
# elif self.opt.clipscore_mode == 'refclip_s':
|
254 |
-
clip_s = self.val_clipscore_model(text_feat=text_cont_feat, img_feat=vis_feat, mode='clip_s')
|
255 |
-
# ref_text = utils.decode_sequence(model.vocab, data['gts'])
|
256 |
-
|
257 |
-
gt_indices = torch.arange(0, len(data['gts']))
|
258 |
-
data_gts = [data['gts'][_] for _ in gt_indices.tolist()]
|
259 |
-
|
260 |
-
B = len(data_gts)
|
261 |
-
|
262 |
-
gts = []
|
263 |
-
gts_valid_mask = []
|
264 |
-
max_n_refs = max([len(_gts) for _gts in data_gts])
|
265 |
-
for i in range(len(data_gts)):
|
266 |
-
_gts = utils.decode_sequence(model.vocab, data_gts[i])
|
267 |
-
# pad references
|
268 |
-
n_ref = len(_gts)
|
269 |
-
_gts.extend([''] * (max_n_refs - n_ref))
|
270 |
-
gts.extend(_gts)
|
271 |
-
gts_valid_mask.extend([1] * n_ref + [0] * (max_n_refs - n_ref))
|
272 |
-
assert len(gts) == B * max_n_refs
|
273 |
-
assert len(gts_valid_mask) == B * max_n_refs
|
274 |
-
|
275 |
-
ref_text = gts
|
276 |
-
ref_text_mask = gts_valid_mask
|
277 |
-
|
278 |
-
refclip_s = self.val_clipscore_model(
|
279 |
-
text_feat=text_cont_feat, img_feat=vis_feat,
|
280 |
-
ref_text=ref_text, ref_text_mask=ref_text_mask, mode='refclip_s')
|
281 |
-
|
282 |
-
# use_grammar = getattr(self.opt, 'use_grammar', False)
|
283 |
-
# joint_out = getattr(self.opt, 'joint_out', False)
|
284 |
-
if use_grammar and not joint_out:
|
285 |
-
with torch.no_grad():
|
286 |
-
# grammar_logit = self.val_clipscore_model.grammar_score_head(text_feat.view(-1, 512))
|
287 |
-
grammar_logit = self.lw_model.clipscore_model.grammar_score_head(text_feat.view(-1, 512))
|
288 |
-
grammar_prob = torch.softmax(grammar_logit, dim=-1)[:, 1]
|
289 |
-
|
290 |
-
|
291 |
-
# BERTScore
|
292 |
-
if next(self.bert_scorer._model.parameters()).device != self.device:
|
293 |
-
self.bert_scorer._model.to(self.device)
|
294 |
-
self.bert_scorer.device = self.device
|
295 |
-
|
296 |
-
|
297 |
-
# [B*K] -> [B, K]
|
298 |
-
ref_text_per_example = []
|
299 |
-
for i in range(B):
|
300 |
-
ref_text_list_example = []
|
301 |
-
for k in range(max_n_refs):
|
302 |
-
ref = ref_text[i * max_n_refs + k]
|
303 |
-
if len(ref) > 0:
|
304 |
-
ref_text_list_example.append(ref)
|
305 |
-
# assert len(ref_text_list_example) == max_n_refs
|
306 |
-
ref_text_per_example.append(ref_text_list_example)
|
307 |
-
assert len(ref_text_per_example) == B
|
308 |
-
|
309 |
-
P, R, F1 = self.bert_scorer.score(
|
310 |
-
sents,
|
311 |
-
ref_text_per_example,
|
312 |
-
)
|
313 |
-
bertscore_f1 = F1
|
314 |
-
# print('Example 5:')
|
315 |
-
# for i in range(5):
|
316 |
-
# print('Generated:', sents[i])
|
317 |
-
# print('ref_text:', ref_text_per_example[i])
|
318 |
-
# print('BERT-Score:', F1[i].item())
|
319 |
-
|
320 |
-
|
321 |
-
for k, sent in enumerate(sents):
|
322 |
-
entry = {'image_id': data['infos'][k]['id'], 'caption': sent,
|
323 |
-
'perplexity': perplexity[k].item(), 'entropy': entropy[k].item()}
|
324 |
-
if self.opt.use_clipscore or os.getenv('EVALUATE', '0') == '1':
|
325 |
-
# if self.opt.clipscore_mode == 'clip_s':
|
326 |
-
# entry['clipscore'] = clipscore[k].item()
|
327 |
-
# entry['CLIP-S'] = clip_s[k].item()
|
328 |
-
# elif self.opt.clipscore_mode == 'refclip_s':
|
329 |
-
entry['CLIP-S'] = clip_s[k].item()
|
330 |
-
entry['RefCLIP-S'] = refclip_s[k].item()
|
331 |
-
|
332 |
-
if use_grammar and not joint_out:
|
333 |
-
entry['grammar_prob'] = grammar_prob[k].item()
|
334 |
-
|
335 |
-
# BERT-S
|
336 |
-
entry['BERT-S'] = bertscore_f1[k].item()
|
337 |
-
|
338 |
-
if eval_kwargs.get('dump_path', 0) == 1:
|
339 |
-
entry['file_name'] = data['infos'][k]['file_path']
|
340 |
-
predictions.append(entry)
|
341 |
-
if eval_kwargs.get('dump_images', 0) == 1:
|
342 |
-
# dump the raw image to vis/ folder
|
343 |
-
cmd = 'cp "' + os.path.join(eval_kwargs['image_root'], data['infos'][k]['file_path']) + \
|
344 |
-
'" vis/imgs/img' + \
|
345 |
-
str(len(predictions)) + '.jpg' # bit gross
|
346 |
-
print(cmd)
|
347 |
-
os.system(cmd)
|
348 |
-
|
349 |
-
if verbose:
|
350 |
-
print('image %s: %s' %
|
351 |
-
(entry['image_id'], entry['caption']))
|
352 |
-
|
353 |
-
if sample_n > 1:
|
354 |
-
eval_utils.eval_split_n(model, n_predictions, [
|
355 |
-
fc_feats, att_feats, att_masks, data], eval_kwargs)
|
356 |
-
|
357 |
-
output = {
|
358 |
-
# 'val_loss': loss,
|
359 |
-
'loss': loss,
|
360 |
-
'predictions': predictions,
|
361 |
-
'n_predictions': n_predictions,
|
362 |
-
}
|
363 |
-
return output
|
364 |
-
|
365 |
-
def test_step(self, *args, **kwargs):
|
366 |
-
return self.validation_step(*args, **kwargs)
|
367 |
-
|
368 |
-
def validation_epoch_end(self, outputs, split='val'):
|
369 |
-
outputs = d2comm.gather(outputs)
|
370 |
-
# master node
|
371 |
-
if d2comm.is_main_process():
|
372 |
-
assert self.trainer.node_rank == 0 and self.trainer.local_rank == 0
|
373 |
-
outputs = sum(outputs, [])
|
374 |
-
|
375 |
-
opt = self.opt
|
376 |
-
# val_loss_mean = sum([_['val_loss']
|
377 |
-
# val_loss_mean = sum([_['val_loss'].cpu()
|
378 |
-
val_loss_mean = sum([_['loss'].cpu()
|
379 |
-
for _ in outputs]) / len(outputs)
|
380 |
-
|
381 |
-
predictions = sum([_['predictions'] for _ in outputs], [])
|
382 |
-
if len(outputs[0]['n_predictions']) != 0:
|
383 |
-
n_predictions = sum([_['n_predictions'] for _ in outputs], [])
|
384 |
-
else:
|
385 |
-
n_predictions = []
|
386 |
-
|
387 |
-
lang_stats = None
|
388 |
-
if len(n_predictions) > 0 and 'perplexity' in n_predictions[0]:
|
389 |
-
n_predictions = sorted(
|
390 |
-
n_predictions, key=lambda x: x['perplexity'])
|
391 |
-
|
392 |
-
if not os.path.isdir('eval_results'):
|
393 |
-
os.mkdir('eval_results')
|
394 |
-
torch.save((predictions, n_predictions), os.path.join(
|
395 |
-
'eval_results/', '.saved_pred_' + opt.id + '_' + split + '.pth'))
|
396 |
-
|
397 |
-
if opt.language_eval:
|
398 |
-
lang_stats = eval_utils.language_eval(
|
399 |
-
opt.input_json, predictions, n_predictions, vars(opt), split)
|
400 |
-
|
401 |
-
if opt.reduce_on_plateau:
|
402 |
-
optimizer = self.trainer.optimizers[0]
|
403 |
-
if 'CIDEr' in lang_stats:
|
404 |
-
optimizer.scheduler_step(-lang_stats['CIDEr'])
|
405 |
-
else:
|
406 |
-
optimizer.scheduler_step(val_loss_mean)
|
407 |
-
|
408 |
-
# out = {
|
409 |
-
# 'val_loss': val_loss_mean
|
410 |
-
# }
|
411 |
-
out = {
|
412 |
-
'loss': val_loss_mean
|
413 |
-
}
|
414 |
-
out.update(lang_stats)
|
415 |
-
# out['to_monitor'] = lang_stats['CIDEr'] if lang_stats is not None else -val_loss_mean
|
416 |
-
if self.opt.use_clipscore or os.getenv('EVALUATE', '0') == '1':
|
417 |
-
# if self.opt.clipscore_mode == 'clip_s':
|
418 |
-
# out['clipscore'] = sum([p['clipscore'] for p in predictions]) / len(predictions)
|
419 |
-
# print('CLIPScore', out['clipscore'])
|
420 |
-
# out['CLIP-S'] = sum([p['CLIP-S'] for p in predictions]) / len(predictions)
|
421 |
-
# print('CLIP-S', out['CLIP-S'])
|
422 |
-
# elif self.opt.clipscore_mode == 'refclip_s':
|
423 |
-
out['CLIP-S'] = sum([p['CLIP-S'] for p in predictions]) / len(predictions)
|
424 |
-
print('CLIP-S', out['CLIP-S'])
|
425 |
-
|
426 |
-
out['RefCLIP-S'] = sum([p['RefCLIP-S'] for p in predictions]) / len(predictions)
|
427 |
-
print('RefCLIP-S', out['RefCLIP-S'])
|
428 |
-
|
429 |
-
if getattr(self.opt, 'use_grammar', False) and not getattr(self.opt, 'joint_out', False):
|
430 |
-
out['grammar_prob'] = sum([p['grammar_prob'] for p in predictions]) / len(predictions)
|
431 |
-
print('grammar_prob', out['grammar_prob'])
|
432 |
-
|
433 |
-
out['BERT-S'] = sum([p['BERT-S'] for p in predictions]) / len(predictions)
|
434 |
-
print('BERT-S', out['BERT-S'])
|
435 |
-
else:
|
436 |
-
out = {}
|
437 |
-
|
438 |
-
out = d2comm.all_gather(out)[0] # Only the one from master node
|
439 |
-
assert len(out) > 0 # make sure the head has index 0
|
440 |
-
|
441 |
-
# must all be tensors
|
442 |
-
out = {k: torch.tensor(v) if not torch.is_tensor(
|
443 |
-
v) else v for k, v in out.items()}
|
444 |
-
|
445 |
-
# return {
|
446 |
-
# 'progress_bar': {'val_loss': out['val_loss']},
|
447 |
-
# 'log': out,
|
448 |
-
# }
|
449 |
-
for k, v in out.items():
|
450 |
-
# if k in ['loss', 'clipscore', 'RefCLIP-S', 'CIDEr']:
|
451 |
-
# if split != 'test':
|
452 |
-
# self.log(f'{split}/{k}', v, prog_bar=True)
|
453 |
-
# elif k == 'to_monitor':
|
454 |
-
# if split != 'test':
|
455 |
-
# self.log(f'{split}/{k}', v)
|
456 |
-
# else:
|
457 |
-
self.log(f'{split}/{k}', v)
|
458 |
-
|
459 |
-
def test_epoch_end(self, outputs):
|
460 |
-
# out = self.validation_epoch_end(outputs, 'test')
|
461 |
-
# out['progress_bar'] = {
|
462 |
-
# # 'test_loss': out['progress_bar']['val_loss']
|
463 |
-
# 'test_loss': out['progress_bar']['loss']
|
464 |
-
# }
|
465 |
-
# out['log']['test_loss'] = out['log']['val_loss']
|
466 |
-
# del out['log']['val_loss']
|
467 |
-
# del out['log']['to_monitor']
|
468 |
-
|
469 |
-
# out['log'] = {'test_'+k if 'test' not in k else k:v \
|
470 |
-
# for k,v in out['log'].items()}
|
471 |
-
|
472 |
-
# return out
|
473 |
-
self.validation_epoch_end(outputs, 'test')
|
474 |
-
|
475 |
-
def configure_optimizers(self):
|
476 |
-
opt = self.opt
|
477 |
-
model = self.model
|
478 |
-
|
479 |
-
parameters = [p for p in model.parameters() if p.requires_grad]
|
480 |
-
|
481 |
-
if opt.noamopt:
|
482 |
-
# assert opt.caption_model in ['transformer', 'bert', 'm2transformer'], 'noamopt can only work with transformer'
|
483 |
-
optimizer = utils.get_std_opt(
|
484 |
-
model, optim_func=opt.optim, factor=opt.noamopt_factor, warmup=opt.noamopt_warmup)
|
485 |
-
elif opt.reduce_on_plateau:
|
486 |
-
# optimizer = utils.build_optimizer(model.parameters(), opt)
|
487 |
-
optimizer = utils.build_optimizer(parameters, opt)
|
488 |
-
optimizer = utils.ReduceLROnPlateau(optimizer,
|
489 |
-
factor=opt.reduce_on_plateau_factor,
|
490 |
-
patience=opt.reduce_on_plateau_patience)
|
491 |
-
else:
|
492 |
-
# optimizer = utils.build_optimizer(model.parameters(), opt)
|
493 |
-
optimizer = utils.build_optimizer(parameters, opt)
|
494 |
-
return [optimizer], []
|
495 |
-
|
496 |
-
def optimizer_step(self, epoch, batch_idx, optimizer,
|
497 |
-
optimizer_idx, *args, **kwargs):
|
498 |
-
# warm up lr
|
499 |
-
opt = self.opt
|
500 |
-
iteration = self.trainer.global_step
|
501 |
-
if opt.use_warmup and (iteration < opt.noamopt_warmup):
|
502 |
-
opt.current_lr = opt.learning_rate * \
|
503 |
-
(iteration+1) / opt.noamopt_warmup
|
504 |
-
utils.set_lr(optimizer, opt.current_lr)
|
505 |
-
|
506 |
-
super().optimizer_step(epoch, batch_idx, optimizer,
|
507 |
-
optimizer_idx, *args, **kwargs)
|
508 |
-
|
509 |
-
def state_dict(self):
|
510 |
-
"""
|
511 |
-
Save the model state dict as well as opt and vocab
|
512 |
-
"""
|
513 |
-
state_dict = self.model.state_dict()
|
514 |
-
device = next(iter(state_dict.values())).device
|
515 |
-
assert '_vocab' not in state_dict and '_opt' not in state_dict, 'Just in case'
|
516 |
-
state_dict.update({
|
517 |
-
'_vocab': utils.serialize_to_tensor(self.model.vocab).to(device),
|
518 |
-
'_opt': utils.serialize_to_tensor(self.opt).to(device)
|
519 |
-
})
|
520 |
-
return state_dict
|
521 |
-
|
522 |
-
def load_state_dict(self, state_dict=None, strict=True):
|
523 |
-
if '_vocab' in state_dict:
|
524 |
-
self.model.vocab = utils.deserialize(state_dict['_vocab'])
|
525 |
-
del state_dict['_vocab']
|
526 |
-
# elif strict:
|
527 |
-
# raise KeyError
|
528 |
-
if '_opt' in state_dict:
|
529 |
-
saved_model_opt = utils.deserialize(state_dict['_opt'])
|
530 |
-
del state_dict['_opt']
|
531 |
-
opt = self.opt
|
532 |
-
# Make sure the saved opt is compatible with the curren topt
|
533 |
-
need_be_same = ["caption_model",
|
534 |
-
"rnn_type", "rnn_size", "num_layers"]
|
535 |
-
for checkme in need_be_same:
|
536 |
-
if getattr(saved_model_opt, checkme) in ['updown', 'topdown'] and \
|
537 |
-
getattr(opt, checkme) in ['updown', 'topdown']:
|
538 |
-
continue
|
539 |
-
assert getattr(saved_model_opt, checkme) == getattr(
|
540 |
-
opt, checkme), "Command line argument and saved model disagree on '%s' " % checkme
|
541 |
-
# elif strict:
|
542 |
-
# raise KeyError
|
543 |
-
self.model.load_state_dict(state_dict, strict)
|
544 |
-
|
545 |
-
|
546 |
-
class OnEpochStartCallback(pl.Callback):
|
547 |
-
|
548 |
-
def on_epoch_start(self, trainer, pl_module):
|
549 |
-
# Update lr/training stage/scheduled sampling prob etc.
|
550 |
-
opt = pl_module.opt
|
551 |
-
model = pl_module.model
|
552 |
-
epoch = trainer.current_epoch
|
553 |
-
optimizer = trainer.optimizers[0]
|
554 |
-
|
555 |
-
if not opt.noamopt and not opt.reduce_on_plateau:
|
556 |
-
# Assign the learning rate
|
557 |
-
if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
|
558 |
-
frac = (
|
559 |
-
epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every
|
560 |
-
decay_factor = opt.learning_rate_decay_rate ** frac
|
561 |
-
opt.current_lr = opt.learning_rate * decay_factor
|
562 |
-
else:
|
563 |
-
opt.current_lr = opt.learning_rate
|
564 |
-
utils.set_lr(optimizer, opt.current_lr) # set the decayed rate
|
565 |
-
# Assign the scheduled sampling prob
|
566 |
-
if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
|
567 |
-
frac = (
|
568 |
-
epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every
|
569 |
-
opt.ss_prob = min(opt.scheduled_sampling_increase_prob *
|
570 |
-
frac, opt.scheduled_sampling_max_prob)
|
571 |
-
model.ss_prob = opt.ss_prob
|
572 |
-
|
573 |
-
# If start self critical training
|
574 |
-
if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
|
575 |
-
sc_flag = True
|
576 |
-
init_scorer(opt.cached_tokens)
|
577 |
-
else:
|
578 |
-
sc_flag = False
|
579 |
-
|
580 |
-
# If start structure loss training
|
581 |
-
if opt.structure_after != -1 and epoch >= opt.structure_after:
|
582 |
-
struc_flag = True
|
583 |
-
init_scorer(opt.cached_tokens)
|
584 |
-
else:
|
585 |
-
struc_flag = False
|
586 |
-
|
587 |
-
pl_module.struc_flag = struc_flag
|
588 |
-
pl_module.sc_flag = sc_flag
|
589 |
-
|
590 |
-
|
591 |
-
class ModelCheckpoint(pl.callbacks.ModelCheckpoint):
|
592 |
-
|
593 |
-
def on_keyboard_interrupt(self, trainer, pl_module):
|
594 |
-
# Save model when keyboard interrupt
|
595 |
-
filepath = os.path.join(self.dirpath, self.prefix + 'interrupt.ckpt')
|
596 |
-
self._save_model(filepath)
|
597 |
-
|
598 |
-
|
599 |
-
opt = opts.parse_opt()
|
600 |
-
|
601 |
-
checkpoint_callback = ModelCheckpoint(
|
602 |
-
filepath=opt.checkpoint_path,
|
603 |
-
# dirpath=opt.checkpoint_path,
|
604 |
-
save_last=True,
|
605 |
-
save_top_k=1,
|
606 |
-
verbose=True,
|
607 |
-
# monitor='to_monitor',
|
608 |
-
# monitor='val/to_monitor',
|
609 |
-
monitor='val/CIDEr',
|
610 |
-
mode='max',
|
611 |
-
# prefix=opt.id+'_',
|
612 |
-
prefix=opt.id,
|
613 |
-
# filename=f'{opt.id}_',
|
614 |
-
)
|
615 |
-
|
616 |
-
verbose = True
|
617 |
-
# import torch
|
618 |
-
# if torch.cuda.current_device() in [0, -1]:
|
619 |
-
if 'LOCAL_RANK' in os.environ and os.environ['LOCAL_RANK'] != '0':
|
620 |
-
verbose = False
|
621 |
-
|
622 |
-
if verbose:
|
623 |
-
print(opt)
|
624 |
-
print("""
|
625 |
-
val_image_use,
|
626 |
-
save_checkpoint_very
|
627 |
-
save_every_epoch,
|
628 |
-
save_history-ckpt will be ignored.
|
629 |
-
""")
|
630 |
-
|
631 |
-
# Lightning defines batch size as batch size per gpu
|
632 |
-
assert opt.batch_size % torch.cuda.device_count() == 0
|
633 |
-
opt.batch_size = opt.batch_size // torch.cuda.device_count()
|
634 |
-
|
635 |
-
# If resume from last checkpoint
|
636 |
-
# if opt.start_from is not None and os.path.isfile(os.path.join(opt.start_from, f'{opt.id}_last.ckpt')):
|
637 |
-
# resume_from = os.path.join(opt.start_from, f'{opt.id}_last.ckpt')
|
638 |
-
if opt.start_from is not None:
|
639 |
-
resume_from = os.path.join(opt.start_from, f'{opt.id}-last.ckpt')
|
640 |
-
if os.path.isfile(resume_from):
|
641 |
-
if verbose:
|
642 |
-
print('Loading checkpoint from', resume_from)
|
643 |
-
else:
|
644 |
-
print("Checkpoint not found:", resume_from)
|
645 |
-
resume_from = None
|
646 |
-
else:
|
647 |
-
resume_from = None
|
648 |
-
|
649 |
-
from pytorch_lightning.loggers import WandbLogger
|
650 |
-
wandb_logger = WandbLogger(
|
651 |
-
project='CLIP-ViL-COCOCaption',
|
652 |
-
name=opt.id,
|
653 |
-
)
|
654 |
-
|
655 |
-
if verbose:
|
656 |
-
wandb_logger.experiment.config.update(opt)
|
657 |
-
from pathlib import Path
|
658 |
-
import glob
|
659 |
-
import wandb
|
660 |
-
# src_dir = Path(__file__).resolve().parent.parent
|
661 |
-
glob_str = "**/*.py"
|
662 |
-
base_path = './'
|
663 |
-
wandb.save(glob_str=glob_str, base_path=base_path)
|
664 |
-
|
665 |
-
# code = wandb.Artifact('project-source', type='code')
|
666 |
-
# for path in glob.glob('**/*.py', recursive=True):
|
667 |
-
# code.add_file(path, name='source/'+path)
|
668 |
-
# print(path)
|
669 |
-
# wandb.run.use_artifact(code)
|
670 |
-
|
671 |
-
|
672 |
-
|
673 |
-
|
674 |
-
lit = LitModel(opt)
|
675 |
-
# warning grad_clip_mode is ignored.
|
676 |
-
trainer = pl.Trainer(
|
677 |
-
callbacks=[
|
678 |
-
OnEpochStartCallback(),
|
679 |
-
# pl.callbacks.lr_logger.LearningRateLogger()
|
680 |
-
pl.callbacks.LearningRateMonitor()
|
681 |
-
],
|
682 |
-
default_root_dir=opt.checkpoint_path,
|
683 |
-
resume_from_checkpoint=resume_from,
|
684 |
-
distributed_backend='ddp',
|
685 |
-
check_val_every_n_epoch=1,
|
686 |
-
max_epochs=opt.max_epochs,
|
687 |
-
gradient_clip_val=opt.grad_clip_value,
|
688 |
-
gpus=torch.cuda.device_count(),
|
689 |
-
checkpoint_callback=checkpoint_callback,
|
690 |
-
log_gpu_memory='min_max',
|
691 |
-
# log_save_interval=opt.losses_log_every,
|
692 |
-
log_every_n_steps=opt.losses_log_every,
|
693 |
-
profiler=True,
|
694 |
-
# profiler='simple',
|
695 |
-
# row_log_interval=10, # what is it?
|
696 |
-
flush_logs_every_n_steps=10,
|
697 |
-
num_sanity_val_steps=0,
|
698 |
-
# val_check_interval=0.01,
|
699 |
-
# limit_train_batches=500,
|
700 |
-
# progress_bar_refresh_rate=0,
|
701 |
-
# fast_dev_run=True,
|
702 |
-
precision=opt.precision,
|
703 |
-
logger=wandb_logger
|
704 |
-
)
|
705 |
-
|
706 |
-
if os.getenv('EVALUATE', '0') == '1':
|
707 |
-
trainer.test(lit)
|
708 |
-
else:
|
709 |
-
trainer.fit(lit)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|