|
import spaces |
|
import gradio as gr |
|
import cv2 |
|
from PIL import Image |
|
|
|
from transformers import RTDetrForObjectDetection, RTDetrImageProcessor |
|
|
|
from draw_boxes import draw_bounding_boxes |
|
|
|
image_processor = RTDetrImageProcessor.from_pretrained("PekingU/rtdetr_r50vd") |
|
model = RTDetrForObjectDetection.from_pretrained("PekingU/rtdetr_r50vd") |
|
|
|
@spaces.GPU |
|
def stream_object_detection(video, conf_threshold): |
|
cap = cv2.VideoCapture(video) |
|
|
|
video_codec = cv2.VideoWriter_fourcc(*"x264") |
|
fps = int(cap.get(cv2.CAP_PROP_FPS)) |
|
desired_fps = fps // 3 |
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
|
|
iterating, frame = cap.read() |
|
|
|
n_frames = 0 |
|
n_chunks = 0 |
|
name = str(current_dir / f"output_{n_chunks}.ts") |
|
segment_file = cv2.VideoWriter(name, video_codec, fps, (width, height)) |
|
batch = [] |
|
|
|
while iterating: |
|
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
if n_frames % 3 == 0: |
|
batch.append(frame) |
|
if len(batch) == desired_fps: |
|
inputs = image_processor(images=batch, return_tensors="pt") |
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
|
|
boxes = image_processor.post_process_object_detection( |
|
outputs, |
|
target_sizes=torch.tensor([batch[0].shape[::-1]] * len(batch)), |
|
threshold=conf_threshold) |
|
|
|
for array, box in zip(batch, boxes): |
|
pil_image = draw_bounding_boxes(Image.from_array(array), boxes[0], model, 0.3) |
|
frame = numpy.array(pil_image) |
|
|
|
frame = frame[:, :, ::-1].copy() |
|
segment_file.write(frame) |
|
|
|
segment_file.release() |
|
n_frames = 0 |
|
n_chunks += 1 |
|
yield name |
|
name = str(current_dir / f"output_{n_chunks}.ts") |
|
segment_file = cv2.VideoWriter(name, video_codec, fps, (width, height)) |
|
|
|
iterating, frame = cap.read() |
|
n_frames += 1 |
|
|
|
segment_file.release() |
|
yield name |
|
|
|
|
|
css=""".my-group {max-width: 600px !important; max-height: 600 !important;} |
|
.my-column {display: flex !important; justify-content: center !important; align-items: center !important};""" |
|
|
|
|
|
with gr.Blocks(css=css) as app: |
|
gr.HTML( |
|
""" |
|
<h1 style='text-align: center'> |
|
Video Object Detection with RT-DETR |
|
</h1> |
|
""") |
|
gr.HTML( |
|
""" |
|
<h3 style='text-align: center'> |
|
<a href='https://arxiv.org/abs/2304.08069' target='_blank'>arXiv</a> | <a href='https://huggingface.co/PekingU/rtdetr_r101vd_coco_o365' target='_blank'>github</a> |
|
</h3> |
|
""") |
|
with gr.Column(elem_classes=["my-column"]): |
|
with gr.Group(elem_classes=["my-group"]): |
|
video = gr.Video(label="Video Source") |
|
conf_threshold = gr.Slider( |
|
label="Confidence Threshold", |
|
minimum=0.0, |
|
maximum=1.0, |
|
step=0.05, |
|
value=0.30, |
|
) |
|
video.upload( |
|
fn=stream_object_detection, |
|
inputs=[video, conf_threshold], |
|
outputs=[video], |
|
) |
|
|
|
if __name__ == '__main__': |
|
app.launch() |
|
|