cheng-hust's picture
Update app.py
f29c1af verified
raw
history blame
3.06 kB
import gradio as gr
import torch.nn as nn
import torch
from src.core.config import BaseConfig
from src.core.yaml_utils import load_config, merge_config, create, merge_dict
import cv2
import torchvision.transforms as transforms
from PIL import Image, ImageDraw
from src.core import YAMLConfig
from pathlib import Path
transformer = transforms.Compose([
transforms.ToTensor(),
#transforms.Resize([640,640]),
])
#def model(yaml_cfg) -> torch.nn.Module:
# if 'model' in yaml_cfg:
# merge_config(yaml_cfg)
# model = create(yaml_cfg['model'])
# return model
#
#
#
#def get_model(cfg_path='./configs/rtdetr/rtdetr_r101vd_6x_coco.yml',ckpt="./logs/checkpointcope12.pth"):
# yaml_cfg = load_config(cfg_path)
# merge_dict(yaml_cfg, {})
# tmodel = model(yaml_cfg)
# checkpoint = torch.load(ckpt, map_location='cpu')
# if 'ema' in checkpoint:
# state = checkpoint['ema']['module']
# else:
# state = checkpoint['model']
#
# tmodel.load_state_dict(state)
# return tmodel
class Model(nn.Module):
def __init__(self, confg=None, ckpt="") -> None:
super().__init__()
self.cfg = YAMLConfig(confg, resume=ckpt)
if ckpt:
checkpoint = torch.load(ckpt, map_location='cpu')
if 'ema' in checkpoint:
state = checkpoint['ema']['module']
else:
state = checkpoint['model']
else:
raise AttributeError('only support resume to load model.state_dict by now.')
# NOTE load train mode state -> convert to deploy mode
self.cfg.model.load_state_dict(state)
self.model = self.cfg.model.deploy()
self.postprocessor = self.cfg.postprocessor.deploy()
# print(self.postprocessor.deploy_mode)
def forward(self, images, orig_target_sizes):
outputs = self.model(images)
return self.postprocessor(outputs, orig_target_sizes)
model = Model(confg='./configs/rtdetr/rtdetr_r101vd_6x_coco.yml',ckpt="./checkpointcope12.pth")
#img = cv2.imread('./j.jpg',cv2.IMREAD_GRAYSCALE)
#img = Image.open('./a.jpg').convert('RGB').resize((640,640))
def detect(img,thr):
#print(img) #ndarray
img = Image.fromarray(img).resize((640,640))
t_img = transformer(img).unsqueeze(0)#.unsqueeze(0) #[1,1,640,640]
size = torch.tensor([[t_img.shape[2], t_img.shape[3]]])
#print(t_img.shape)
labels, boxes, scores=model(t_img,size)
#img_path = Path('./a.jpg')
draw = ImageDraw.Draw(img)
thrh = thr
for i in range(t_img.shape[0]):
scr = scores[i]
lab = labels[i][scr > thrh]
box = boxes[i][scr > thrh]
for b in box:
draw.rectangle(list(b), outline='red', )
draw.text((b[0], b[1]), text=str(lab[i]), fill='blue', )
#save_path = Path('./output') / img_path.name
return img
interface = gr.Interface(fn=detect,inputs=["image",gr.Slider(label=“thr”, value=0.2, maximum=1, minimum=0)],outputs="image",title="rt-cope detect")
interface.launch()