cheng-hust's picture
Upload 91 files
e8861c0 verified
"""by lyuwenyu
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import numpy as np
from src.core import register
__all__ = ['RTDETR', ]
@register
class RTDETR(nn.Module):
__inject__ = ['backbone', 'encoder', 'decoder', ]
def __init__(self, backbone: nn.Module, encoder, decoder, multi_scale=None):
super().__init__()
self.backbone = backbone
self.decoder = decoder
self.encoder = encoder
self.multi_scale = multi_scale
def forward(self, x, targets=None):
if self.multi_scale and self.training:
sz = np.random.choice(self.multi_scale)
x = F.interpolate(x, size=[sz, sz])
x = self.backbone(x)
x = self.encoder(x)
x = self.decoder(x, targets)
return x
def deploy(self, ):
self.eval()
for m in self.modules():
if hasattr(m, 'convert_to_deploy'):
m.convert_to_deploy()
return self