cheng-hust commited on
Commit
7c5b9cf
·
verified ·
1 Parent(s): 6c616ed

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -0
app.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch.nn as nn
3
+ import torch
4
+ from src.core.config import BaseConfig
5
+ from src.core.yaml_utils import load_config, merge_config, create, merge_dict
6
+ import cv2
7
+ import torchvision.transforms as transforms
8
+ from PIL import Image, ImageDraw
9
+ from src.core import YAMLConfig
10
+ from pathlib import Path
11
+
12
+ transformer = transforms.Compose([
13
+ transforms.ToTensor(),
14
+ # transforms.Resize([224,224]),
15
+
16
+ ])
17
+
18
+
19
+ #def model(yaml_cfg) -> torch.nn.Module:
20
+ # if 'model' in yaml_cfg:
21
+ # merge_config(yaml_cfg)
22
+ # model = create(yaml_cfg['model'])
23
+ # return model
24
+ #
25
+ #
26
+ #
27
+ #def get_model(cfg_path='./configs/rtdetr/rtdetr_r101vd_6x_coco.yml',ckpt="./logs/checkpointcope12.pth"):
28
+ # yaml_cfg = load_config(cfg_path)
29
+ # merge_dict(yaml_cfg, {})
30
+ # tmodel = model(yaml_cfg)
31
+ # checkpoint = torch.load(ckpt, map_location='cpu')
32
+ # if 'ema' in checkpoint:
33
+ # state = checkpoint['ema']['module']
34
+ # else:
35
+ # state = checkpoint['model']
36
+ #
37
+ # tmodel.load_state_dict(state)
38
+ # return tmodel
39
+
40
+ class Model(nn.Module):
41
+ def __init__(self, confg=None, ckpt="") -> None:
42
+ super().__init__()
43
+ self.cfg = YAMLConfig(confg, resume=ckpt)
44
+ if ckpt:
45
+ checkpoint = torch.load(ckpt, map_location='cpu')
46
+ if 'ema' in checkpoint:
47
+ state = checkpoint['ema']['module']
48
+ else:
49
+ state = checkpoint['model']
50
+ else:
51
+ raise AttributeError('only support resume to load model.state_dict by now.')
52
+
53
+ # NOTE load train mode state -> convert to deploy mode
54
+ self.cfg.model.load_state_dict(state)
55
+
56
+ self.model = self.cfg.model.deploy()
57
+ self.postprocessor = self.cfg.postprocessor.deploy()
58
+ # print(self.postprocessor.deploy_mode)
59
+
60
+ def forward(self, images, orig_target_sizes):
61
+ outputs = self.model(images)
62
+ return self.postprocessor(outputs, orig_target_sizes)
63
+
64
+
65
+ model = Model(confg='./configs/rtdetr/rtdetr_r101vd_6x_coco.yml',ckpt="./logs/checkpointcope12.pth")
66
+
67
+ #img = cv2.imread('./j.jpg',cv2.IMREAD_GRAYSCALE)
68
+ #img = Image.open('./a.jpg').convert('RGB').resize((640,640))
69
+
70
+
71
+
72
+ def detect(img):
73
+ img = img.resize((640,640))
74
+ t_img = transformer(img).unsqueeze(0)#.unsqueeze(0) #[1,1,640,640]
75
+ size = torch.tensor([[t_img.shape[2], t_img.shape[3]]])
76
+ #print(t_img.shape)
77
+ labels, boxes, scores=model(t_img,size)
78
+ #img_path = Path('./a.jpg')
79
+ draw = ImageDraw.Draw(img)
80
+ thrh = 0.1
81
+
82
+ for i in range(t_img.shape[0]):
83
+
84
+ scr = scores[i]
85
+ lab = labels[i][scr > thrh]
86
+ box = boxes[i][scr > thrh]
87
+
88
+
89
+ for b in box:
90
+ draw.rectangle(list(b), outline='red', )
91
+ draw.text((b[0], b[1]), text=str(lab[i]), fill='blue', )
92
+
93
+ #save_path = Path('./output') / img_path.name
94
+ return img
95
+
96
+ interface = gr.Interface(fn=detect,inputs="image",outputs="image",title="rt-cope detect")
97
+
98
+ interface.launch()