cheng-hust commited on
Commit
55489ff
·
verified ·
1 Parent(s): a25d5d4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -2
app.py CHANGED
@@ -7,6 +7,7 @@ import cv2
7
  import torchvision.transforms as transforms
8
  from PIL import Image, ImageDraw
9
  from src.core import YAMLConfig
 
10
  from pathlib import Path
11
 
12
  transformer = transforms.Compose([
@@ -38,9 +39,12 @@ transformer = transforms.Compose([
38
  # return tmodel
39
 
40
  class Model(nn.Module):
41
- def __init__(self, confg=None, ckpt="") -> None:
42
  super().__init__()
43
- self.cfg = YAMLConfig(confg, resume=ckpt)
 
 
 
44
  if ckpt:
45
  checkpoint = torch.load(ckpt, map_location='cpu')
46
  if 'ema' in checkpoint:
@@ -63,6 +67,7 @@ class Model(nn.Module):
63
 
64
 
65
  model = Model(confg='./configs/rtdetr/rtdetr_r101vd_6x_coco.yml',ckpt="./checkpointcope12.pth")
 
66
 
67
  #img = cv2.imread('./j.jpg',cv2.IMREAD_GRAYSCALE)
68
  #img = Image.open('./a.jpg').convert('RGB').resize((640,640))
@@ -75,6 +80,7 @@ def detect(img,thr=0.2):
75
  size = torch.tensor([[t_img.shape[2], t_img.shape[3]]])
76
  #print(t_img.shape)
77
  labels, boxes, scores=model(t_img,size)
 
78
  #img_path = Path('./a.jpg')
79
  draw = ImageDraw.Draw(img)
80
  thrh = thr
 
7
  import torchvision.transforms as transforms
8
  from PIL import Image, ImageDraw
9
  from src.core import YAMLConfig
10
+ from src2.core import YAMLConfig
11
  from pathlib import Path
12
 
13
  transformer = transforms.Compose([
 
39
  # return tmodel
40
 
41
  class Model(nn.Module):
42
+ def __init__(self, confg=None, ckpt="",cope=True) -> None:
43
  super().__init__()
44
+ if cope:
45
+ self.cfg = src.core.YAMLConfig(confg, resume=ckpt)
46
+ else:
47
+ self.cfg = src2.core.YAMLConfig(confg, resume=ckpt)
48
  if ckpt:
49
  checkpoint = torch.load(ckpt, map_location='cpu')
50
  if 'ema' in checkpoint:
 
67
 
68
 
69
  model = Model(confg='./configs/rtdetr/rtdetr_r101vd_6x_coco.yml',ckpt="./checkpointcope12.pth")
70
+ model2 = Model(confg='./configs/rtdetr/rtdetr_r101vd_6x_coco.yml',ckpt="./checkpoint_init.pth",cope=False)
71
 
72
  #img = cv2.imread('./j.jpg',cv2.IMREAD_GRAYSCALE)
73
  #img = Image.open('./a.jpg').convert('RGB').resize((640,640))
 
80
  size = torch.tensor([[t_img.shape[2], t_img.shape[3]]])
81
  #print(t_img.shape)
82
  labels, boxes, scores=model(t_img,size)
83
+ labels2, boxes2, scores2=model2(t_img,size)
84
  #img_path = Path('./a.jpg')
85
  draw = ImageDraw.Draw(img)
86
  thrh = thr