import torch from torch import nn def conv4x4(in_c, out_c): return nn.Sequential( nn.Conv2d(in_c, out_c,kernel_size=4, stride=2, padding=1), nn.BatchNorm2d(out_c), nn.LeakyReLU(0.1, inplace=True), ) def deconv4x4(in_c, out_c): return nn.Sequential( nn.ConvTranspose2d(in_c, out_c, kernel_size=4, stride=2, padding=1), nn.BatchNorm2d(out_c), nn.LeakyReLU(0.1, inplace=True), ) class Hear_Net(nn.Module): def __init__(self): super(Hear_Net, self).__init__() self.down1 = conv4x4(6, 64) self.down2 = conv4x4(64, 128) self.down3 = conv4x4(128, 256) self.down4 = conv4x4(256, 512) self.down5 = conv4x4(512, 512) self.up1 = deconv4x4(512, 512) self.up2 = deconv4x4(512*2, 256) self.up3 = deconv4x4(256*2, 128) self.up4 = deconv4x4(128*2, 64) self.up5 = nn.Conv2d(64*2, 3, kernel_size=3, stride=1, padding=1) def forward(self, x): # input:(B,6,256,256) c1 = self.down1(x) c2 = self.down2(c1) c3 = self.down3(c2) c4 = self.down4(c3) c5 = self.down5(c4) m1 = self.up1(c5) m1 = torch.cat((c4, m1), dim=1) m2 = self.up2(m1) m2 = torch.cat((c3, m2), dim=1) m3 = self.up3(m2) m3 = torch.cat((c2, m3), dim=1) m4 = self.up4(m3) m4 = torch.cat((c1, m4), dim=1) out = nn.functional.interpolate(m4, scale_factor=2, mode='bilinear', align_corners=True) out = self.up5(out) return torch.tanh(out) # output:(B,3,256,256) if __name__ == '__main__': y_cat = torch.randn(5, 6, 256, 256) hear = Hear_Net() y_st = hear(y_cat) print(y_st.shape)