Spaces:
Sleeping
Sleeping
juansensio
commited on
Upload utils.py
Browse files
utils.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytorch_lightning as pl
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from models.backbone import SSLVisionTransformer
|
5 |
+
from models.dpt_head import DPTHead
|
6 |
+
|
7 |
+
|
8 |
+
class SSLAE(nn.Module):
|
9 |
+
def __init__(self, pretrained=None, classify=True, n_bins=256, huge=False):
|
10 |
+
super().__init__()
|
11 |
+
if huge == True:
|
12 |
+
self.backbone = SSLVisionTransformer(
|
13 |
+
embed_dim=1280,
|
14 |
+
num_heads=20,
|
15 |
+
out_indices=(9, 16, 22, 29),
|
16 |
+
depth=32,
|
17 |
+
pretrained=pretrained,
|
18 |
+
)
|
19 |
+
self.decode_head = DPTHead(
|
20 |
+
classify=classify,
|
21 |
+
in_channels=(1280, 1280, 1280, 1280),
|
22 |
+
embed_dims=1280,
|
23 |
+
post_process_channels=[160, 320, 640, 1280],
|
24 |
+
)
|
25 |
+
else:
|
26 |
+
self.backbone = SSLVisionTransformer(pretrained=pretrained)
|
27 |
+
self.decode_head = DPTHead(classify=classify, n_bins=256)
|
28 |
+
|
29 |
+
def forward(self, x):
|
30 |
+
x = self.backbone(x)
|
31 |
+
x = self.decode_head(x)
|
32 |
+
return x
|
33 |
+
|
34 |
+
|
35 |
+
class SSLModule(pl.LightningModule):
|
36 |
+
def __init__(self, ssl_path="compressed_SSLbaseline.pth"):
|
37 |
+
super().__init__()
|
38 |
+
|
39 |
+
if "huge" in ssl_path:
|
40 |
+
self.chm_module_ = SSLAE(classify=True, huge=True).eval()
|
41 |
+
else:
|
42 |
+
self.chm_module_ = SSLAE(classify=True, huge=False).eval()
|
43 |
+
|
44 |
+
if "compressed" in ssl_path:
|
45 |
+
ckpt = torch.load(ssl_path, map_location="cpu")
|
46 |
+
self.chm_module_ = torch.quantization.quantize_dynamic(
|
47 |
+
self.chm_module_,
|
48 |
+
{torch.nn.Linear, torch.nn.Conv2d, torch.nn.ConvTranspose2d},
|
49 |
+
dtype=torch.qint8,
|
50 |
+
)
|
51 |
+
self.chm_module_.load_state_dict(ckpt, strict=False)
|
52 |
+
else:
|
53 |
+
ckpt = torch.load(ssl_path, map_location="cpu")
|
54 |
+
state_dict = ckpt["state_dict"]
|
55 |
+
self.chm_module_.load_state_dict(state_dict)
|
56 |
+
|
57 |
+
self.chm_module = lambda x: 10 * self.chm_module_(x)
|
58 |
+
|
59 |
+
def forward(self, x):
|
60 |
+
x = self.chm_module(x)
|
61 |
+
return x
|