|
import gradio as gr |
|
import torch.nn as nn |
|
import torch |
|
from src.core.config import BaseConfig |
|
from src.core.yaml_utils import load_config, merge_config, create, merge_dict |
|
import cv2 |
|
import torchvision.transforms as transforms |
|
from PIL import Image, ImageDraw |
|
from src.core import YAMLConfig |
|
from pathlib import Path |
|
import src |
|
|
|
transformer = transforms.Compose([ |
|
transforms.ToTensor(), |
|
|
|
|
|
]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Model(nn.Module): |
|
def __init__(self, confg=None, ckpt="",cope=True) -> None: |
|
super().__init__() |
|
|
|
self.cfg = src.core.YAMLConfig(confg, resume=ckpt) |
|
|
|
if ckpt: |
|
checkpoint = torch.load(ckpt, map_location='cpu') |
|
if 'ema' in checkpoint: |
|
state = checkpoint['ema']['module'] |
|
else: |
|
state = checkpoint['model'] |
|
else: |
|
raise AttributeError('only support resume to load model.state_dict by now.') |
|
|
|
|
|
self.cfg.model.load_state_dict(state) |
|
|
|
self.model = self.cfg.model.deploy() |
|
self.postprocessor = self.cfg.postprocessor.deploy() |
|
|
|
|
|
def forward(self, images, orig_target_sizes): |
|
outputs = self.model(images) |
|
return self.postprocessor(outputs, orig_target_sizes) |
|
|
|
|
|
model = Model(confg='./configs/rtdetr/rtdetr_r101vd_6x_coco.yml',ckpt="./checkpoint_init.pth") |
|
model2 = Model(confg='./configs/rtdetr/rtdetr_r101vd_6x_cococope12.yml',ckpt="./checkpointcope12.pth",cope=False) |
|
model3 = Model(confg='./configs/rtdetr/rtdetr_r101vd_6x_cocococo.yml',ckpt="./rtdetrCOCO.pth",cope=False) |
|
|
|
|
|
|
|
|
|
|
|
def detect(img,thr=0.2,trained_dataset='aitod'): |
|
|
|
img = Image.fromarray(img).resize((640,640)) |
|
t_img = transformer(img).unsqueeze(0) |
|
size = torch.tensor([[t_img.shape[2], t_img.shape[3]]]) |
|
|
|
if trained_dataset == 'aitod': |
|
labels, boxes, scores=model(t_img,size) |
|
elif trained_dataset == 'ten_classes': |
|
labels, boxes, scores=model2(t_img,size) |
|
else: |
|
labels, boxes, scores=model3(t_img,size) |
|
|
|
draw = ImageDraw.Draw(img) |
|
thrh = thr |
|
|
|
for i in range(t_img.shape[0]): |
|
|
|
scr = scores[i] |
|
lab = labels[i][scr > thrh] |
|
box = boxes[i][scr > thrh] |
|
|
|
label_dict = {8:'wind-mill',7:'person',6:'vehicle',5:'swimming-pool',4:'ship',3:'storage-tank',2:'bridge',1:'airplane'} |
|
coco_dict = {1: 'person', 2: 'bicycle', 3: 'car', 4: 'motorcycle', 5: 'airplane', 6: 'bus', 7: 'train', 8: 'truck', |
|
9: 'boat', 10: 'traffic light', 11: 'fire hydrant', 13: 'stop sign', 14: 'parking meter', 15: 'bench', |
|
16: 'bird', 17: 'cat', 18: 'dog', 19: 'horse', 20: 'sheep', 21: 'cow', 22: 'elephant', 23: 'bear', 24: 'zebra', |
|
25: 'giraffe', 27: 'backpack', 28: 'umbrella', 31: 'handbag', 32: 'tie', 33: 'suitcase', 34: 'frisbee', |
|
35: 'skis', 36: 'snowboard', 37: 'sports ball', 38: 'kite', 39: 'baseball bat', 40: 'baseball glove', |
|
41: 'skateboard', 42: 'surfboard', 43: 'tennis racket', 44: 'bottle', 46: 'wine glass', 47: 'cup', |
|
48: 'fork', 49: 'knife', 50: 'spoon', 51: 'bowl', 52: 'banana', 53: 'apple', 54: 'sandwich', 55: 'orange', |
|
56: 'broccoli', 57: 'carrot', 58: 'hot dog', 59: 'pizza', 60: 'donut', 61: 'cake', 62: 'chair', 63: 'couch', |
|
64: 'potted plant', 65: 'bed', 67: 'dining table', 70: 'toilet', 72: 'tv', 73: 'laptop', 74: 'mouse', |
|
75: 'remote', 76: 'keyboard', 77: 'cell phone', 78: 'microwave', 79: 'oven', 80: 'toaster', 81: 'sink', |
|
82: 'refrigerator', 84: 'book', 85: 'clock', 86: 'vase', 87: 'scissors', 88: 'teddy bear', 89: 'hair drier', |
|
90: 'toothbrush'} |
|
|
|
label_color_dict = {8:'burlyWood',7:'red',6:'blue',5:'green',4:'yellow',3:'cyan',2:'magenta',1:'orange'} |
|
if trained_dataset != 'COCO': |
|
for idx,b in enumerate(box): |
|
label_i = lab[idx].item() |
|
draw.rectangle(list(b), outline=label_color_dict[label_i], ) |
|
draw.text((b[0], b[1]), text=label_dict[label_i], fill='blue', ) |
|
else: |
|
for idx,b in enumerate(box): |
|
label_i = lab[idx].item() |
|
draw.rectangle(list(b), outline=label_color_dict[label_i%8+1], ) |
|
draw.text((b[0], b[1]), text=coco_dict[label_i+1], fill='blue', ) |
|
|
|
|
|
return img |
|
|
|
interface = gr.Interface(fn=detect,inputs=["image",gr.Slider(label="thr", value=0.2, maximum=1, minimum=0),gr.Radio(['aitod','ten_classes','COCO'],value='aitod')],outputs="image",title="degraded hust small object detect") |
|
|
|
interface.launch() |