Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,968 Bytes
17cd746 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
# Copyright (c) 2023-2024, Zexin He
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
import timm
from accelerate.logging import get_logger
logger = get_logger(__name__)
class XUNet(nn.Module):
def __init__(self, model_name="swin_base_patch4_window12_384_in22k", encoder_feat_dim=384):
super(XUNet, self).__init__()
# Swin Transformer Encoder
self.encoder = timm.create_model(model_name, pretrained=True)
# swin
# del self.encoder.head
# del self.encoder.norm
# resnet
del self.encoder.global_pool
del self.encoder.fc
# Decoder layers
# self.upconv4 = self.upconv_block(2048, 1024) # Upsample
# self.upconv3 = self.upconv_block(1024, 512)
# self.upconv2 = self.upconv_block(512, 256)
# self.upconv1 = self.upconv_block(256, 64)
self.upconv4 = self.upconv_block(512, 256) # Upsample
self.upconv3 = self.upconv_block(256, 128)
self.upconv2 = self.upconv_block(128, 64)
# self.upconv1 = self.upconv_block(64, 64)
self.out_conv = nn.Conv2d(64, encoder_feat_dim, kernel_size=1)
def upconv_block(self, in_channels, out_channels):
return nn.Sequential(
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
nn.ReLU(inplace=True),
)
def forward(self, x):
# Encoder part using Swin Transformer
enc_output = self.encoder.forward_intermediates(x, stop_early=True, intermediates_only=True)
# for e in enc_output:
# print(e.shape, x.shape)
# Assuming output of the encoder is a list of feature maps
# Resize them according to UNet architecture
enc_out4 = enc_output[4] # Adjust according to the feature layers of Swin
enc_out3 = enc_output[3]
enc_out2 = enc_output[2]
enc_out1 = enc_output[1]
# enc_out0 = enc_output[0]
# Decoder part
x = self.upconv4(enc_out4)
x = x + enc_out3 # s16, Skip connection
x = self.upconv3(x)
x = x + enc_out2 # s8
x = self.upconv2(x)
x = x + enc_out1 # s4
# x = self.upconv1(x)
# x = x + enc_out0 # s2
x = self.out_conv(x)
return x
class XnetWrapper(nn.Module):
"""
XnetWrapper using original implementation, hacked with modulation.
"""
def __init__(self, model_name: str, modulation_dim: int = None, freeze: bool = True, encoder_feat_dim: int = 384):
super().__init__()
self.modulation_dim = modulation_dim
self.model = XUNet(model_name=model_name, encoder_feat_dim=encoder_feat_dim)
if freeze:
if modulation_dim is not None:
raise ValueError("Modulated SwinUnetWrapper requires training, freezing is not allowed.")
self._freeze()
def _freeze(self):
logger.warning(f"======== Freezing SwinUnetWrapper ========")
self.model.eval()
for name, param in self.model.named_parameters():
param.requires_grad = False
@torch.compile
def forward(self, image: torch.Tensor, mod: torch.Tensor = None):
# image: [N, C, H, W]
# mod: [N, D] or None
# RGB image with [0,1] scale and properly sized
outs = self.model(image)
ret = outs.permute(0, 2, 3, 1).flatten(1, 2)
return ret
|