File size: 1,034 Bytes
e8861c0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
"""by lyuwenyu
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import numpy as np
from src.core import register
__all__ = ['RTDETR', ]
@register
class RTDETR(nn.Module):
__inject__ = ['backbone', 'encoder', 'decoder', ]
def __init__(self, backbone: nn.Module, encoder, decoder, multi_scale=None):
super().__init__()
self.backbone = backbone
self.decoder = decoder
self.encoder = encoder
self.multi_scale = multi_scale
def forward(self, x, targets=None):
if self.multi_scale and self.training:
sz = np.random.choice(self.multi_scale)
x = F.interpolate(x, size=[sz, sz])
x = self.backbone(x)
x = self.encoder(x)
x = self.decoder(x, targets)
return x
def deploy(self, ):
self.eval()
for m in self.modules():
if hasattr(m, 'convert_to_deploy'):
m.convert_to_deploy()
return self
|