Praveen Malla
commited on
Commit
·
fd687ca
1
Parent(s):
c680bdc
changing to cpu
Browse files- got_vision_b.py +52 -32
got_vision_b.py
CHANGED
@@ -5,6 +5,7 @@ from functools import partial
|
|
5 |
import torch.nn as nn
|
6 |
from typing import Type
|
7 |
|
|
|
8 |
|
9 |
|
10 |
class MLPBlock(nn.Module):
|
@@ -23,7 +24,6 @@ class MLPBlock(nn.Module):
|
|
23 |
return self.lin2(self.act(self.lin1(x)))
|
24 |
|
25 |
|
26 |
-
|
27 |
class LayerNorm2d(nn.Module):
|
28 |
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
|
29 |
super().__init__()
|
@@ -39,7 +39,6 @@ class LayerNorm2d(nn.Module):
|
|
39 |
return x
|
40 |
|
41 |
|
42 |
-
|
43 |
class ImageEncoderViT(nn.Module):
|
44 |
def __init__(
|
45 |
self,
|
@@ -92,7 +91,9 @@ class ImageEncoderViT(nn.Module):
|
|
92 |
if use_abs_pos:
|
93 |
# Initialize absolute positional embedding with pretrain image size.
|
94 |
self.pos_embed = nn.Parameter(
|
95 |
-
torch.zeros(
|
|
|
|
|
96 |
)
|
97 |
|
98 |
self.blocks = nn.ModuleList()
|
@@ -129,9 +130,10 @@ class ImageEncoderViT(nn.Module):
|
|
129 |
LayerNorm2d(out_chans),
|
130 |
)
|
131 |
|
132 |
-
|
133 |
self.net_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False)
|
134 |
-
self.net_3 = nn.Conv2d(
|
|
|
|
|
135 |
|
136 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
137 |
x = self.patch_embed(x)
|
@@ -145,7 +147,6 @@ class ImageEncoderViT(nn.Module):
|
|
145 |
x = self.net_2(x)
|
146 |
x = self.net_3(x)
|
147 |
|
148 |
-
|
149 |
return x
|
150 |
|
151 |
|
@@ -192,7 +193,9 @@ class Block(nn.Module):
|
|
192 |
)
|
193 |
|
194 |
self.norm2 = norm_layer(dim)
|
195 |
-
self.mlp = MLPBlock(
|
|
|
|
|
196 |
|
197 |
self.window_size = window_size
|
198 |
|
@@ -257,23 +260,34 @@ class Attention(nn.Module):
|
|
257 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
258 |
B, H, W, _ = x.shape
|
259 |
# qkv with shape (3, B, nHead, H * W, C)
|
260 |
-
qkv =
|
|
|
|
|
261 |
# q, k, v with shape (B * nHead, H * W, C)
|
262 |
q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
|
263 |
|
264 |
attn = (q * self.scale) @ k.transpose(-2, -1)
|
265 |
|
266 |
if self.use_rel_pos:
|
267 |
-
attn = add_decomposed_rel_pos(
|
|
|
|
|
268 |
|
269 |
attn = attn.softmax(dim=-1)
|
270 |
-
x = (
|
|
|
|
|
|
|
|
|
|
|
271 |
x = self.proj(x)
|
272 |
|
273 |
return x
|
274 |
|
275 |
|
276 |
-
def window_partition(
|
|
|
|
|
277 |
"""
|
278 |
Partition into non-overlapping windows with padding if needed.
|
279 |
Args:
|
@@ -293,12 +307,17 @@ def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, T
|
|
293 |
Hp, Wp = H + pad_h, W + pad_w
|
294 |
|
295 |
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
|
296 |
-
windows =
|
|
|
|
|
297 |
return windows, (Hp, Wp)
|
298 |
|
299 |
|
300 |
def window_unpartition(
|
301 |
-
windows: torch.Tensor,
|
|
|
|
|
|
|
302 |
) -> torch.Tensor:
|
303 |
"""
|
304 |
Window unpartition into original sequences and removing padding.
|
@@ -314,7 +333,9 @@ def window_unpartition(
|
|
314 |
Hp, Wp = pad_hw
|
315 |
H, W = hw
|
316 |
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
|
317 |
-
x = windows.view(
|
|
|
|
|
318 |
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
|
319 |
|
320 |
if Hp > H or Wp > W:
|
@@ -386,7 +407,9 @@ def add_decomposed_rel_pos(
|
|
386 |
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
|
387 |
|
388 |
attn = (
|
389 |
-
attn.view(B, q_h, q_w, k_h, k_w)
|
|
|
|
|
390 |
).view(B, q_h * q_w, k_h * k_w)
|
391 |
|
392 |
return attn
|
@@ -426,7 +449,6 @@ class PatchEmbed(nn.Module):
|
|
426 |
return x
|
427 |
|
428 |
|
429 |
-
|
430 |
def build_GOT_vit_b(checkpoint=None):
|
431 |
return _build_GOT_vision(
|
432 |
encoder_embed_dim=768,
|
@@ -448,21 +470,19 @@ def _build_GOT_vision(
|
|
448 |
image_size = 1024
|
449 |
vit_patch_size = 16
|
450 |
image_embedding_size = image_size // vit_patch_size
|
451 |
-
image_encoder=ImageEncoderViT(
|
452 |
-
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
|
466 |
|
467 |
return image_encoder
|
468 |
-
|
|
|
5 |
import torch.nn as nn
|
6 |
from typing import Type
|
7 |
|
8 |
+
torch.set_default_device("cpu")
|
9 |
|
10 |
|
11 |
class MLPBlock(nn.Module):
|
|
|
24 |
return self.lin2(self.act(self.lin1(x)))
|
25 |
|
26 |
|
|
|
27 |
class LayerNorm2d(nn.Module):
|
28 |
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
|
29 |
super().__init__()
|
|
|
39 |
return x
|
40 |
|
41 |
|
|
|
42 |
class ImageEncoderViT(nn.Module):
|
43 |
def __init__(
|
44 |
self,
|
|
|
91 |
if use_abs_pos:
|
92 |
# Initialize absolute positional embedding with pretrain image size.
|
93 |
self.pos_embed = nn.Parameter(
|
94 |
+
torch.zeros(
|
95 |
+
1, img_size // patch_size, img_size // patch_size, embed_dim
|
96 |
+
)
|
97 |
)
|
98 |
|
99 |
self.blocks = nn.ModuleList()
|
|
|
130 |
LayerNorm2d(out_chans),
|
131 |
)
|
132 |
|
|
|
133 |
self.net_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False)
|
134 |
+
self.net_3 = nn.Conv2d(
|
135 |
+
512, 1024, kernel_size=3, stride=2, padding=1, bias=False
|
136 |
+
)
|
137 |
|
138 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
139 |
x = self.patch_embed(x)
|
|
|
147 |
x = self.net_2(x)
|
148 |
x = self.net_3(x)
|
149 |
|
|
|
150 |
return x
|
151 |
|
152 |
|
|
|
193 |
)
|
194 |
|
195 |
self.norm2 = norm_layer(dim)
|
196 |
+
self.mlp = MLPBlock(
|
197 |
+
embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer
|
198 |
+
)
|
199 |
|
200 |
self.window_size = window_size
|
201 |
|
|
|
260 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
261 |
B, H, W, _ = x.shape
|
262 |
# qkv with shape (3, B, nHead, H * W, C)
|
263 |
+
qkv = (
|
264 |
+
self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
265 |
+
)
|
266 |
# q, k, v with shape (B * nHead, H * W, C)
|
267 |
q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
|
268 |
|
269 |
attn = (q * self.scale) @ k.transpose(-2, -1)
|
270 |
|
271 |
if self.use_rel_pos:
|
272 |
+
attn = add_decomposed_rel_pos(
|
273 |
+
attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)
|
274 |
+
)
|
275 |
|
276 |
attn = attn.softmax(dim=-1)
|
277 |
+
x = (
|
278 |
+
(attn @ v)
|
279 |
+
.view(B, self.num_heads, H, W, -1)
|
280 |
+
.permute(0, 2, 3, 1, 4)
|
281 |
+
.reshape(B, H, W, -1)
|
282 |
+
)
|
283 |
x = self.proj(x)
|
284 |
|
285 |
return x
|
286 |
|
287 |
|
288 |
+
def window_partition(
|
289 |
+
x: torch.Tensor, window_size: int
|
290 |
+
) -> Tuple[torch.Tensor, Tuple[int, int]]:
|
291 |
"""
|
292 |
Partition into non-overlapping windows with padding if needed.
|
293 |
Args:
|
|
|
307 |
Hp, Wp = H + pad_h, W + pad_w
|
308 |
|
309 |
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
|
310 |
+
windows = (
|
311 |
+
x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
312 |
+
)
|
313 |
return windows, (Hp, Wp)
|
314 |
|
315 |
|
316 |
def window_unpartition(
|
317 |
+
windows: torch.Tensor,
|
318 |
+
window_size: int,
|
319 |
+
pad_hw: Tuple[int, int],
|
320 |
+
hw: Tuple[int, int],
|
321 |
) -> torch.Tensor:
|
322 |
"""
|
323 |
Window unpartition into original sequences and removing padding.
|
|
|
333 |
Hp, Wp = pad_hw
|
334 |
H, W = hw
|
335 |
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
|
336 |
+
x = windows.view(
|
337 |
+
B, Hp // window_size, Wp // window_size, window_size, window_size, -1
|
338 |
+
)
|
339 |
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
|
340 |
|
341 |
if Hp > H or Wp > W:
|
|
|
407 |
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
|
408 |
|
409 |
attn = (
|
410 |
+
attn.view(B, q_h, q_w, k_h, k_w)
|
411 |
+
+ rel_h[:, :, :, :, None]
|
412 |
+
+ rel_w[:, :, :, None, :]
|
413 |
).view(B, q_h * q_w, k_h * k_w)
|
414 |
|
415 |
return attn
|
|
|
449 |
return x
|
450 |
|
451 |
|
|
|
452 |
def build_GOT_vit_b(checkpoint=None):
|
453 |
return _build_GOT_vision(
|
454 |
encoder_embed_dim=768,
|
|
|
470 |
image_size = 1024
|
471 |
vit_patch_size = 16
|
472 |
image_embedding_size = image_size // vit_patch_size
|
473 |
+
image_encoder = ImageEncoderViT(
|
474 |
+
depth=encoder_depth,
|
475 |
+
embed_dim=encoder_embed_dim,
|
476 |
+
img_size=image_size,
|
477 |
+
mlp_ratio=4,
|
478 |
+
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
|
479 |
+
num_heads=encoder_num_heads,
|
480 |
+
patch_size=vit_patch_size,
|
481 |
+
qkv_bias=True,
|
482 |
+
use_rel_pos=True,
|
483 |
+
global_attn_indexes=encoder_global_attn_indexes,
|
484 |
+
window_size=14,
|
485 |
+
out_chans=prompt_embed_dim,
|
486 |
+
)
|
|
|
487 |
|
488 |
return image_encoder
|
|