Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) 2023-2024, Zexin He | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# https://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import torch | |
import torch.nn as nn | |
import timm | |
from accelerate.logging import get_logger | |
logger = get_logger(__name__) | |
class XUNet(nn.Module): | |
def __init__(self, model_name="swin_base_patch4_window12_384_in22k", encoder_feat_dim=384): | |
super(XUNet, self).__init__() | |
# Swin Transformer Encoder | |
self.encoder = timm.create_model(model_name, pretrained=True) | |
# swin | |
# del self.encoder.head | |
# del self.encoder.norm | |
# resnet | |
del self.encoder.global_pool | |
del self.encoder.fc | |
# Decoder layers | |
# self.upconv4 = self.upconv_block(2048, 1024) # Upsample | |
# self.upconv3 = self.upconv_block(1024, 512) | |
# self.upconv2 = self.upconv_block(512, 256) | |
# self.upconv1 = self.upconv_block(256, 64) | |
self.upconv4 = self.upconv_block(512, 256) # Upsample | |
self.upconv3 = self.upconv_block(256, 128) | |
self.upconv2 = self.upconv_block(128, 64) | |
# self.upconv1 = self.upconv_block(64, 64) | |
self.out_conv = nn.Conv2d(64, encoder_feat_dim, kernel_size=1) | |
def upconv_block(self, in_channels, out_channels): | |
return nn.Sequential( | |
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2), | |
nn.ReLU(inplace=True), | |
) | |
def forward(self, x): | |
# Encoder part using Swin Transformer | |
enc_output = self.encoder.forward_intermediates(x, stop_early=True, intermediates_only=True) | |
# for e in enc_output: | |
# print(e.shape, x.shape) | |
# Assuming output of the encoder is a list of feature maps | |
# Resize them according to UNet architecture | |
enc_out4 = enc_output[4] # Adjust according to the feature layers of Swin | |
enc_out3 = enc_output[3] | |
enc_out2 = enc_output[2] | |
enc_out1 = enc_output[1] | |
# enc_out0 = enc_output[0] | |
# Decoder part | |
x = self.upconv4(enc_out4) | |
x = x + enc_out3 # s16, Skip connection | |
x = self.upconv3(x) | |
x = x + enc_out2 # s8 | |
x = self.upconv2(x) | |
x = x + enc_out1 # s4 | |
# x = self.upconv1(x) | |
# x = x + enc_out0 # s2 | |
x = self.out_conv(x) | |
return x | |
class XnetWrapper(nn.Module): | |
""" | |
XnetWrapper using original implementation, hacked with modulation. | |
""" | |
def __init__(self, model_name: str, modulation_dim: int = None, freeze: bool = True, encoder_feat_dim: int = 384): | |
super().__init__() | |
self.modulation_dim = modulation_dim | |
self.model = XUNet(model_name=model_name, encoder_feat_dim=encoder_feat_dim) | |
if freeze: | |
if modulation_dim is not None: | |
raise ValueError("Modulated SwinUnetWrapper requires training, freezing is not allowed.") | |
self._freeze() | |
def _freeze(self): | |
logger.warning(f"======== Freezing SwinUnetWrapper ========") | |
self.model.eval() | |
for name, param in self.model.named_parameters(): | |
param.requires_grad = False | |
def forward(self, image: torch.Tensor, mod: torch.Tensor = None): | |
# image: [N, C, H, W] | |
# mod: [N, D] or None | |
# RGB image with [0,1] scale and properly sized | |
outs = self.model(image) | |
ret = outs.permute(0, 2, 3, 1).flatten(1, 2) | |
return ret | |