cheng-hust commited on
Commit
9165ca9
·
verified ·
1 Parent(s): 9585432

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -9
app.py CHANGED
@@ -41,10 +41,9 @@ transformer = transforms.Compose([
41
  class Model(nn.Module):
42
  def __init__(self, confg=None, ckpt="",cope=True) -> None:
43
  super().__init__()
44
- if cope:
45
- self.cfg = src.core.YAMLConfig(confg, resume=ckpt)
46
- else:
47
- self.cfg = src2.core.YAMLConfig(confg, resume=ckpt)
48
  if ckpt:
49
  checkpoint = torch.load(ckpt, map_location='cpu')
50
  if 'ema' in checkpoint:
@@ -66,8 +65,9 @@ class Model(nn.Module):
66
  return self.postprocessor(outputs, orig_target_sizes)
67
 
68
 
69
- model = Model(confg='./configs/rtdetr/rtdetr_r101vd_6x_coco.yml',ckpt="./checkpointcope12.pth")
70
- model2 = Model(confg='./configs/rtdetr/rtdetr_r101vd_6x_coco.yml',ckpt="./checkpoint_init.pth",cope=False)
 
71
 
72
  #img = cv2.imread('./j.jpg',cv2.IMREAD_GRAYSCALE)
73
  #img = Image.open('./a.jpg').convert('RGB').resize((640,640))
@@ -79,8 +79,12 @@ def detect(img,thr=0.2,cope='none'):
79
  t_img = transformer(img).unsqueeze(0)#.unsqueeze(0) #[1,1,640,640]
80
  size = torch.tensor([[t_img.shape[2], t_img.shape[3]]])
81
  #print(t_img.shape)
82
- labels, boxes, scores=model(t_img,size)
83
- labels2, boxes2, scores2=model2(t_img,size)
 
 
 
 
84
  #img_path = Path('./a.jpg')
85
  draw = ImageDraw.Draw(img)
86
  thrh = thr
@@ -101,6 +105,6 @@ def detect(img,thr=0.2,cope='none'):
101
  #save_path = Path('./output') / img_path.name
102
  return img
103
 
104
- interface = gr.Interface(fn=detect,inputs=["image",gr.Slider(label="thr", value=0.2, maximum=1, minimum=0)],outputs="image",title="rt-cope detect")
105
 
106
  interface.launch()
 
41
  class Model(nn.Module):
42
  def __init__(self, confg=None, ckpt="",cope=True) -> None:
43
  super().__init__()
44
+
45
+ self.cfg = src.core.YAMLConfig(confg, resume=ckpt)
46
+
 
47
  if ckpt:
48
  checkpoint = torch.load(ckpt, map_location='cpu')
49
  if 'ema' in checkpoint:
 
65
  return self.postprocessor(outputs, orig_target_sizes)
66
 
67
 
68
+ model = Model(confg='./configs/rtdetr/rtdetr_r101vd_6x_coco.yml',ckpt="./checkpoint_init.pth")
69
+ model2 = Model(confg='./configs/rtdetr/rtdetr_r101vd_6x_coco_cope12.yml',ckpt="./checkpointcope12.pth",cope=False)
70
+ model3 = Model(confg='./configs/rtdetr/rtdetr_r101vd_6x_coco_cope24.yml',ckpt="./checkpointcope24.pth",cope=False)
71
 
72
  #img = cv2.imread('./j.jpg',cv2.IMREAD_GRAYSCALE)
73
  #img = Image.open('./a.jpg').convert('RGB').resize((640,640))
 
79
  t_img = transformer(img).unsqueeze(0)#.unsqueeze(0) #[1,1,640,640]
80
  size = torch.tensor([[t_img.shape[2], t_img.shape[3]]])
81
  #print(t_img.shape)
82
+ if cope == 'none':
83
+ labels, boxes, scores=model(t_img,size)
84
+ elif cope == 'cope12':
85
+ labels, boxes, scores=model2(t_img,size)
86
+ else
87
+ labels, boxes, scores=model3(t_img,size)
88
  #img_path = Path('./a.jpg')
89
  draw = ImageDraw.Draw(img)
90
  thrh = thr
 
105
  #save_path = Path('./output') / img_path.name
106
  return img
107
 
108
+ interface = gr.Interface(fn=detect,inputs=["image",gr.Slider(label="thr", value=0.2, maximum=1, minimum=0),gr.inputs.Radio(['none','cope12','cope24'])],outputs="image",title="rt-cope detect")
109
 
110
  interface.launch()