juansensio commited on
Commit
c5d0550
·
verified ·
1 Parent(s): 5159f3a

Upload 5 files

Browse files
models/__pycache__/backbone.cpython-39.pyc ADDED
Binary file (30 kB). View file
 
models/__pycache__/dpt_head.cpython-39.pyc ADDED
Binary file (15.6 kB). View file
 
models/backbone.py ADDED
@@ -0,0 +1,957 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ from torch import nn
8
+ import torchvision
9
+ from torch.nn.modules.batchnorm import _BatchNorm
10
+ from torch.nn.modules.utils import _pair as to_2tuple
11
+ import math
12
+ import warnings
13
+ from collections import OrderedDict
14
+ from torch import Tensor
15
+
16
+ import torch.nn.functional as F
17
+ from typing import Callable, Optional, Tuple, Union
18
+ from functools import partial
19
+ import pdb
20
+
21
+ class MaskingGenerator:
22
+ def __init__(
23
+ self,
24
+ input_size,
25
+ num_masking_patches=None,
26
+ min_num_patches=4,
27
+ max_num_patches=None,
28
+ min_aspect=0.3,
29
+ max_aspect=None,
30
+ ):
31
+ if not isinstance(input_size, tuple):
32
+ input_size = (input_size,) * 2
33
+ self.height, self.width = input_size
34
+
35
+ self.num_patches = self.height * self.width
36
+ self.num_masking_patches = num_masking_patches
37
+
38
+ self.min_num_patches = min_num_patches
39
+ self.max_num_patches = num_masking_patches if max_num_patches is None else max_num_patches
40
+
41
+ max_aspect = max_aspect or 1 / min_aspect
42
+ self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
43
+
44
+ def __repr__(self):
45
+ repr_str = "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % (
46
+ self.height,
47
+ self.width,
48
+ self.min_num_patches,
49
+ self.max_num_patches,
50
+ self.num_masking_patches,
51
+ self.log_aspect_ratio[0],
52
+ self.log_aspect_ratio[1],
53
+ )
54
+ return repr_str
55
+
56
+ def get_shape(self):
57
+ return self.height, self.width
58
+
59
+ def _mask(self, mask, max_mask_patches):
60
+ delta = 0
61
+ for attempt in range(10):
62
+ target_area = random.uniform(self.min_num_patches, max_mask_patches)
63
+ aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
64
+ h = int(round(math.sqrt(target_area * aspect_ratio)))
65
+ w = int(round(math.sqrt(target_area / aspect_ratio)))
66
+ if w < self.width and h < self.height:
67
+ top = random.randint(0, self.height - h)
68
+ left = random.randint(0, self.width - w)
69
+
70
+ num_masked = mask[top : top + h, left : left + w].sum()
71
+ # Overlap
72
+ if 0 < h * w - num_masked <= max_mask_patches:
73
+ for i in range(top, top + h):
74
+ for j in range(left, left + w):
75
+ if mask[i, j] == 0:
76
+ mask[i, j] = 1
77
+ delta += 1
78
+
79
+ if delta > 0:
80
+ break
81
+ return delta
82
+
83
+ def __call__(self, num_masking_patches=0):
84
+ mask = np.zeros(shape=self.get_shape(), dtype=np.bool)
85
+ mask_count = 0
86
+ while mask_count < num_masking_patches:
87
+ max_mask_patches = num_masking_patches - mask_count
88
+ max_mask_patches = min(max_mask_patches, self.max_num_patches)
89
+
90
+ delta = self._mask(mask, max_mask_patches)
91
+ if delta == 0:
92
+ break
93
+ else:
94
+ mask_count += delta
95
+
96
+ return mask
97
+
98
+
99
+ def resize(input,
100
+ size=None,
101
+ scale_factor=None,
102
+ mode='nearest',
103
+ align_corners=None,
104
+ warning=False):
105
+ if warning:
106
+ if size is not None and align_corners:
107
+ input_h, input_w = tuple(int(x) for x in input.shape[2:])
108
+ output_h, output_w = tuple(int(x) for x in size)
109
+ if output_h > input_h or output_w > output_h:
110
+ if ((output_h > 1 and output_w > 1 and input_h > 1
111
+ and input_w > 1) and (output_h - 1) % (input_h - 1)
112
+ and (output_w - 1) % (input_w - 1)):
113
+ warnings.warn(
114
+ f'When align_corners={align_corners}, '
115
+ 'the output would more aligned if '
116
+ f'input size {(input_h, input_w)} is `x+1` and '
117
+ f'out size {(output_h, output_w)} is `nx+1`')
118
+
119
+ return F.interpolate(input, size, scale_factor, mode, align_corners)
120
+
121
+
122
+ class Mlp(nn.Module):
123
+ def __init__(
124
+ self,
125
+ in_features: int,
126
+ hidden_features: Optional[int] = None,
127
+ out_features: Optional[int] = None,
128
+ act_layer: Callable[..., nn.Module] = nn.GELU(),
129
+ drop: float = 0.0,
130
+ ) -> None:
131
+ super().__init__()
132
+ out_features = out_features or in_features
133
+ hidden_features = hidden_features or in_features
134
+ self.fc1 = nn.Linear(in_features, hidden_features)
135
+ self.act = act_layer()
136
+ self.fc2 = nn.Linear(hidden_features, out_features)
137
+ self.drop = nn.Dropout(drop)
138
+
139
+ def forward(self, x: Tensor) -> Tensor:
140
+ x = self.fc1(x)
141
+ x = self.act(x)
142
+ x = self.drop(x)
143
+ x = self.fc2(x)
144
+ x = self.drop(x)
145
+ return x
146
+
147
+
148
+ class Attention(nn.Module):
149
+ def __init__(
150
+ self,
151
+ dim: int,
152
+ num_heads: int = 8,
153
+ qkv_bias: bool = False,
154
+ attn_drop: float = 0.0,
155
+ proj_drop: float = 0.0,
156
+ ) -> None:
157
+ super().__init__()
158
+ self.num_heads = num_heads
159
+ head_dim = dim // num_heads
160
+ self.scale = head_dim**-0.5
161
+
162
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
163
+ self.attn_drop = nn.Dropout(attn_drop)
164
+ self.proj = nn.Linear(dim, dim)
165
+ self.proj_drop = nn.Dropout(proj_drop)
166
+
167
+ def forward(self, x: Tensor) -> Tensor:
168
+ B, N, C = x.shape
169
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
170
+
171
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
172
+ attn = q @ k.transpose(-2, -1)
173
+
174
+ attn = attn.softmax(dim=-1)
175
+ attn = self.attn_drop(attn)
176
+
177
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
178
+ x = self.proj(x)
179
+ x = self.proj_drop(x)
180
+ return x
181
+
182
+
183
+
184
+ class LayerScale(nn.Module):
185
+ def __init__(
186
+ self,
187
+ dim: int,
188
+ init_values: Union[float, Tensor] = 1e-5,
189
+ inplace: bool = False,
190
+ ) -> None:
191
+ super().__init__()
192
+ self.inplace = inplace
193
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
194
+
195
+ def forward(self, x: Tensor) -> Tensor:
196
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
197
+
198
+
199
+ class Block(nn.Module):
200
+ def __init__(
201
+ self,
202
+ dim: int,
203
+ num_heads: int,
204
+ mlp_ratio: float = 4.0,
205
+ qkv_bias: bool = False,
206
+ drop: float = 0.0,
207
+ attn_drop: float = 0.0,
208
+ init_values=None,
209
+ drop_path: float = 0.0,
210
+ act_layer: Callable[..., nn.Module] = nn.GELU(),
211
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
212
+ attn_class: Callable[..., nn.Module] = Attention,
213
+ ffn_layer: Callable[..., nn.Module] = Mlp,
214
+ ) -> None:
215
+ super().__init__()
216
+ self.norm1 = norm_layer(dim)
217
+ self.attn = attn_class(
218
+ dim,
219
+ num_heads=num_heads,
220
+ qkv_bias=qkv_bias,
221
+ attn_drop=attn_drop,
222
+ proj_drop=drop,
223
+ )
224
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
225
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
226
+
227
+ self.norm2 = norm_layer(dim)
228
+ mlp_hidden_dim = int(dim * mlp_ratio)
229
+ self.mlp = ffn_layer(
230
+ in_features=dim,
231
+ hidden_features=mlp_hidden_dim,
232
+ act_layer=act_layer,
233
+ drop=drop,
234
+ )
235
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
236
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
237
+
238
+ self.sample_drop_ratio = drop_path
239
+
240
+ def forward(self, x: Tensor) -> Tensor:
241
+ #pdb.set_trace()
242
+ def attn_residual_func(x: Tensor) -> Tensor:
243
+ return self.ls1(self.attn(self.norm1(x)))
244
+
245
+ def ffn_residual_func(x: Tensor) -> Tensor:
246
+ return self.ls2(self.mlp(self.norm2(x)))
247
+
248
+ if self.training and self.sample_drop_ratio > 0.1:
249
+ x = drop_add_residual_stochastic_depth(
250
+ x,
251
+ residual_func=attn_residual_func,
252
+ sample_drop_ratio=self.sample_drop_ratio,
253
+ )
254
+ x = drop_add_residual_stochastic_depth(
255
+ x,
256
+ residual_func=ffn_residual_func,
257
+ sample_drop_ratio=self.sample_drop_ratio,
258
+ )
259
+ elif self.training and self.sample_drop_ratio > 0.0:
260
+ x = x + self.drop_path1(attn_residual_func(x))
261
+ x = x + self.drop_path1(ffn_residual_func(x))
262
+ else:
263
+ x = x + attn_residual_func(x)
264
+ x = x + ffn_residual_func(x)
265
+ return x
266
+
267
+
268
+ def make_2tuple(x):
269
+ if isinstance(x, tuple):
270
+ assert len(tuple) == 2
271
+ return x
272
+
273
+ assert isinstance(x, int)
274
+ return (x, x)
275
+
276
+
277
+ class PatchEmbed(nn.Module):
278
+ """
279
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
280
+
281
+ Args:
282
+ img_size: Image size.
283
+ patch_size: Patch token size.
284
+ in_chans: Number of input image channels.
285
+ embed_dim: Number of linear projection output channels.
286
+ norm_layer: Normalization layer.
287
+ """
288
+
289
+ def __init__(
290
+ self,
291
+ img_size: Union[int, Tuple[int, int]] = 224,
292
+ patch_size: Union[int, Tuple[int, int]] = 16,
293
+ in_chans: int = 3,
294
+ embed_dim: int = 768,
295
+ norm_layer: Optional[Callable] = None,
296
+ ) -> None:
297
+ super().__init__()
298
+
299
+ image_HW = make_2tuple(img_size)
300
+ patch_HW = make_2tuple(patch_size)
301
+ patch_grid_size = (
302
+ image_HW[0] // patch_HW[0],
303
+ image_HW[1] // patch_HW[1],
304
+ )
305
+
306
+ self.img_size = image_HW
307
+ self.patch_size = patch_HW
308
+ self.patches_resolution = patch_grid_size
309
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
310
+
311
+ self.in_chans = in_chans
312
+ self.embed_dim = embed_dim
313
+
314
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
315
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
316
+
317
+
318
+ def forward(self, x: Tensor) -> Tensor:
319
+ _, _, H, W = x.shape
320
+ patch_H, patch_W = self.patch_size
321
+
322
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
323
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
324
+
325
+ x = self.proj(x)
326
+ x = x.flatten(2).transpose(1, 2)
327
+ x = self.norm(x)
328
+ return x
329
+
330
+ def flops(self) -> float:
331
+ Ho, Wo = self.patches_resolution
332
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
333
+ if self.norm is not None:
334
+ flops += Ho * Wo * self.embed_dim
335
+ return flops
336
+
337
+
338
+ class DinoVisionTransformer(nn.Module):
339
+ """Vision Transformer
340
+
341
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
342
+ - https://arxiv.org/abs/2010.11929
343
+ """
344
+
345
+ def __init__(
346
+ self,
347
+ img_size=224,
348
+ patch_size=16,
349
+ in_chans=3,
350
+ num_classes=0,
351
+ global_pool="token",
352
+ embed_dim=1024,
353
+ depth=24,
354
+ num_heads=16,
355
+ mlp_ratio=4.0,
356
+ qkv_bias=True,
357
+ representation_size=None,
358
+ drop_rate=0.0,
359
+ attn_drop_rate=0.0,
360
+ drop_path_rate=0.0,
361
+ weight_init="",
362
+ init_values=1.,
363
+ embed_layer=PatchEmbed,
364
+ norm_layer=None,
365
+ act_layer=None,
366
+ block_fn=Block,
367
+ ffn_layer="mlp",
368
+ drop_path_uniform=False,
369
+ patch_drop=0.0,
370
+ sin_cos_embeddings=False,
371
+ local_crops_size=96,
372
+ multiple_pos_embeddings=False,
373
+ ):
374
+ """
375
+ Args:
376
+ img_size (int, tuple): input image size
377
+ patch_size (int, tuple): patch size
378
+ in_chans (int): number of input channels
379
+ num_classes (int): number of classes for classification head
380
+ global_pool (str): type of global pooling for final sequence (default: 'token')
381
+ embed_dim (int): embedding dimension
382
+ depth (int): depth of transformer
383
+ num_heads (int): number of attention heads
384
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
385
+ qkv_bias (bool): enable bias for qkv if True
386
+ representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
387
+ drop_rate (float): dropout rate
388
+ attn_drop_rate (float): attention dropout rate
389
+ drop_path_rate (float): stochastic depth rate
390
+ weight_init: (str): weight init scheme
391
+ init_values: (float): layer-scale init values
392
+ embed_layer (nn.Module): patch embedding layer
393
+ norm_layer: (nn.Module): normalization layer
394
+ act_layer: (nn.Module): MLP activation layer
395
+ """
396
+ super().__init__()
397
+ assert global_pool in ("", "avg", "token")
398
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
399
+ act_layer = act_layer or nn.GELU
400
+
401
+ self.num_classes = num_classes
402
+ self.global_pool = global_pool
403
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
404
+ self.num_tokens = 1
405
+ self.grad_checkpointing = False
406
+ self.sin_cos_embeddings = sin_cos_embeddings
407
+ self.multiple_pos_embeddings = multiple_pos_embeddings
408
+
409
+ self.patch_embed = embed_layer(
410
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim
411
+ )
412
+ num_patches = self.patch_embed.num_patches
413
+
414
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
415
+ if self.sin_cos_embeddings:
416
+ self.pos_embed = torch.Tensor(())
417
+ logger.info("using sin-cos fixed embeddings")
418
+ pass
419
+ elif self.multiple_pos_embeddings:
420
+ logger.info("using multiple position embeddings (one for global one for local)")
421
+ self.pos_embeds = nn.ParameterDict()
422
+ self.pos_embeds[str(img_size)] = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
423
+ n_local_patches = (local_crops_size // patch_size) ** 2
424
+ self.pos_embeds[str(local_crops_size)] = nn.Parameter(torch.zeros(1, n_local_patches, embed_dim))
425
+ self.pos_embed = None
426
+ else:
427
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
428
+ self.pos_drop = nn.Dropout(p=drop_rate)
429
+
430
+ if drop_path_uniform is True:
431
+ dpr = [drop_path_rate] * depth
432
+ else:
433
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
434
+
435
+ if ffn_layer == "mlp":
436
+ #print("using MLP layer as FFN")
437
+ ffn_layer = Mlp
438
+ elif ffn_layer == "swiglu":
439
+ #print("using SwiGLU layer as FFN")
440
+ ffn_layer = SwiGLUFFN
441
+ elif ffn_layer == "identity":
442
+ #print("using Identity layer as FFN")
443
+ def f(*args, **kwargs):
444
+ return nn.Identity()
445
+ ffn_layer = f
446
+ else:
447
+ raise NotImplementedError
448
+
449
+ self.blocks = nn.ModuleList(
450
+ [
451
+ block_fn(
452
+ dim=embed_dim,
453
+ num_heads=num_heads,
454
+ mlp_ratio=mlp_ratio,
455
+ qkv_bias=qkv_bias,
456
+ drop=drop_rate,
457
+ attn_drop=attn_drop_rate,
458
+ drop_path=dpr[i],
459
+ norm_layer=norm_layer,
460
+ act_layer=act_layer,
461
+ ffn_layer=ffn_layer,
462
+ init_values=init_values,
463
+ )
464
+ for i in range(depth)
465
+ ]
466
+ )
467
+
468
+ use_fc_norm = self.global_pool == "avg"
469
+ self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
470
+
471
+ # Representation layer. Used for original ViT models w/ in21k pretraining.
472
+ self.representation_size = representation_size
473
+ self.pre_logits = nn.Identity()
474
+ if representation_size:
475
+ self._reset_representation(representation_size)
476
+
477
+ # Classifier Head
478
+ self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
479
+ final_chs = self.representation_size if self.representation_size else self.embed_dim
480
+ self.head = nn.Linear(final_chs, num_classes) if num_classes > 0 else nn.Identity()
481
+
482
+ self.mask_generator = MaskingGenerator(
483
+ input_size=(img_size // patch_size, img_size // patch_size),
484
+ max_num_patches=0.5 * img_size // patch_size * img_size // patch_size,
485
+ )
486
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
487
+
488
+ # if weight_init != "skip":
489
+ # self.init_weights(weight_init)
490
+
491
+ def _reset_representation(self, representation_size):
492
+ self.representation_size = representation_size
493
+ if self.representation_size:
494
+ self.pre_logits = nn.Sequential(
495
+ OrderedDict([("fc", nn.Linear(self.embed_dim, self.representation_size)), ("act", nn.Tanh())])
496
+ )
497
+ else:
498
+ self.pre_logits = nn.Identity()
499
+
500
+ def init_weights(self, mode=""):
501
+ assert mode in ("jax", "jax_nlhb", "moco", "")
502
+ head_bias = -math.log(self.num_classes) if "nlhb" in mode else 0.0
503
+ if self.pos_embed is not None:
504
+ trunc_normal_(self.pos_embed, std=0.02)
505
+ elif self.pos_embeds:
506
+ for v in self.pos_embeds.values():
507
+ trunc_normal_(v, std=0.02)
508
+ nn.init.normal_(self.cls_token, std=1e-6)
509
+ named_apply(get_init_weights_vit(mode, head_bias), self)
510
+
511
+ def _init_weights(self, m):
512
+ # this fn left here for compat with downstream users
513
+ init_weights_vit_timm(m)
514
+
515
+ @torch.jit.ignore()
516
+ def load_pretrained(self, checkpoint_path, prefix=""):
517
+ _load_weights(self, checkpoint_path, prefix)
518
+
519
+ @torch.jit.ignore
520
+ def no_weight_decay(self):
521
+ return {"pos_embed", "cls_token", "dist_token"}
522
+
523
+ @torch.jit.ignore
524
+ def group_matcher(self, coarse=False):
525
+ return dict(
526
+ stem=r"^cls_token|pos_embed|patch_embed", # stem and embed
527
+ blocks=[(r"^blocks\.(\d+)", None), (r"^norm", (99999,))],
528
+ )
529
+
530
+ @torch.jit.ignore
531
+ def set_grad_checkpointing(self, enable=True):
532
+ self.grad_checkpointing = enable
533
+
534
+ @torch.jit.ignore
535
+ def get_classifier(self):
536
+ return self.head
537
+
538
+ def reset_classifier(self, num_classes: int, global_pool=None, representation_size=None):
539
+ self.num_classes = num_classes
540
+ if global_pool is not None:
541
+ assert global_pool in ("", "avg", "token")
542
+ self.global_pool = global_pool
543
+ if representation_size is not None:
544
+ self._reset_representation(representation_size)
545
+ final_chs = self.representation_size if self.representation_size else self.embed_dim
546
+ self.head = nn.Linear(final_chs, num_classes) if num_classes > 0 else nn.Identity()
547
+
548
+ def forward_head(self, x, pre_logits: bool = False):
549
+ if self.global_pool:
550
+ x = x[:, 1:].mean(dim=1) if self.global_pool == "avg" else x[:, 0]
551
+ x = self.fc_norm(x)
552
+ x = self.pre_logits(x)
553
+ return x if pre_logits else self.head(x)
554
+
555
+ def interpolate_pos_encoding(self, x, w, h):
556
+ if self.sin_cos_embeddings:
557
+
558
+ w0 = w // self.patch_embed.patch_size[0]
559
+ step_coef = (w0-1) / 3.14
560
+ omega_coef = 10000
561
+ sin_cos_embed = get_2d_sincos_pos_embed_cached_device(
562
+ embed_dim=x.shape[-1], grid_size=w0, step_coef=step_coef, omega_coef=omega_coef, device=x.device, cls_token=True
563
+ )
564
+
565
+ return sin_cos_embed
566
+ elif self.multiple_pos_embeddings:
567
+
568
+ _m = sum((v.mean() * 0 for v in self.pos_embeds.values()))
569
+ pos_embed = self.pos_embeds[str(w)] + _m
570
+ class_pos_embed = torch.zeros_like(pos_embed[:1,:1])
571
+ return torch.cat((class_pos_embed, pos_embed), dim=1)
572
+ else:
573
+ npatch = x.shape[1] - 1
574
+ N = self.pos_embed.shape[1] - 1
575
+ if npatch == N and w == h:
576
+ return self.pos_embed
577
+ class_pos_embed = self.pos_embed[:, 0]
578
+ patch_pos_embed = self.pos_embed[:, 1:]
579
+ dim = x.shape[-1]
580
+ w0 = w // self.patch_embed.patch_size[0]
581
+ h0 = h // self.patch_embed.patch_size[0]
582
+ # we add a small number to avoid floating point error in the interpolation
583
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
584
+ w0, h0 = w0 + 0.1, h0 + 0.1
585
+
586
+ patch_pos_embed = nn.functional.interpolate(
587
+ patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
588
+ scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
589
+ mode="bicubic", align_corners=True, recompute_scale_factor=True
590
+ )
591
+
592
+ assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
593
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
594
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
595
+
596
+ def mask_patches_with_probability_p(self, x, mask_ratio_tuple, p):
597
+ B, N, _ = x.shape
598
+ n_samples_masked = int(B * p)
599
+ mask_ratio_min, mask_ratio_max = mask_ratio_tuple
600
+ masks = torch.stack(
601
+ [
602
+ torch.BoolTensor(self.mask_generator(int(N * random.uniform(mask_ratio_min, mask_ratio_max))))
603
+ for _ in range(0, n_samples_masked)
604
+ ]
605
+ + [torch.BoolTensor(self.mask_generator(0)) for _ in range(n_samples_masked, B)]
606
+ ).to(
607
+ x.device
608
+ )
609
+ masks = masks[torch.randperm(B, device=x.device)].flatten(1)
610
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
611
+
612
+ return x, masks
613
+
614
+ def mask_patches_with_probability_p_upperbound(self, x, mask_ratio_tuple, p):
615
+ B, N, _ = x.shape
616
+ n_samples_masked = int(B * p)
617
+ probs = torch.linspace(*mask_ratio_tuple, n_samples_masked + 1)
618
+ upperbound = 0
619
+ masks_list = []
620
+ for i in range(0, n_samples_masked):
621
+ prob_min = probs[i]
622
+ prob_max = probs[i+1]
623
+ masks_list.append(torch.BoolTensor(self.mask_generator(int(N * random.uniform(prob_min, prob_max)))))
624
+ upperbound += int(N * prob_max)
625
+ for i in range(n_samples_masked, B):
626
+ masks_list.append(torch.BoolTensor(self.mask_generator(0)))
627
+ masks = torch.stack(masks_list).to(x.device)
628
+ masks = masks[torch.randperm(B, device=x.device)].flatten(1)
629
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
630
+
631
+ return x, masks, upperbound
632
+
633
+ def prepare_tokens(self, x, mask_ratio_tuple=(0.0, 0.0), mask_sample_probability=0.0, ibot_balanced_masking=False):
634
+ B, nc, w, h = x.shape
635
+ x = self.patch_embed(x)
636
+ masks = None
637
+ n_masked_patches_upperbound = None
638
+ cls_token = self.cls_token
639
+ do_ibot = max(mask_ratio_tuple) > 0.0 and mask_sample_probability > 0.0
640
+ if do_ibot:
641
+ if ibot_balanced_masking:
642
+ logger.debug("using balanced masking")
643
+ x, masks, n_masked_patches_upperbound = self.mask_patches_with_probability_p_upperbound(
644
+ x, mask_ratio_tuple=mask_ratio_tuple, p=mask_sample_probability
645
+ )
646
+ else:
647
+ logger.debug("not using balanced masking")
648
+ x, masks = self.mask_patches_with_probability_p(
649
+ x, mask_ratio_tuple=mask_ratio_tuple, p=mask_sample_probability
650
+ )
651
+ else:
652
+ cls_token = cls_token + 0 * self.mask_token # hack to use the mask_token param to not crash ddp...
653
+
654
+ x = torch.cat((cls_token.expand(x.shape[0], -1, -1), x), dim=1)
655
+ x = self.pos_drop(x + self.interpolate_pos_encoding(x, w, h))
656
+
657
+ return x, masks, n_masked_patches_upperbound
658
+
659
+ def forward_features(self, x, mask_ratio_tuple=(0.0, 0.0), mask_sample_probability=0.0, ibot_balanced_masking=False):
660
+ x, masks, n_masked_patches_upperbound = self.prepare_tokens(x, mask_ratio_tuple, mask_sample_probability, ibot_balanced_masking)
661
+
662
+ for blk in self.blocks:
663
+ x = blk(x)
664
+
665
+ x_norm = self.norm(x)
666
+ return {
667
+ "x_norm_clstoken": x_norm[:, 0],
668
+ "x_norm_patchtokens": x_norm[:, 1:],
669
+ "x_prenorm": x,
670
+ "masks": masks,
671
+ "n_masked_patches_upperbound": n_masked_patches_upperbound,
672
+ }
673
+
674
+ def get_intermediate_layers(self, x, n=1):
675
+ x, _, _ = self.prepare_tokens(x)
676
+ # we return the output tokens from the `n` last blocks
677
+ output = []
678
+ for i, blk in enumerate(self.blocks):
679
+ x = blk(x)
680
+ if len(self.blocks) - i <= n:
681
+ output.append(self.norm(x))
682
+ return output
683
+
684
+ def forward(self, *args, is_training=False, **kwargs):
685
+ ret = self.forward_features(*args, **kwargs)
686
+ if is_training:
687
+ return ret
688
+ else:
689
+ return ret["x_norm_clstoken"]
690
+
691
+
692
+
693
+ class AdaptivePadding(nn.Module):
694
+ """Applies padding to input (if needed) so that input can get fully covered
695
+ by filter you specified. It support two modes "same" and "corner". The
696
+ "same" mode is same with "SAME" padding mode in TensorFlow, pad zero around
697
+ input. The "corner" mode would pad zero to bottom right.
698
+ Args:
699
+ kernel_size (int | tuple): Size of the kernel:
700
+ stride (int | tuple): Stride of the filter. Default: 1:
701
+ dilation (int | tuple): Spacing between kernel elements.
702
+ Default: 1.
703
+ padding (str): Support "same" and "corner", "corner" mode
704
+ would pad zero to bottom right, and "same" mode would
705
+ pad zero around input. Default: "corner".
706
+ Example:
707
+ >>> kernel_size = 16
708
+ >>> stride = 16
709
+ >>> dilation = 1
710
+ >>> input = torch.rand(1, 1, 15, 17)
711
+ >>> adap_pad = AdaptivePadding(
712
+ >>> kernel_size=kernel_size,
713
+ >>> stride=stride,
714
+ >>> dilation=dilation,
715
+ >>> padding="corner")
716
+ >>> out = adap_pad(input)
717
+ >>> assert (out.shape[2], out.shape[3]) == (16, 32)
718
+ >>> input = torch.rand(1, 1, 16, 17)
719
+ >>> out = adap_pad(input)
720
+ >>> assert (out.shape[2], out.shape[3]) == (16, 32)
721
+ """
722
+
723
+ def __init__(self, kernel_size=1, stride=1, dilation=1, padding='corner'):
724
+
725
+ super(AdaptivePadding, self).__init__()
726
+
727
+ assert padding in ('same', 'corner')
728
+
729
+ kernel_size = to_2tuple(kernel_size)
730
+ stride = to_2tuple(stride)
731
+ dilation = to_2tuple(dilation)
732
+
733
+ self.padding = padding
734
+ self.kernel_size = kernel_size
735
+ self.stride = stride
736
+ self.dilation = dilation
737
+
738
+ def get_pad_shape(self, input_shape):
739
+ input_h, input_w = input_shape
740
+ kernel_h, kernel_w = self.kernel_size
741
+ stride_h, stride_w = self.stride
742
+ output_h = math.ceil(input_h / stride_h)
743
+ output_w = math.ceil(input_w / stride_w)
744
+ pad_h = max((output_h - 1) * stride_h +
745
+ (kernel_h - 1) * self.dilation[0] + 1 - input_h, 0)
746
+ pad_w = max((output_w - 1) * stride_w +
747
+ (kernel_w - 1) * self.dilation[1] + 1 - input_w, 0)
748
+ return pad_h, pad_w
749
+
750
+ def forward(self, x):
751
+ pad_h, pad_w = self.get_pad_shape(x.size()[-2:])
752
+ if pad_h > 0 or pad_w > 0:
753
+ if self.padding == 'corner':
754
+ x = F.pad(x, [0, pad_w, 0, pad_h])
755
+ elif self.padding == 'same':
756
+ x = F.pad(x, [
757
+ pad_w // 2, pad_w - pad_w // 2, pad_h // 2,
758
+ pad_h - pad_h // 2
759
+ ])
760
+ return x
761
+
762
+
763
+
764
+ class SSLVisionTransformer(DinoVisionTransformer):
765
+ """Vision Transformer.
766
+ """
767
+
768
+ def __init__(self,
769
+ interpolate_mode='bicubic',
770
+ init_cfg=None,
771
+ pretrained=None,
772
+ img_size=224,
773
+ patch_size=16,
774
+ #embed_dim=1024,
775
+ #depth=24,
776
+ #num_heads=16,
777
+ mlp_ratio=4,
778
+ qkv_bias=True,
779
+ init_values=1.,
780
+ out_indices=(4, 11, 17, 23),
781
+ final_norm=False,
782
+ with_cls_token=True,
783
+ output_cls_token=True,
784
+ frozen_stages=100,
785
+ *args, **kwargs):
786
+ super(SSLVisionTransformer, self).__init__(*args, **kwargs)
787
+
788
+ if output_cls_token:
789
+ assert with_cls_token is True, f'with_cls_token must be True if' \
790
+ f'set output_cls_token to True, but got {with_cls_token}'
791
+
792
+ assert not (init_cfg and pretrained), \
793
+ 'init_cfg and pretrained cannot be set at the same time'
794
+ if isinstance(pretrained, str):
795
+ warnings.warn('DeprecationWarning: pretrained is deprecated, '
796
+ 'please use "init_cfg" instead')
797
+ self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
798
+ elif pretrained is not None:
799
+ raise TypeError('pretrained must be a str or None')
800
+
801
+
802
+ if len(self.blocks)==1:
803
+ self.blocks = self.blocks[0]
804
+ if isinstance(out_indices, int):
805
+ if out_indices == -1:
806
+ out_indices = len(self.blocks) - 1
807
+ self.out_indices = [out_indices]
808
+ elif isinstance(out_indices, list) or isinstance(out_indices, tuple):
809
+ self.out_indices = out_indices
810
+ else:
811
+ raise TypeError('out_indices must be type of int, list or tuple')
812
+
813
+ self.interpolate_mode = interpolate_mode
814
+ self.pretrained = pretrained
815
+ self.frozen_stages = frozen_stages
816
+ self.detach = False
817
+ self.with_cls_token = with_cls_token
818
+ self.output_cls_token = output_cls_token
819
+ self.final_norm = final_norm
820
+ self.patch_size = self.patch_embed.patch_size
821
+ self.adapad = AdaptivePadding(kernel_size=self.patch_size, stride=self.patch_size, padding='same')
822
+ if pretrained:
823
+ self.init_weights(pretrained)
824
+
825
+ self._freeze_stages()
826
+
827
+ @staticmethod
828
+ def resize_pos_embed(pos_embed, input_shpae, pos_shape, mode):
829
+ """Resize pos_embed weights.
830
+ Resize pos_embed using bicubic interpolate method.
831
+ Args:
832
+ pos_embed (torch.Tensor): Position embedding weights.
833
+ input_shpae (tuple): Tuple for (downsampled input image height,
834
+ downsampled input image width).
835
+ pos_shape (tuple): The resolution of downsampled origin training
836
+ image.
837
+ mode (str): Algorithm used for upsampling:
838
+ ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |
839
+ ``'trilinear'``. Default: ``'nearest'``
840
+ Return:
841
+ torch.Tensor: The resized pos_embed of shape [B, L_new, C]
842
+ """
843
+ assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]'
844
+ pos_h, pos_w = pos_shape
845
+ cls_token_weight = pos_embed[:, 0]
846
+ pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):]
847
+ pos_embed_weight = pos_embed_weight.reshape(
848
+ 1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2)
849
+ pos_embed_weight = resize(
850
+ pos_embed_weight, size=input_shpae, align_corners=False, mode=mode)
851
+ cls_token_weight = cls_token_weight.unsqueeze(1)
852
+ pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2)
853
+ pos_embed = torch.cat((cls_token_weight, pos_embed_weight), dim=1)
854
+ return pos_embed
855
+
856
+ def init_weights(self, pretrained):
857
+ print("init_weights", pretrained)
858
+ if (isinstance(self.init_cfg, dict)
859
+ and self.init_cfg.get('type') == 'Pretrained'):
860
+
861
+ checkpoint = torch.load(pretrained, map_location='cpu')
862
+ if 'state_dict' in checkpoint:
863
+ # timm checkpoint
864
+ state_dict = checkpoint['state_dict']
865
+ elif 'model' in checkpoint:
866
+ # deit checkpoint
867
+ state_dict = checkpoint['model']
868
+ elif 'teacher' in checkpoint:
869
+ # dino eval checkpoint
870
+ state_dict = checkpoint['teacher']
871
+ else:
872
+ state_dict = checkpoint
873
+
874
+ if len([k for k in state_dict.keys() if 'teacher.backbone.' in k]) > 0:
875
+ state_dict = {k.replace('teacher.backbone.', ''):v for k,v in state_dict.items() if 'teacher.backbone' in k}
876
+ if len([k for k in state_dict.keys() if 'backbone.' in k]) > 0:
877
+ state_dict = {k.replace('backbone.', ''):v for k,v in state_dict.items()}
878
+
879
+ if 'pos_embed' in state_dict.keys():
880
+ if self.pos_embed.shape != state_dict['pos_embed'].shape:
881
+ print(f'Resize the pos_embed shape from '
882
+ f'{state_dict["pos_embed"].shape} to '
883
+ f'{self.pos_embed.shape}')
884
+ h, w = (224, 224) # self.img_size
885
+ pos_size = int(
886
+ math.sqrt(state_dict['pos_embed'].shape[1] - 1))
887
+ state_dict['pos_embed'] = self.resize_pos_embed(
888
+ state_dict['pos_embed'],
889
+ (h // self.patch_size[0], w // self.patch_size[1]),
890
+ (pos_size, pos_size), self.interpolate_mode)
891
+ self.load_state_dict(state_dict)
892
+ else:
893
+ super(SSLVisionTransformer, self).init_weights()
894
+
895
+
896
+ def forward(self, x):
897
+
898
+ with torch.set_grad_enabled(not self.detach):
899
+ _, _, old_w, old_h = x.shape
900
+ xx = self.adapad(x)
901
+
902
+ x = F.pad(x, (0, xx.shape[-1] - x.shape[-1], 0, xx.shape[-2] - x.shape[-2]))
903
+ B, nc, w, h = x.shape
904
+
905
+ x, _, _ = self.prepare_tokens(x)
906
+ # we return the output tokens from the `n` last blocks
907
+ outs = []
908
+ for i, blk in enumerate(self.blocks):
909
+ x = blk(x)
910
+ if i in self.out_indices:
911
+ if self.with_cls_token:
912
+ out = x[:, 1:]
913
+ else:
914
+ out = x
915
+ B, _, C = out.shape
916
+ out = out.reshape(B, w // self.patch_size[0], h // self.patch_size[1],
917
+ C).permute(0, 3, 1, 2).contiguous()
918
+ if self.output_cls_token:
919
+ out = [out, x[:, 0]]
920
+ else:
921
+ out = [out]
922
+ if self.final_norm:
923
+ out = [self.norm(o) for o in out]
924
+ if self.detach:
925
+ out = [o.detach() for o in out]
926
+ outs.append(out)
927
+ return tuple(outs)
928
+
929
+ def train(self, mode=True):
930
+ super(SSLVisionTransformer, self).train(mode)
931
+ self.detach = False
932
+ self._freeze_stages()
933
+
934
+ def _freeze_stages(self):
935
+ """Freeze stages param and norm stats."""
936
+ if self.frozen_stages >= 0:
937
+ self.patch_embed.eval()
938
+ for m in [self.patch_embed]:
939
+ for param in m.parameters():
940
+ param.requires_grad = False
941
+ self.cls_token.requires_grad = False
942
+ self.pos_embed.requires_grad = False
943
+ self.mask_token.requires_grad = False
944
+
945
+ if self.frozen_stages >= len(self.blocks) - 1:
946
+ self.norm.eval()
947
+ for param in self.norm.parameters():
948
+ param.requires_grad = False
949
+ self.detach = True
950
+
951
+ for i, layer in enumerate(self.blocks):
952
+ if i <= self.frozen_stages:
953
+ layer.eval()
954
+ for param in layer.parameters():
955
+ param.requires_grad = False
956
+
957
+
models/dpt_head.py ADDED
@@ -0,0 +1,536 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ import torch
8
+ from torch import nn
9
+ import torchvision
10
+
11
+ from models.backbone import resize
12
+
13
+ def kaiming_init(module: nn.Module,
14
+ a: float = 0,
15
+ mode: str = 'fan_out',
16
+ nonlinearity: str = 'relu',
17
+ bias: float = 0,
18
+ distribution: str = 'normal') -> None:
19
+ assert distribution in ['uniform', 'normal']
20
+ if hasattr(module, 'weight') and module.weight is not None:
21
+ if distribution == 'uniform':
22
+ nn.init.kaiming_uniform_(
23
+ module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
24
+ else:
25
+ nn.init.kaiming_normal_(
26
+ module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
27
+ if hasattr(module, 'bias') and module.bias is not None:
28
+ nn.init.constant_(module.bias, bias)
29
+
30
+ class ConvModule(nn.Module):
31
+ """A conv block that bundles conv/norm/activation layers.
32
+ This block simplifies the usage of convolution layers, which are commonly
33
+ used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU).
34
+ It is based upon three build methods: `build_conv_layer()`,
35
+ `build_norm_layer()` and `build_activation_layer()`.
36
+ Besides, we add some additional features in this module.
37
+ 1. Automatically set `bias` of the conv layer.
38
+ 2. Spectral norm is supported.
39
+ 3. More padding modes are supported. Before PyTorch 1.5, nn.Conv2d only
40
+ supports zero and circular padding, and we add "reflect" padding mode.
41
+ Args:
42
+ in_channels (int): Number of channels in the input feature map.
43
+ Same as that in ``nn._ConvNd``.
44
+ out_channels (int): Number of channels produced by the convolution.
45
+ Same as that in ``nn._ConvNd``.
46
+ kernel_size (int | tuple[int]): Size of the convolving kernel.
47
+ Same as that in ``nn._ConvNd``.
48
+ stride (int | tuple[int]): Stride of the convolution.
49
+ Same as that in ``nn._ConvNd``.
50
+ padding (int | tuple[int]): Zero-padding added to both sides of
51
+ the input. Same as that in ``nn._ConvNd``.
52
+ dilation (int | tuple[int]): Spacing between kernel elements.
53
+ Same as that in ``nn._ConvNd``.
54
+ groups (int): Number of blocked connections from input channels to
55
+ output channels. Same as that in ``nn._ConvNd``.
56
+ bias (bool | str): If specified as `auto`, it will be decided by the
57
+ norm_cfg. Bias will be set as True if `norm_cfg` is None, otherwise
58
+ False. Default: "auto".
59
+ conv_cfg (dict): Config dict for convolution layer. Default: None,
60
+ which means using conv2d.
61
+ norm_cfg (dict): Config dict for normalization layer. Default: None.
62
+ act_cfg (dict): Config dict for activation layer.
63
+ Default: dict(type='ReLU').
64
+ inplace (bool): Whether to use inplace mode for activation.
65
+ Default: True.
66
+ with_spectral_norm (bool): Whether use spectral norm in conv module.
67
+ Default: False.
68
+ padding_mode (str): If the `padding_mode` has not been supported by
69
+ current `Conv2d` in PyTorch, we will use our own padding layer
70
+ instead. Currently, we support ['zeros', 'circular'] with official
71
+ implementation and ['reflect'] with our own implementation.
72
+ Default: 'zeros'.
73
+ order (tuple[str]): The order of conv/norm/activation layers. It is a
74
+ sequence of "conv", "norm" and "act". Common examples are
75
+ ("conv", "norm", "act") and ("act", "conv", "norm").
76
+ Default: ('conv', 'norm', 'act').
77
+ """
78
+
79
+ _abbr_ = 'conv_block'
80
+
81
+ def __init__(self,
82
+ in_channels,
83
+ out_channels,
84
+ kernel_size,
85
+ stride = 1,
86
+ padding = 0,
87
+ dilation = 1,
88
+ groups = 1,
89
+ bias = 'auto',
90
+ conv_cfg = None,
91
+ norm_cfg = None,
92
+ act_cfg = dict(type='ReLU'),
93
+ inplace= True,
94
+ with_spectral_norm = False,
95
+ padding_mode = 'zeros',
96
+ order = ('conv', 'norm', 'act')):
97
+ super().__init__()
98
+ assert conv_cfg is None or isinstance(conv_cfg, dict)
99
+ assert norm_cfg is None or isinstance(norm_cfg, dict)
100
+ assert act_cfg is None or isinstance(act_cfg, dict)
101
+ official_padding_mode = ['zeros', 'circular']
102
+ self.conv_cfg = conv_cfg
103
+ self.norm_cfg = norm_cfg
104
+ self.act_cfg = act_cfg
105
+ self.inplace = inplace
106
+ self.with_spectral_norm = with_spectral_norm
107
+ self.with_explicit_padding = padding_mode not in official_padding_mode
108
+ self.order = order
109
+ assert isinstance(self.order, tuple) and len(self.order) == 3
110
+ assert set(order) == {'conv', 'norm', 'act'}
111
+
112
+ self.with_norm = norm_cfg is not None
113
+ self.with_activation = act_cfg is not None
114
+ # if the conv layer is before a norm layer, bias is unnecessary.
115
+ if bias == 'auto':
116
+ bias = not self.with_norm
117
+ self.with_bias = bias
118
+
119
+ if self.with_explicit_padding:
120
+ pad_cfg = dict(type=padding_mode)
121
+ self.padding_layer = build_padding_layer(pad_cfg, padding)
122
+ # to do Camille put back
123
+
124
+ # reset padding to 0 for conv module
125
+ conv_padding = 0 if self.with_explicit_padding else padding
126
+ # build convolution layer
127
+ self.conv = nn.Conv2d( #build_conv_layer(#conv_cfg,
128
+ in_channels,
129
+ out_channels,
130
+ kernel_size,
131
+ stride=stride,
132
+ padding=conv_padding,
133
+ dilation=dilation,
134
+ groups=groups,
135
+ bias=bias)
136
+ # export the attributes of self.conv to a higher level for convenience
137
+ self.in_channels = self.conv.in_channels
138
+ self.out_channels = self.conv.out_channels
139
+ self.kernel_size = self.conv.kernel_size
140
+ self.stride = self.conv.stride
141
+ self.padding = padding
142
+ self.dilation = self.conv.dilation
143
+ self.transposed = self.conv.transposed
144
+ self.output_padding = self.conv.output_padding
145
+ self.groups = self.conv.groups
146
+
147
+ if self.with_spectral_norm:
148
+ self.conv = nn.utils.spectral_norm(self.conv)
149
+
150
+ self.norm_name = None # type: ignore
151
+
152
+ # build activation layer
153
+ if self.with_activation:
154
+ act_cfg_ = act_cfg.copy() # type: ignore
155
+ # nn.Tanh has no 'inplace' argument
156
+ if act_cfg_['type'] not in [
157
+ 'Tanh', 'PReLU', 'Sigmoid', 'HSigmoid', 'Swish', 'GELU'
158
+ ]:
159
+ act_cfg_.setdefault('inplace', inplace)
160
+ self.activate = nn.ReLU() # build_activation_layer(act_cfg_)
161
+
162
+ # Use msra init by default
163
+ torch.manual_seed(1)
164
+ self.init_weights()
165
+
166
+ @property
167
+ def norm(self):
168
+ if self.norm_name:
169
+ return getattr(self, self.norm_name)
170
+ else:
171
+ return None
172
+
173
+ def init_weights(self):
174
+ # 1. It is mainly for customized conv layers with their own
175
+ # initialization manners by calling their own ``init_weights()``,
176
+ # and we do not want ConvModule to override the initialization.
177
+ # 2. For customized conv layers without their own initialization
178
+ # manners (that is, they don't have their own ``init_weights()``)
179
+ # and PyTorch's conv layers, they will be initialized by
180
+ # this method with default ``kaiming_init``.
181
+ # Note: For PyTorch's conv layers, they will be overwritten by our
182
+ # initialization implementation using default ``kaiming_init``.
183
+ if not hasattr(self.conv, 'init_weights'):
184
+ if self.with_activation and self.act_cfg['type'] == 'LeakyReLU':
185
+ nonlinearity = 'leaky_relu'
186
+ a = self.act_cfg.get('negative_slope', 0.01)
187
+ else:
188
+ nonlinearity = 'relu'
189
+ a = 0
190
+ kaiming_init(self.conv, a=a, nonlinearity=nonlinearity)
191
+ if self.with_norm:
192
+ constant_init(self.norm, 1, bias=0)
193
+
194
+ def forward(self,
195
+ x: torch.Tensor,
196
+ activate: bool = True,
197
+ norm: bool = True,
198
+ debug: bool = False) -> torch.Tensor:
199
+
200
+ for layer in self.order:
201
+ if debug==True:
202
+ breakpoint()
203
+ if layer == 'conv':
204
+ if self.with_explicit_padding:
205
+ x = self.padding_layer(x)
206
+ x = self.conv(x)
207
+ elif layer == 'norm' and norm and self.with_norm:
208
+ x = self.norm(x)
209
+ elif layer == 'act' and activate and self.with_activation:
210
+ x = self.activate(x)
211
+ return x
212
+
213
+
214
+ class Interpolate(nn.Module):
215
+ def __init__(self, scale_factor, mode, align_corners=False):
216
+ super(Interpolate, self).__init__()
217
+ self.interp = nn.functional.interpolate
218
+ self.scale_factor = scale_factor
219
+ self.mode = mode
220
+ self.align_corners = align_corners
221
+
222
+ def forward(self, x):
223
+ x = self.interp(
224
+ x,
225
+ scale_factor=self.scale_factor,
226
+ mode=self.mode,
227
+ align_corners=self.align_corners)
228
+ return x
229
+
230
+ class HeadDepth(nn.Module):
231
+ def __init__(self, features, classify=False, n_bins=256):
232
+ super(HeadDepth, self).__init__()
233
+ self.head = nn.Sequential(
234
+ nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
235
+ Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
236
+ nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
237
+ nn.ReLU(),
238
+ nn.Conv2d(32, 1 if not classify else n_bins, kernel_size=1, stride=1, padding=0),
239
+ )
240
+ def forward(self, x):
241
+ x = self.head(x)
242
+ return x
243
+
244
+
245
+ class ReassembleBlocks(nn.Module):
246
+ """ViTPostProcessBlock, process cls_token in ViT backbone output and
247
+ rearrange the feature vector to feature map.
248
+ Args:
249
+ in_channels (int): ViT feature channels. Default: 768.
250
+ out_channels (List): output channels of each stage.
251
+ Default: [96, 192, 384, 768].
252
+ readout_type (str): Type of readout operation. Default: 'ignore'.
253
+ patch_size (int): The patch size. Default: 16.
254
+ init_cfg (dict, optional): Initialization config dict. Default: None.
255
+ """
256
+ def __init__(self,
257
+ in_channels=1024, #768,
258
+ out_channels=[128, 256, 512, 1024], #[96, 192, 384, 768],
259
+ readout_type='project', # 'ignore',
260
+ patch_size=16):
261
+ super(ReassembleBlocks, self).__init__()#init_cfg)
262
+
263
+ assert readout_type in ['ignore', 'add', 'project']
264
+ self.readout_type = readout_type
265
+ self.patch_size = patch_size
266
+
267
+ self.projects = nn.ModuleList([
268
+ ConvModule(
269
+ in_channels=in_channels,
270
+ out_channels=out_channel,
271
+ kernel_size=1,
272
+ act_cfg=None,
273
+ ) for out_channel in out_channels
274
+ ])
275
+
276
+ self.resize_layers = nn.ModuleList([
277
+ nn.ConvTranspose2d(
278
+ in_channels=out_channels[0],
279
+ out_channels=out_channels[0],
280
+ kernel_size=4,
281
+ stride=4,
282
+ padding=0),
283
+ nn.ConvTranspose2d(
284
+ in_channels=out_channels[1],
285
+ out_channels=out_channels[1],
286
+ kernel_size=2,
287
+ stride=2,
288
+ padding=0),
289
+ nn.Identity(),
290
+ nn.Conv2d(
291
+ in_channels=out_channels[3],
292
+ out_channels=out_channels[3],
293
+ kernel_size=3,
294
+ stride=2,
295
+ padding=1)
296
+ ])
297
+ if self.readout_type == 'project':
298
+ self.readout_projects = nn.ModuleList()
299
+ for _ in range(len(self.projects)):
300
+ self.readout_projects.append(
301
+ nn.Sequential(
302
+ nn.Linear(2 * in_channels, in_channels),
303
+ nn.GELU()))
304
+ #build_activation_layer(dict(type='GELU'))))
305
+
306
+ def forward(self, inputs):
307
+ assert isinstance(inputs, list)
308
+ out = []
309
+ for i, x in enumerate(inputs):
310
+ assert len(x) == 2
311
+ x, cls_token = x[0], x[1]
312
+ feature_shape = x.shape
313
+ if self.readout_type == 'project':
314
+ x = x.flatten(2).permute((0, 2, 1))
315
+ readout = cls_token.unsqueeze(1).expand_as(x)
316
+ x = self.readout_projects[i](torch.cat((x, readout), -1))
317
+ x = x.permute(0, 2, 1).reshape(feature_shape)
318
+ elif self.readout_type == 'add':
319
+ x = x.flatten(2) + cls_token.unsqueeze(-1)
320
+ x = x.reshape(feature_shape)
321
+ else:
322
+ pass
323
+ x = self.projects[i](x)
324
+ x = self.resize_layers[i](x)
325
+ out.append(x)
326
+ return out
327
+
328
+
329
+ class PreActResidualConvUnit(nn.Module):
330
+ """ResidualConvUnit, pre-activate residual unit.
331
+ Args:
332
+ in_channels (int): number of channels in the input feature map.
333
+ act_cfg (dict): dictionary to construct and config activation layer.
334
+ norm_cfg (dict): dictionary to construct and config norm layer.
335
+ stride (int): stride of the first block. Default: 1
336
+ dilation (int): dilation rate for convs layers. Default: 1.
337
+ init_cfg (dict, optional): Initialization config dict. Default: None.
338
+ """
339
+
340
+ def __init__(self,
341
+ in_channels,
342
+ act_cfg,
343
+ norm_cfg,
344
+ stride=1,
345
+ dilation=1,
346
+ init_cfg=None):
347
+ super(PreActResidualConvUnit, self).__init__()#init_cfg)
348
+ self.conv1 = ConvModule(
349
+ in_channels,
350
+ in_channels,
351
+ 3,
352
+ stride=stride,
353
+ padding=dilation,
354
+ dilation=dilation,
355
+ norm_cfg=norm_cfg,
356
+ act_cfg=act_cfg,
357
+ bias=False,
358
+ order=('act', 'conv', 'norm'))
359
+ self.conv2 = ConvModule(
360
+ in_channels,
361
+ in_channels,
362
+ 3,
363
+ padding=1,
364
+ norm_cfg=norm_cfg,
365
+ act_cfg=act_cfg,
366
+ bias=False,
367
+ order=('act', 'conv', 'norm'))
368
+ def forward(self, inputs):
369
+ inputs_ = inputs.clone()
370
+ x = self.conv1(inputs)
371
+ x = self.conv2(x)
372
+ return x + inputs_
373
+
374
+
375
+ class FeatureFusionBlock(nn.Module):
376
+ """FeatureFusionBlock, merge feature map from different stages.
377
+ Args:
378
+ in_channels (int): Input channels.
379
+ act_cfg (dict): The activation config for ResidualConvUnit.
380
+ norm_cfg (dict): Config dict for normalization layer.
381
+ expand (bool): Whether expand the channels in post process block.
382
+ Default: False.
383
+ align_corners (bool): align_corner setting for bilinear upsample.
384
+ Default: True.
385
+ init_cfg (dict, optional): Initialization config dict. Default: None.
386
+ """
387
+
388
+ def __init__(self,
389
+ in_channels,
390
+ act_cfg,
391
+ norm_cfg,
392
+ expand=False,
393
+ align_corners=True,
394
+ init_cfg=None):
395
+ super(FeatureFusionBlock, self).__init__()#init_cfg)
396
+ self.in_channels = in_channels
397
+ self.expand = expand
398
+ self.align_corners = align_corners
399
+ self.out_channels = in_channels
400
+ if self.expand:
401
+ self.out_channels = in_channels // 2
402
+ self.project = ConvModule(
403
+ self.in_channels,
404
+ self.out_channels,
405
+ kernel_size=1,
406
+ act_cfg=None,
407
+ bias=True)
408
+ self.res_conv_unit1 = PreActResidualConvUnit(
409
+ in_channels=self.in_channels, act_cfg=act_cfg, norm_cfg=norm_cfg)
410
+ self.res_conv_unit2 = PreActResidualConvUnit(
411
+ in_channels=self.in_channels, act_cfg=act_cfg, norm_cfg=norm_cfg)
412
+
413
+ def forward(self, *inputs):
414
+ x = inputs[0]
415
+
416
+ if len(inputs) == 2:
417
+ if x.shape != inputs[1].shape:
418
+ res = resize(
419
+ inputs[1],
420
+ size=(x.shape[2], x.shape[3]),
421
+ mode='bilinear',
422
+ align_corners=False)
423
+ else:
424
+ res = inputs[1]
425
+ x = x + self.res_conv_unit1(res)
426
+ x = self.res_conv_unit2(x)
427
+ x = resize( x, scale_factor=2, mode='bilinear', align_corners=self.align_corners)
428
+ x = self.project(x)
429
+ return x
430
+
431
+ class DPTHead(nn.Module):
432
+ """Vision Transformers for Dense Prediction.
433
+ This head is implemented of `DPT <https://arxiv.org/abs/2103.13413>`_.
434
+ Args:
435
+ embed_dims (int): The embed dimension of the ViT backbone.
436
+ Default: 768.
437
+ post_process_channels (List): Out channels of post process conv
438
+ layers. Default: [96, 192, 384, 768].
439
+ readout_type (str): Type of readout operation. Default: 'ignore'.
440
+ patch_size (int): The patch size. Default: 16.
441
+ expand_channels (bool): Whether expand the channels in post process
442
+ block. Default: False.
443
+ """
444
+
445
+ def __init__(self,
446
+ in_channels=(1024, 1024, 1024, 1024),
447
+ channels=256,
448
+ embed_dims=1024,
449
+ post_process_channels=[128, 256, 512, 1024],
450
+ readout_type='project',
451
+ patch_size=16,
452
+ expand_channels=False,
453
+ min_depth = 0.001,
454
+ classify=False,
455
+ n_bins=256,
456
+ **kwargs):
457
+ super(DPTHead, self).__init__(**kwargs)
458
+ torch.manual_seed(1)
459
+ self.channels = channels
460
+ self.norm_cfg = None
461
+ self.min_depth = min_depth
462
+ self.max_depth = 10
463
+ self.n_bins = n_bins
464
+ self.classify = classify
465
+ self.in_channels = in_channels
466
+ self.expand_channels = expand_channels
467
+ self.reassemble_blocks = ReassembleBlocks(in_channels=embed_dims, # Camille 23-06-26
468
+ out_channels=post_process_channels) # Camille 23-06-26
469
+
470
+ self.post_process_channels = [
471
+ channel * math.pow(2, i) if expand_channels else channel
472
+ for i, channel in enumerate(post_process_channels)
473
+ ]
474
+ self.convs = nn.ModuleList()
475
+ for channel in self.post_process_channels:
476
+ self.convs.append(
477
+ ConvModule(
478
+ channel,
479
+ self.channels,
480
+ kernel_size=3,
481
+ padding=1,
482
+ act_cfg=None,
483
+ bias=False))
484
+ self.fusion_blocks = nn.ModuleList()
485
+ self.act_cfg = {'type': 'ReLU'}
486
+ for _ in range(len(self.convs)):
487
+ self.fusion_blocks.append(
488
+ FeatureFusionBlock(self.channels, self.act_cfg, self.norm_cfg))
489
+ self.fusion_blocks[0].res_conv_unit1 = None
490
+ torch.manual_seed(1)
491
+ self.project = ConvModule(
492
+ self.channels,
493
+ self.channels,
494
+ kernel_size=3,
495
+ padding=1,
496
+ norm_cfg=None)
497
+ self.num_fusion_blocks = len(self.fusion_blocks)
498
+ self.num_reassemble_blocks = len(self.reassemble_blocks.resize_layers)
499
+ self.num_post_process_channels = len(self.post_process_channels)
500
+ assert self.num_fusion_blocks == self.num_reassemble_blocks
501
+ assert self.num_reassemble_blocks == self.num_post_process_channels
502
+ #self.conv_depth = HeadDepth(self.channels)
503
+ self.conv_depth = HeadDepth(self.channels, self.classify, self.n_bins)
504
+ self.relu = nn.ReLU()
505
+ self.sigmoid = nn.Sigmoid()
506
+
507
+
508
+ def forward(self, inputs):
509
+
510
+ assert len(inputs) == self.num_reassemble_blocks
511
+ x = [inp for inp in inputs]
512
+
513
+ x = self.reassemble_blocks(x)
514
+ x = [self.convs[i](feature) for i, feature in enumerate(x)]
515
+ out = self.fusion_blocks[0](x[-1])
516
+
517
+ for i in range(1, len(self.fusion_blocks)):
518
+ out = self.fusion_blocks[i](out, x[-(i + 1)])
519
+
520
+ out = self.project(out)
521
+ if self.classify:
522
+ logit = self.conv_depth(out)
523
+
524
+ #if self.bins_strategy == 'UD':
525
+ bins = torch.linspace(self.min_depth, self.max_depth, self.n_bins, device=inputs[0][0].device)
526
+ #linear strategy
527
+ logit = torch.relu(logit)
528
+ eps = 0.1
529
+ logit = logit + eps
530
+ logit = logit / logit.sum(dim=1, keepdim=True)
531
+ out = torch.einsum('ikmn,k->imn', [logit, bins]).unsqueeze(dim=1) #+ self.min_depth
532
+ else:
533
+ out = self.relu(self.conv_depth(out)) + self.min_depth
534
+
535
+ return out
536
+
models/regressor.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torchvision
9
+
10
+ class RNet(nn.Module):
11
+ def __init__(
12
+ self,
13
+ n_channels=3,
14
+ n_classes=13,
15
+ n_pix=256,
16
+ filters=(8, 16, 32, 64, 64, 128),
17
+ pool=(2, 2),
18
+ kernel_size=(3, 3),
19
+ n_meta=0,
20
+ ) -> None:
21
+ super(RNet, self).__init__()
22
+
23
+ def conv_block(in_filters, out_filters, kernel_size):
24
+ layers = nn.Sequential(
25
+ # first conv is across channels, size=1
26
+ nn.Conv2d(in_filters, out_filters, kernel_size=(1, 1), padding="same"),
27
+ nn.BatchNorm2d(out_filters),
28
+ nn.ReLU(),
29
+ nn.Conv2d(
30
+ out_filters, out_filters, kernel_size=kernel_size, padding="same"
31
+ ),
32
+ )
33
+ return layers
34
+
35
+ def fc_block(in_features, out_features):
36
+ layers = nn.Sequential(
37
+ nn.Linear(in_features=in_features, out_features=out_features),
38
+ #nn.BatchNorm1d(out_features),
39
+ #nn.InstanceNorm1d(out_features),
40
+ nn.ReLU(),
41
+ )
42
+ return layers
43
+
44
+ self.pool = nn.MaxPool2d(2, 2)
45
+ self.input_layer = conv_block(n_channels, filters[0], kernel_size)
46
+ self.conv_block1 = conv_block(filters[0], filters[1], kernel_size)
47
+ self.conv_block2 = conv_block(filters[1], filters[2], kernel_size)
48
+ self.conv_block3 = conv_block(filters[2], filters[3], kernel_size)
49
+ self.conv_block4 = conv_block(filters[3], filters[4], kernel_size)
50
+ self.conv_block5 = conv_block(filters[4], filters[5], kernel_size)
51
+ n_pool = 5
52
+ self.fc1 = fc_block(in_features= int(filters[5] * (n_pix / 2**n_pool) ** 2), out_features=64)
53
+ self.fc2 = fc_block(in_features=64 + n_meta, out_features=64)
54
+ self.fc3 = fc_block(in_features=64, out_features=32)
55
+ self.fc4 = nn.Linear(in_features=32, out_features=n_classes)
56
+
57
+ def forward(self, x):
58
+ x1 = self.pool(self.input_layer(x))
59
+ x2 = self.pool(self.conv_block1(x1))
60
+ x3 = self.pool(self.conv_block2(x2))
61
+ x4 = self.pool(self.conv_block3(x3))
62
+ x4b = self.pool(self.conv_block4(x4))
63
+ x5 = self.conv_block5(x4b)
64
+ x6 = torch.flatten(x5, 1) # flatten all dimensions except batch
65
+ x7 = self.fc1(x6)
66
+ x9 = self.fc2(x7)
67
+ x10 = self.fc3(x9)
68
+ x11 = self.fc4(x10)
69
+ return x11