Spaces:
Sleeping
Sleeping
| import spaces | |
| import gradio as gr | |
| import cv2 | |
| from PIL import Image | |
| import torch | |
| import time | |
| import numpy as np | |
| from transformers import RTDetrForObjectDetection, RTDetrImageProcessor | |
| from draw_boxes import draw_bounding_boxes | |
| image_processor = RTDetrImageProcessor.from_pretrained("PekingU/rtdetr_r50vd") | |
| model = RTDetrForObjectDetection.from_pretrained("PekingU/rtdetr_r50vd") | |
| def stream_object_detection(video, conf_threshold): | |
| cap = cv2.VideoCapture(video) | |
| video_codec = cv2.VideoWriter_fourcc(*"x264") # type: ignore | |
| fps = int(cap.get(cv2.CAP_PROP_FPS)) | |
| desired_fps = fps // 5 | |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| iterating, frame = cap.read() | |
| n_frames = 0 | |
| n_chunks = 0 | |
| name = f"output_{n_chunks}.ts" | |
| segment_file = cv2.VideoWriter(name, video_codec, fps, (width, height)) # type: ignore | |
| batch = [] | |
| while iterating: | |
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| if n_frames % 5 == 0: | |
| batch.append(frame) | |
| if len(batch) == 2 * desired_fps: | |
| inputs = image_processor(images=batch, return_tensors="pt") | |
| print(f"starting batch of size {len(batch)}") | |
| start = time.time() | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| end = time.time() | |
| print("time taken ", end - start) | |
| boxes = image_processor.post_process_object_detection( | |
| outputs, | |
| target_sizes=torch.tensor([frame[0].shape[:2][::-1]] * len(batch)), | |
| threshold=conf_threshold) | |
| for array, box in zip(batch, boxes): | |
| pil_image = draw_bounding_boxes(Image.fromarray(array), box, model, conf_threshold) | |
| frame = np.array(pil_image) | |
| # Convert RGB to BGR | |
| frame = frame[:, :, ::-1].copy() | |
| segment_file.write(frame) | |
| segment_file.release() | |
| n_frames = 0 | |
| n_chunks += 1 | |
| yield name | |
| name = f"output_{n_chunks}.ts" | |
| segment_file = cv2.VideoWriter(name, video_codec, fps, (width, height)) # type: ignore | |
| iterating, frame = cap.read() | |
| n_frames += 1 | |
| segment_file.release() | |
| yield name | |
| css=""".my-group {max-width: 600px !important; max-height: 600 !important;} | |
| .my-column {display: flex !important; justify-content: center !important; align-items: center !important};""" | |
| with gr.Blocks(css=css) as app: | |
| gr.HTML( | |
| """ | |
| <h1 style='text-align: center'> | |
| Video Object Detection with RT-DETR | |
| </h1> | |
| """) | |
| gr.HTML( | |
| """ | |
| <h3 style='text-align: center'> | |
| <a href='https://arxiv.org/abs/2304.08069' target='_blank'>arXiv</a> | <a href='https://huggingface.co/PekingU/rtdetr_r101vd_coco_o365' target='_blank'>github</a> | |
| </h3> | |
| """) | |
| with gr.Column(elem_classes=["my-column"]): | |
| with gr.Group(elem_classes=["my-group"]): | |
| video = gr.Video(label="Video Source", streaming=True, autoplay=True) | |
| conf_threshold = gr.Slider( | |
| label="Confidence Threshold", | |
| minimum=0.0, | |
| maximum=1.0, | |
| step=0.05, | |
| value=0.30, | |
| ) | |
| video.upload( | |
| fn=stream_object_detection, | |
| inputs=[video, conf_threshold], | |
| outputs=[video], | |
| ) | |
| if __name__ == '__main__': | |
| app.launch() | |