juansensio commited on
Commit
b22f21e
·
verified ·
1 Parent(s): 0c0385a

Upload utils.py

Browse files
Files changed (1) hide show
  1. utils.py +61 -0
utils.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+ import torch
3
+ import torch.nn as nn
4
+ from models.backbone import SSLVisionTransformer
5
+ from models.dpt_head import DPTHead
6
+
7
+
8
+ class SSLAE(nn.Module):
9
+ def __init__(self, pretrained=None, classify=True, n_bins=256, huge=False):
10
+ super().__init__()
11
+ if huge == True:
12
+ self.backbone = SSLVisionTransformer(
13
+ embed_dim=1280,
14
+ num_heads=20,
15
+ out_indices=(9, 16, 22, 29),
16
+ depth=32,
17
+ pretrained=pretrained,
18
+ )
19
+ self.decode_head = DPTHead(
20
+ classify=classify,
21
+ in_channels=(1280, 1280, 1280, 1280),
22
+ embed_dims=1280,
23
+ post_process_channels=[160, 320, 640, 1280],
24
+ )
25
+ else:
26
+ self.backbone = SSLVisionTransformer(pretrained=pretrained)
27
+ self.decode_head = DPTHead(classify=classify, n_bins=256)
28
+
29
+ def forward(self, x):
30
+ x = self.backbone(x)
31
+ x = self.decode_head(x)
32
+ return x
33
+
34
+
35
+ class SSLModule(pl.LightningModule):
36
+ def __init__(self, ssl_path="compressed_SSLbaseline.pth"):
37
+ super().__init__()
38
+
39
+ if "huge" in ssl_path:
40
+ self.chm_module_ = SSLAE(classify=True, huge=True).eval()
41
+ else:
42
+ self.chm_module_ = SSLAE(classify=True, huge=False).eval()
43
+
44
+ if "compressed" in ssl_path:
45
+ ckpt = torch.load(ssl_path, map_location="cpu")
46
+ self.chm_module_ = torch.quantization.quantize_dynamic(
47
+ self.chm_module_,
48
+ {torch.nn.Linear, torch.nn.Conv2d, torch.nn.ConvTranspose2d},
49
+ dtype=torch.qint8,
50
+ )
51
+ self.chm_module_.load_state_dict(ckpt, strict=False)
52
+ else:
53
+ ckpt = torch.load(ssl_path, map_location="cpu")
54
+ state_dict = ckpt["state_dict"]
55
+ self.chm_module_.load_state_dict(state_dict)
56
+
57
+ self.chm_module = lambda x: 10 * self.chm_module_(x)
58
+
59
+ def forward(self, x):
60
+ x = self.chm_module(x)
61
+ return x