|
"""by lyuwenyu |
|
""" |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
import random |
|
import numpy as np |
|
|
|
from src.core import register |
|
|
|
|
|
__all__ = ['RTDETR', ] |
|
|
|
|
|
@register |
|
class RTDETR(nn.Module): |
|
__inject__ = ['backbone', 'encoder', 'decoder', ] |
|
|
|
def __init__(self, backbone: nn.Module, encoder, decoder, multi_scale=None): |
|
super().__init__() |
|
self.backbone = backbone |
|
self.decoder = decoder |
|
self.encoder = encoder |
|
self.multi_scale = multi_scale |
|
|
|
def forward(self, x, targets=None): |
|
if self.multi_scale and self.training: |
|
sz = np.random.choice(self.multi_scale) |
|
x = F.interpolate(x, size=[sz, sz]) |
|
|
|
x = self.backbone(x) |
|
x = self.encoder(x) |
|
x = self.decoder(x, targets) |
|
|
|
return x |
|
|
|
def deploy(self, ): |
|
self.eval() |
|
for m in self.modules(): |
|
if hasattr(m, 'convert_to_deploy'): |
|
m.convert_to_deploy() |
|
return self |
|
|