File size: 3,856 Bytes
7c5b9cf
 
 
 
 
 
 
 
 
 
2b88463
7c5b9cf
 
 
9eb39a7
7c5b9cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55489ff
7c5b9cf
9165ca9
 
 
7c5b9cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9165ca9
672536c
 
7c5b9cf
 
 
 
 
672536c
9eb39a7
 
7c5b9cf
 
 
672536c
9165ca9
 
 
ed23be2
9165ca9
7c5b9cf
 
f29c1af
7c5b9cf
 
 
 
 
 
 
0da3b64
2c41e6d
e87ccec
f90adc8
 
 
7c5b9cf
 
 
 
672536c
7c5b9cf
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import gradio as gr
import torch.nn as nn
import torch
from src.core.config import BaseConfig
from src.core.yaml_utils import load_config, merge_config, create, merge_dict
import cv2
import torchvision.transforms as transforms
from PIL import Image, ImageDraw
from src.core import YAMLConfig
from pathlib import Path
import src

transformer = transforms.Compose([    
    transforms.ToTensor(),                            
    #transforms.Resize([640,640]),

])


#def model(yaml_cfg) -> torch.nn.Module:
#    if 'model' in yaml_cfg:
#        merge_config(yaml_cfg)
#        model = create(yaml_cfg['model'])
#    return model
#
#
#
#def get_model(cfg_path='./configs/rtdetr/rtdetr_r101vd_6x_coco.yml',ckpt="./logs/checkpointcope12.pth"):
#  yaml_cfg = load_config(cfg_path)
#  merge_dict(yaml_cfg, {}) 
#  tmodel = model(yaml_cfg)
#  checkpoint = torch.load(ckpt, map_location='cpu') 
#  if 'ema' in checkpoint:
#      state = checkpoint['ema']['module']
#  else:
#      state = checkpoint['model']
#      
#  tmodel.load_state_dict(state)
#  return tmodel
  
class Model(nn.Module):
    def __init__(self, confg=None, ckpt="",cope=True) -> None:
        super().__init__()

        self.cfg = src.core.YAMLConfig(confg, resume=ckpt)

        if ckpt:
            checkpoint = torch.load(ckpt, map_location='cpu') 
            if 'ema' in checkpoint:
                state = checkpoint['ema']['module']
            else:
                state = checkpoint['model']
        else:
            raise AttributeError('only support resume to load model.state_dict by now.')

        # NOTE load train mode state -> convert to deploy mode
        self.cfg.model.load_state_dict(state)

        self.model = self.cfg.model.deploy()
        self.postprocessor = self.cfg.postprocessor.deploy()
        # print(self.postprocessor.deploy_mode)
        
    def forward(self, images, orig_target_sizes):
        outputs = self.model(images)
        return self.postprocessor(outputs, orig_target_sizes)
  

model = Model(confg='./configs/rtdetr/rtdetr_r101vd_6x_coco.yml',ckpt="./checkpoint_init.pth")
model2 = Model(confg='./configs/rtdetr/rtdetr_r101vd_6x_cococope12.yml',ckpt="./checkpointcope12.pth",cope=False)
model3 =  Model(confg='./configs/rtdetr/rtdetr_r101vd_6x_coco.yml',ckpt="./rtdetrCOCO.pth",cope=False)

#img = cv2.imread('./j.jpg',cv2.IMREAD_GRAYSCALE)
#img = Image.open('./a.jpg').convert('RGB').resize((640,640))


def detect(img,thr=0.2,cope='aitod'):
  #print(img)                    #ndarray
  img = Image.fromarray(img).resize((640,640))
  t_img = transformer(img).unsqueeze(0)#.unsqueeze(0)      #[1,1,640,640]
  size = torch.tensor([[t_img.shape[2], t_img.shape[3]]])
  #print(t_img.shape)
  if cope == 'aitod':
      labels, boxes, scores=model(t_img,size)
  elif cope == 'cope12':
      labels, boxes, scores=model2(t_img,size)
  else:
      labels, boxes, scores=model3(t_img,size)
  #img_path = Path('./a.jpg')
  draw = ImageDraw.Draw(img)
  thrh = thr
  
  for i in range(t_img.shape[0]):
  
      scr = scores[i]
      lab = labels[i][scr > thrh]
      box = boxes[i][scr > thrh]
  
      label_dict = {8:'wind-mill',7:'person',6:'vehicle',5:'swimming-pool',4:'ship',3:'storage-tank',2:'bridge',1:'airplane'}         #AITOD labels
      label_color_dict = {8:'burlyWood',7:'red',6:'blue',5:'green',4:'yellow',3:'cyan',2:'magenta',1:'orange'}
      for idx,b in enumerate(box):
          label_i = lab[idx].item()
          draw.rectangle(list(b), outline=label_color_dict[label_i], )
          draw.text((b[0], b[1]), text=label_dict[label_i], fill='blue', )
  
  #save_path = Path('./output') / img_path.name
  return img

interface = gr.Interface(fn=detect,inputs=["image",gr.Slider(label="thr", value=0.2, maximum=1, minimum=0),gr.inputs.Radio(['aitod','cope12','COCO'])],outputs="image",title="rt-cope detect")
  
interface.launch()