Update app.py
Browse files
app.py
CHANGED
@@ -41,10 +41,9 @@ transformer = transforms.Compose([
|
|
41 |
class Model(nn.Module):
|
42 |
def __init__(self, confg=None, ckpt="",cope=True) -> None:
|
43 |
super().__init__()
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
self.cfg = src2.core.YAMLConfig(confg, resume=ckpt)
|
48 |
if ckpt:
|
49 |
checkpoint = torch.load(ckpt, map_location='cpu')
|
50 |
if 'ema' in checkpoint:
|
@@ -66,8 +65,9 @@ class Model(nn.Module):
|
|
66 |
return self.postprocessor(outputs, orig_target_sizes)
|
67 |
|
68 |
|
69 |
-
model = Model(confg='./configs/rtdetr/rtdetr_r101vd_6x_coco.yml',ckpt="./
|
70 |
-
model2 = Model(confg='./configs/rtdetr/
|
|
|
71 |
|
72 |
#img = cv2.imread('./j.jpg',cv2.IMREAD_GRAYSCALE)
|
73 |
#img = Image.open('./a.jpg').convert('RGB').resize((640,640))
|
@@ -79,8 +79,12 @@ def detect(img,thr=0.2,cope='none'):
|
|
79 |
t_img = transformer(img).unsqueeze(0)#.unsqueeze(0) #[1,1,640,640]
|
80 |
size = torch.tensor([[t_img.shape[2], t_img.shape[3]]])
|
81 |
#print(t_img.shape)
|
82 |
-
|
83 |
-
|
|
|
|
|
|
|
|
|
84 |
#img_path = Path('./a.jpg')
|
85 |
draw = ImageDraw.Draw(img)
|
86 |
thrh = thr
|
@@ -101,6 +105,6 @@ def detect(img,thr=0.2,cope='none'):
|
|
101 |
#save_path = Path('./output') / img_path.name
|
102 |
return img
|
103 |
|
104 |
-
interface = gr.Interface(fn=detect,inputs=["image",gr.Slider(label="thr", value=0.2, maximum=1, minimum=0)],outputs="image",title="rt-cope detect")
|
105 |
|
106 |
interface.launch()
|
|
|
41 |
class Model(nn.Module):
|
42 |
def __init__(self, confg=None, ckpt="",cope=True) -> None:
|
43 |
super().__init__()
|
44 |
+
|
45 |
+
self.cfg = src.core.YAMLConfig(confg, resume=ckpt)
|
46 |
+
|
|
|
47 |
if ckpt:
|
48 |
checkpoint = torch.load(ckpt, map_location='cpu')
|
49 |
if 'ema' in checkpoint:
|
|
|
65 |
return self.postprocessor(outputs, orig_target_sizes)
|
66 |
|
67 |
|
68 |
+
model = Model(confg='./configs/rtdetr/rtdetr_r101vd_6x_coco.yml',ckpt="./checkpoint_init.pth")
|
69 |
+
model2 = Model(confg='./configs/rtdetr/rtdetr_r101vd_6x_coco_cope12.yml',ckpt="./checkpointcope12.pth",cope=False)
|
70 |
+
model3 = Model(confg='./configs/rtdetr/rtdetr_r101vd_6x_coco_cope24.yml',ckpt="./checkpointcope24.pth",cope=False)
|
71 |
|
72 |
#img = cv2.imread('./j.jpg',cv2.IMREAD_GRAYSCALE)
|
73 |
#img = Image.open('./a.jpg').convert('RGB').resize((640,640))
|
|
|
79 |
t_img = transformer(img).unsqueeze(0)#.unsqueeze(0) #[1,1,640,640]
|
80 |
size = torch.tensor([[t_img.shape[2], t_img.shape[3]]])
|
81 |
#print(t_img.shape)
|
82 |
+
if cope == 'none':
|
83 |
+
labels, boxes, scores=model(t_img,size)
|
84 |
+
elif cope == 'cope12':
|
85 |
+
labels, boxes, scores=model2(t_img,size)
|
86 |
+
else
|
87 |
+
labels, boxes, scores=model3(t_img,size)
|
88 |
#img_path = Path('./a.jpg')
|
89 |
draw = ImageDraw.Draw(img)
|
90 |
thrh = thr
|
|
|
105 |
#save_path = Path('./output') / img_path.name
|
106 |
return img
|
107 |
|
108 |
+
interface = gr.Interface(fn=detect,inputs=["image",gr.Slider(label="thr", value=0.2, maximum=1, minimum=0),gr.inputs.Radio(['none','cope12','cope24'])],outputs="image",title="rt-cope detect")
|
109 |
|
110 |
interface.launch()
|