File size: 1,735 Bytes
a104d3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
from torch import nn

def conv4x4(in_c, out_c):
    return nn.Sequential(
        nn.Conv2d(in_c, out_c,kernel_size=4, stride=2, padding=1),
        nn.BatchNorm2d(out_c),
        nn.LeakyReLU(0.1, inplace=True),
    )


def deconv4x4(in_c, out_c):
    return nn.Sequential(
        nn.ConvTranspose2d(in_c, out_c, kernel_size=4, stride=2, padding=1),
        nn.BatchNorm2d(out_c),
        nn.LeakyReLU(0.1, inplace=True),
    )


class Hear_Net(nn.Module):
    def __init__(self):
        super(Hear_Net, self).__init__()
        self.down1 = conv4x4(6, 64)
        self.down2 = conv4x4(64, 128)
        self.down3 = conv4x4(128, 256)
        self.down4 = conv4x4(256, 512)
        self.down5 = conv4x4(512, 512)

        self.up1 = deconv4x4(512, 512)
        self.up2 = deconv4x4(512*2, 256)
        self.up3 = deconv4x4(256*2, 128)
        self.up4 = deconv4x4(128*2, 64)
        self.up5 = nn.Conv2d(64*2, 3, kernel_size=3, stride=1, padding=1)

    def forward(self, x):  # input:(B,6,256,256)
        c1 = self.down1(x)
        c2 = self.down2(c1)
        c3 = self.down3(c2)
        c4 = self.down4(c3)
        c5 = self.down5(c4)

        m1 = self.up1(c5)
        m1 = torch.cat((c4, m1), dim=1)
        m2 = self.up2(m1)
        m2 = torch.cat((c3, m2), dim=1)
        m3 = self.up3(m2)
        m3 = torch.cat((c2, m3), dim=1)
        m4 = self.up4(m3)
        m4 = torch.cat((c1, m4), dim=1)

        out = nn.functional.interpolate(m4, scale_factor=2, mode='bilinear', align_corners=True)
        out = self.up5(out)
        return torch.tanh(out)  # output:(B,3,256,256)


if __name__ == '__main__':
    y_cat = torch.randn(5, 6, 256, 256)
    hear = Hear_Net()
    y_st = hear(y_cat)
    print(y_st.shape)