File size: 5,565 Bytes
7c5b9cf
 
 
 
 
 
 
 
 
 
2b88463
7c5b9cf
 
 
9eb39a7
7c5b9cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55489ff
7c5b9cf
9165ca9
 
 
7c5b9cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9165ca9
672536c
038024a
7c5b9cf
 
 
 
 
2391663
9eb39a7
 
7c5b9cf
 
 
28b4a8a
9165ca9
b4ea8c2
9165ca9
ed23be2
9165ca9
7c5b9cf
 
f29c1af
7c5b9cf
 
 
 
 
 
 
0da3b64
038024a
 
 
 
 
 
 
 
 
 
 
 
 
2c41e6d
b4ea8c2
038024a
 
 
 
 
 
 
 
b0b07fc
7c5b9cf
 
 
 
1ff6303
7c5b9cf
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import gradio as gr
import torch.nn as nn
import torch
from src.core.config import BaseConfig
from src.core.yaml_utils import load_config, merge_config, create, merge_dict
import cv2
import torchvision.transforms as transforms
from PIL import Image, ImageDraw
from src.core import YAMLConfig
from pathlib import Path
import src

transformer = transforms.Compose([    
    transforms.ToTensor(),                            
    #transforms.Resize([640,640]),

])


#def model(yaml_cfg) -> torch.nn.Module:
#    if 'model' in yaml_cfg:
#        merge_config(yaml_cfg)
#        model = create(yaml_cfg['model'])
#    return model
#
#
#
#def get_model(cfg_path='./configs/rtdetr/rtdetr_r101vd_6x_coco.yml',ckpt="./logs/checkpointcope12.pth"):
#  yaml_cfg = load_config(cfg_path)
#  merge_dict(yaml_cfg, {}) 
#  tmodel = model(yaml_cfg)
#  checkpoint = torch.load(ckpt, map_location='cpu') 
#  if 'ema' in checkpoint:
#      state = checkpoint['ema']['module']
#  else:
#      state = checkpoint['model']
#      
#  tmodel.load_state_dict(state)
#  return tmodel
  
class Model(nn.Module):
    def __init__(self, confg=None, ckpt="",cope=True) -> None:
        super().__init__()

        self.cfg = src.core.YAMLConfig(confg, resume=ckpt)

        if ckpt:
            checkpoint = torch.load(ckpt, map_location='cpu') 
            if 'ema' in checkpoint:
                state = checkpoint['ema']['module']
            else:
                state = checkpoint['model']
        else:
            raise AttributeError('only support resume to load model.state_dict by now.')

        # NOTE load train mode state -> convert to deploy mode
        self.cfg.model.load_state_dict(state)

        self.model = self.cfg.model.deploy()
        self.postprocessor = self.cfg.postprocessor.deploy()
        # print(self.postprocessor.deploy_mode)
        
    def forward(self, images, orig_target_sizes):
        outputs = self.model(images)
        return self.postprocessor(outputs, orig_target_sizes)
  

model = Model(confg='./configs/rtdetr/rtdetr_r101vd_6x_coco.yml',ckpt="./checkpoint_init.pth")
model2 = Model(confg='./configs/rtdetr/rtdetr_r101vd_6x_cococope12.yml',ckpt="./checkpointcope12.pth",cope=False)
model3 =  Model(confg='./configs/rtdetr/rtdetr_r101vd_6x_cocococo.yml',ckpt="./rtdetrCOCO.pth",cope=False)

#img = cv2.imread('./j.jpg',cv2.IMREAD_GRAYSCALE)
#img = Image.open('./a.jpg').convert('RGB').resize((640,640))


def detect(img,thr=0.2,trained_dataset='aitod'):
  #print(img)                    #ndarray
  img = Image.fromarray(img).resize((640,640))
  t_img = transformer(img).unsqueeze(0)#.unsqueeze(0)      #[1,1,640,640]
  size = torch.tensor([[t_img.shape[2], t_img.shape[3]]])
  #print(t_img.shape)
  if trained_dataset == 'aitod':
      labels, boxes, scores=model(t_img,size)
  elif trained_dataset == 'ten_classes':
      labels, boxes, scores=model2(t_img,size)
  else:
      labels, boxes, scores=model3(t_img,size)
  #img_path = Path('./a.jpg')
  draw = ImageDraw.Draw(img)
  thrh = thr
  
  for i in range(t_img.shape[0]):
  
      scr = scores[i]
      lab = labels[i][scr > thrh]
      box = boxes[i][scr > thrh]
  
      label_dict = {8:'wind-mill',7:'person',6:'vehicle',5:'swimming-pool',4:'ship',3:'storage-tank',2:'bridge',1:'airplane'}         #AITOD labels
      coco_dict = {1: 'person', 2: 'bicycle', 3: 'car', 4: 'motorcycle', 5: 'airplane', 6: 'bus', 7: 'train', 8: 'truck',
           9: 'boat', 10: 'traffic light', 11: 'fire hydrant', 13: 'stop sign', 14: 'parking meter', 15: 'bench',
           16: 'bird', 17: 'cat', 18: 'dog', 19: 'horse', 20: 'sheep', 21: 'cow', 22: 'elephant', 23: 'bear', 24: 'zebra',
           25: 'giraffe', 27: 'backpack', 28: 'umbrella', 31: 'handbag', 32: 'tie', 33: 'suitcase', 34: 'frisbee',
           35: 'skis', 36: 'snowboard', 37: 'sports ball', 38: 'kite', 39: 'baseball bat', 40: 'baseball glove',
           41: 'skateboard', 42: 'surfboard', 43: 'tennis racket', 44: 'bottle', 46: 'wine glass', 47: 'cup',
           48: 'fork', 49: 'knife', 50: 'spoon', 51: 'bowl', 52: 'banana', 53: 'apple', 54: 'sandwich', 55: 'orange',
           56: 'broccoli', 57: 'carrot', 58: 'hot dog', 59: 'pizza', 60: 'donut', 61: 'cake', 62: 'chair', 63: 'couch',
           64: 'potted plant', 65: 'bed', 67: 'dining table', 70: 'toilet', 72: 'tv', 73: 'laptop', 74: 'mouse',
           75: 'remote', 76: 'keyboard', 77: 'cell phone', 78: 'microwave', 79: 'oven', 80: 'toaster', 81: 'sink',
           82: 'refrigerator', 84: 'book', 85: 'clock', 86: 'vase', 87: 'scissors', 88: 'teddy bear', 89: 'hair drier',
           90: 'toothbrush'}                #coco

      label_color_dict = {8:'burlyWood',7:'red',6:'blue',5:'green',4:'yellow',3:'cyan',2:'magenta',1:'orange'}
      if trained_dataset != 'COCO':
          for idx,b in enumerate(box):
              label_i = lab[idx].item()
              draw.rectangle(list(b), outline=label_color_dict[label_i], )
              draw.text((b[0], b[1]), text=label_dict[label_i], fill='blue', )
      else:
          for idx,b in enumerate(box):
              label_i = lab[idx].item()
              draw.rectangle(list(b), outline=label_color_dict[label_i%8+1], )
              draw.text((b[0], b[1]), text=coco_dict[label_i+1], fill='blue', )
  
  #save_path = Path('./output') / img_path.name
  return img

interface = gr.Interface(fn=detect,inputs=["image",gr.Slider(label="thr", value=0.2, maximum=1, minimum=0),gr.Radio(['aitod','ten_classes','COCO'],value='aitod')],outputs="image",title="degraded hust small object detect")
  
interface.launch()