freddyaboulton's picture
code
6a95f1f
raw
history blame
3.33 kB
import spaces
import gradio as gr
import cv2
from PIL import Image
from transformers import RTDetrForObjectDetection, RTDetrImageProcessor
from draw_boxes import draw_bounding_boxes
image_processor = RTDetrImageProcessor.from_pretrained("PekingU/rtdetr_r50vd")
model = RTDetrForObjectDetection.from_pretrained("PekingU/rtdetr_r50vd")
@spaces.GPU
def stream_object_detection(video, conf_threshold):
cap = cv2.VideoCapture(video)
video_codec = cv2.VideoWriter_fourcc(*"x264") # type: ignore
fps = int(cap.get(cv2.CAP_PROP_FPS))
desired_fps = fps // 3
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
iterating, frame = cap.read()
n_frames = 0
n_chunks = 0
name = str(current_dir / f"output_{n_chunks}.ts")
segment_file = cv2.VideoWriter(name, video_codec, fps, (width, height)) # type: ignore
batch = []
while iterating:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
if n_frames % 3 == 0:
batch.append(frame)
if len(batch) == desired_fps:
inputs = image_processor(images=batch, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
boxes = image_processor.post_process_object_detection(
outputs,
target_sizes=torch.tensor([batch[0].shape[::-1]] * len(batch)),
threshold=conf_threshold)
for array, box in zip(batch, boxes):
pil_image = draw_bounding_boxes(Image.from_array(array), boxes[0], model, 0.3)
frame = numpy.array(pil_image)
# Convert RGB to BGR
frame = frame[:, :, ::-1].copy()
segment_file.write(frame)
segment_file.release()
n_frames = 0
n_chunks += 1
yield name
name = str(current_dir / f"output_{n_chunks}.ts")
segment_file = cv2.VideoWriter(name, video_codec, fps, (width, height)) # type: ignore
iterating, frame = cap.read()
n_frames += 1
segment_file.release()
yield name
css=""".my-group {max-width: 600px !important; max-height: 600 !important;}
.my-column {display: flex !important; justify-content: center !important; align-items: center !important};"""
with gr.Blocks(css=css) as app:
gr.HTML(
"""
<h1 style='text-align: center'>
Video Object Detection with RT-DETR
</h1>
""")
gr.HTML(
"""
<h3 style='text-align: center'>
<a href='https://arxiv.org/abs/2304.08069' target='_blank'>arXiv</a> | <a href='https://huggingface.co/PekingU/rtdetr_r101vd_coco_o365' target='_blank'>github</a>
</h3>
""")
with gr.Column(elem_classes=["my-column"]):
with gr.Group(elem_classes=["my-group"]):
video = gr.Video(label="Video Source")
conf_threshold = gr.Slider(
label="Confidence Threshold",
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.30,
)
video.upload(
fn=stream_object_detection,
inputs=[video, conf_threshold],
outputs=[video],
)
if __name__ == '__main__':
app.launch()