|
import gradio as gr |
|
import torch.nn as nn |
|
import torch |
|
from src.core.config import BaseConfig |
|
from src.core.yaml_utils import load_config, merge_config, create, merge_dict |
|
import cv2 |
|
import torchvision.transforms as transforms |
|
from PIL import Image, ImageDraw |
|
from src.core import YAMLConfig |
|
from pathlib import Path |
|
|
|
transformer = transforms.Compose([ |
|
transforms.ToTensor(), |
|
|
|
|
|
]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Model(nn.Module): |
|
def __init__(self, confg=None, ckpt="") -> None: |
|
super().__init__() |
|
self.cfg = YAMLConfig(confg, resume=ckpt) |
|
if ckpt: |
|
checkpoint = torch.load(ckpt, map_location='cpu') |
|
if 'ema' in checkpoint: |
|
state = checkpoint['ema']['module'] |
|
else: |
|
state = checkpoint['model'] |
|
else: |
|
raise AttributeError('only support resume to load model.state_dict by now.') |
|
|
|
|
|
self.cfg.model.load_state_dict(state) |
|
|
|
self.model = self.cfg.model.deploy() |
|
self.postprocessor = self.cfg.postprocessor.deploy() |
|
|
|
|
|
def forward(self, images, orig_target_sizes): |
|
outputs = self.model(images) |
|
return self.postprocessor(outputs, orig_target_sizes) |
|
|
|
|
|
model = Model(confg='./configs/rtdetr/rtdetr_r101vd_6x_coco.yml',ckpt="./checkpointcope12.pth") |
|
|
|
|
|
|
|
|
|
|
|
def detect(img,thr): |
|
|
|
img = Image.fromarray(img).resize((640,640)) |
|
t_img = transformer(img).unsqueeze(0) |
|
size = torch.tensor([[t_img.shape[2], t_img.shape[3]]]) |
|
|
|
labels, boxes, scores=model(t_img,size) |
|
|
|
draw = ImageDraw.Draw(img) |
|
thrh = thr |
|
|
|
for i in range(t_img.shape[0]): |
|
|
|
scr = scores[i] |
|
lab = labels[i][scr > thrh] |
|
box = boxes[i][scr > thrh] |
|
|
|
|
|
for b in box: |
|
draw.rectangle(list(b), outline='red', ) |
|
draw.text((b[0], b[1]), text=str(lab[i]), fill='blue', ) |
|
|
|
|
|
return img |
|
|
|
interface = gr.Interface(fn=detect,inputs=["image",gr.Slider(label=“thr”, value=0.2, maximum=1, minimum=0)],outputs="image",title="rt-cope detect") |
|
|
|
interface.launch() |