Spaces:
Running
Running
import cv2 | |
import gradio as gr | |
import numpy as np | |
import supervision as sv | |
import torch | |
from PIL import Image | |
from transformers import ( | |
RTDetrForObjectDetection, | |
RTDetrImageProcessor, | |
VitPoseForPoseEstimation, | |
VitPoseImageProcessor, | |
pipeline, | |
) | |
KEYPOINT_LABEL_MAP = { | |
0: "Nose", | |
1: "L_Eye", | |
2: "R_Eye", | |
3: "L_Ear", | |
4: "R_Ear", | |
5: "L_Shoulder", | |
6: "R_Shoulder", | |
7: "L_Elbow", | |
8: "R_Elbow", | |
9: "L_Wrist", | |
10: "R_Wrist", | |
11: "L_Hip", | |
12: "R_Hip", | |
13: "L_Knee", | |
14: "R_Knee", | |
15: "L_Ankle", | |
16: "R_Ankle", | |
} | |
class InteractionDetector: | |
def __init__(self): | |
self.person_detector = None | |
self.person_processor = None | |
self.pose_model = None | |
self.pose_processor = None | |
self.depth_model = None | |
self.segmentation_model = None | |
self.interaction_threshold = 2 | |
self.load_models() | |
def load_models(self): | |
"""Load all required models""" | |
# Person detection model | |
self.person_processor = RTDetrImageProcessor.from_pretrained("PekingU/rtdetr_r50vd_coco_o365") | |
self.person_detector = RTDetrForObjectDetection.from_pretrained("PekingU/rtdetr_r50vd_coco_o365") | |
# Pose estimation model | |
self.pose_processor = VitPoseImageProcessor.from_pretrained("nielsr/vitpose-base-simple") | |
self.pose_model = VitPoseForPoseEstimation.from_pretrained("nielsr/vitpose-base-simple") | |
# Depth estimation model | |
self.depth_model = pipeline("depth-estimation", model="depth-anything/Depth-Anything-V2-Small-hf") | |
# Semantic segmentation model | |
self.segmentation_model = pipeline("image-segmentation", model="facebook/maskformer-swin-base-ade") | |
self.segmentation_id2label = self.segmentation_model.model.config.id2label | |
self.segmentation_label2id = {v: k for k, v in self.segmentation_model.model.config.id2label.items()} | |
def get_nearest_pixel_class(self, joint, depth_map, segmentation_map): | |
""" | |
Find the nearest pixel of a specific class to a given joint coordinate | |
Args: | |
joint: (x, y) coordinates of the joint | |
depth_map: Depth map | |
segmentation_map: Semantic segmentation results | |
Returns: | |
tuple: class_name of nearest pixel, distance to that pixel | |
""" | |
PERSON_ID = 12 | |
grid_x, grid_y = np.meshgrid(np.arange(depth_map.shape[0]), np.arange(depth_map.shape[1])) | |
dist_x = np.abs(grid_x.T - joint[1]) | |
dist_y = np.abs(grid_y.T - joint[0]) | |
dist_coord = dist_x + dist_y | |
depth_dist = np.abs(depth_map - depth_map[joint[1], joint[0]]) | |
depth_dist[(segmentation_map == PERSON_ID) | (dist_coord > 50)] = 255 | |
min_dist = np.unravel_index(np.argmin(depth_dist), depth_dist.shape) | |
return segmentation_map[min_dist], depth_dist[min_dist] | |
def detect_persons(self, image: Image.Image): | |
"""Detect persons in the image""" | |
inputs = self.person_processor(images=image, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = self.person_detector(**inputs) | |
results = self.person_processor.post_process_object_detection( | |
outputs, | |
target_sizes=torch.tensor([(image.height, image.width)]), | |
threshold=0.3 | |
) | |
boxes = results[0]["boxes"][results[0]["labels"] == 0] | |
scores = results[0]["scores"][results[0]["labels"] == 0] | |
return boxes.cpu().numpy(), scores.cpu().numpy() | |
def detect_keypoints(self, image: Image.Image): | |
"""Detect keypoints in the image""" | |
boxes, scores = self.detect_persons(image) | |
pixel_values = self.pose_processor(image, boxes=[boxes], return_tensors="pt").pixel_values | |
with torch.no_grad(): | |
outputs = self.pose_model(pixel_values) | |
pose_results = self.pose_processor.post_process_pose_estimation(outputs, boxes=[boxes])[0] | |
return pose_results, boxes, scores | |
def estimate_depth(self, image: Image.Image): | |
"""Estimate depth for the image""" | |
with torch.no_grad(): | |
depth_map = np.array(self.depth_model(image)['depth']) | |
return depth_map | |
def segment_image(self, image: Image.Image): | |
"""Perform semantic segmentation on the image""" | |
with torch.no_grad(): | |
segmentation_map = self.segmentation_model(image) | |
result = np.zeros(np.array(image).shape[:2], dtype=np.uint8) | |
print("Found", [l['label'] for l in segmentation_map]) | |
for cls_item in sorted(segmentation_map, key=lambda l: np.sum(l['mask']), reverse=True): | |
result[np.array(cls_item['mask']) > 0] = self.segmentation_label2id[cls_item['label']] | |
return result | |
def detect_wall_interaction(self, image: Image.Image): | |
"""Detect if hands are touching walls""" | |
# Get all necessary information | |
pose_results, boxes, scores = self.detect_keypoints(image) | |
depth_map = self.estimate_depth(image) | |
segmentation_map = self.segment_image(image) | |
interactions = [] | |
for person_idx, pose_result in enumerate(pose_results): | |
# Get hand keypoints | |
right_hand = pose_result["keypoints"][10].numpy().astype(int) | |
left_hand = pose_result["keypoints"][9].numpy().astype(int) | |
# Find nearest anything pixels | |
right_cls, r_distance = self.get_nearest_pixel_class(right_hand[:2], depth_map, segmentation_map) | |
left_cls, l_distance = self.get_nearest_pixel_class(left_hand[:2], depth_map, segmentation_map) | |
# Check for interactions | |
right_touching = r_distance < self.interaction_threshold | |
left_touching = l_distance < self.interaction_threshold | |
interactions.append({ | |
"person_id": person_idx, | |
"right_hand_touching_object": self.segmentation_id2label[right_cls], | |
"left_hand_touching_object": self.segmentation_id2label[left_cls], | |
"right_hand_touching": right_touching, | |
"left_hand_touching": left_touching, | |
"right_hand_distance": r_distance, | |
"left_hand_distance": l_distance | |
}) | |
return interactions, pose_results, segmentation_map, depth_map | |
def visualize_results(self, image: Image.Image, interactions, pose_results): | |
"""Visualize detection results""" | |
# Create base visualization from original image | |
vis_image = np.array(image).copy() | |
# Add pose keypoints | |
edge_annotator = sv.EdgeAnnotator(color=sv.Color.GREEN, thickness=2) | |
key_points = sv.KeyPoints( | |
xy=torch.cat([pose_result['keypoints'].unsqueeze(0) for pose_result in pose_results]).cpu().numpy() | |
) | |
vis_image = edge_annotator.annotate(scene=vis_image, key_points=key_points) | |
# Add interaction indicators | |
for interaction in interactions: | |
person_id = interaction["person_id"] | |
pose_result = pose_results[person_id] | |
# Draw indicators for touching hands | |
if interaction["right_hand_touching"]: | |
cv2.circle(vis_image, | |
tuple(map(int, pose_result["keypoints"][10][:2])), | |
10, (0, 0, 255), -1) | |
if interaction["left_hand_touching"]: | |
cv2.circle(vis_image, | |
tuple(map(int, pose_result["keypoints"][9][:2])), | |
10, (0, 0, 255), -1) | |
return Image.fromarray(vis_image) | |
def process_image(self, input_image): | |
"""Process image and return visualization with interaction detection""" | |
if input_image is None: | |
return None, "" | |
# Convert to PIL Image if necessary | |
if isinstance(input_image, np.ndarray): | |
image = Image.fromarray(input_image) | |
else: | |
image = input_image | |
image = image.resize((1280, 720)) | |
# Detect interactions | |
interactions, pose_results, segmentation_map, depth_map = self.detect_wall_interaction(image) | |
# Visualize results | |
result_image = self.visualize_results(image, interactions, pose_results) | |
# Create interaction information text | |
info_text = [] | |
for interaction in interactions: | |
info_text.append(f"\nPerson {interaction['person_id'] + 1}:") | |
if interaction["right_hand_touching"]: | |
info_text.append(f"Right hand is touching {interaction['right_hand_touching_object']}") | |
if interaction["left_hand_touching"]: | |
info_text.append(f"Left hand is touching {interaction['left_hand_touching_object']}") | |
info_text.append(f"Right hand distance to wall: {interaction['right_hand_distance']:.2f}") | |
info_text.append(f"Left hand distance to wall: {interaction['left_hand_distance']:.2f}") | |
# Add color to segmentation | |
mask = np.zeros((*segmentation_map.shape, 3), dtype=np.uint8) | |
colors = np.random.randint(0, 255, size=(100, 3)) | |
for cl_id in np.unique(segmentation_map): | |
mask_array = np.array(segmentation_map == cl_id) | |
color = colors[cl_id % len(colors)] | |
mask[mask_array] = color | |
return result_image, mask, depth_map, "\n".join(info_text) | |
def create_gradio_interface(): | |
"""Create Gradio interface""" | |
detector = InteractionDetector() | |
with gr.Blocks() as interface: | |
gr.Markdown("# Object Interaction Detection") | |
gr.Markdown("Upload an image to detect when people are touching objects.") | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(label="Input Image") | |
process_button = gr.Button("Detect Interactions") | |
with gr.Column(): | |
output_image = gr.Image(label="Detection Results") | |
interaction_info = gr.Textbox( | |
label="Interaction Information", | |
lines=10, | |
placeholder="Interaction details will appear here..." | |
) | |
segmentation_im = gr.Image(label="Segmentaiton Results") | |
depth_im = gr.Image(label="Depth Results") | |
process_button.click( | |
fn=detector.process_image, | |
inputs=input_image, | |
outputs=[output_image, segmentation_im, depth_im, interaction_info] | |
) | |
gr.Examples( | |
examples=[ | |
"images/1-8ea4418f.jpg", | |
"images/276757975.jpg" | |
], | |
inputs=input_image | |
) | |
return interface | |
interface = create_gradio_interface() | |
if __name__ == "__main__": | |
interface.launch(debug=True) | |