Dref360's picture
use local images
1415779
import cv2
import gradio as gr
import numpy as np
import supervision as sv
import torch
from PIL import Image
from transformers import (
RTDetrForObjectDetection,
RTDetrImageProcessor,
VitPoseForPoseEstimation,
VitPoseImageProcessor,
pipeline,
)
KEYPOINT_LABEL_MAP = {
0: "Nose",
1: "L_Eye",
2: "R_Eye",
3: "L_Ear",
4: "R_Ear",
5: "L_Shoulder",
6: "R_Shoulder",
7: "L_Elbow",
8: "R_Elbow",
9: "L_Wrist",
10: "R_Wrist",
11: "L_Hip",
12: "R_Hip",
13: "L_Knee",
14: "R_Knee",
15: "L_Ankle",
16: "R_Ankle",
}
class InteractionDetector:
def __init__(self):
self.person_detector = None
self.person_processor = None
self.pose_model = None
self.pose_processor = None
self.depth_model = None
self.segmentation_model = None
self.interaction_threshold = 2
self.load_models()
def load_models(self):
"""Load all required models"""
# Person detection model
self.person_processor = RTDetrImageProcessor.from_pretrained("PekingU/rtdetr_r50vd_coco_o365")
self.person_detector = RTDetrForObjectDetection.from_pretrained("PekingU/rtdetr_r50vd_coco_o365")
# Pose estimation model
self.pose_processor = VitPoseImageProcessor.from_pretrained("nielsr/vitpose-base-simple")
self.pose_model = VitPoseForPoseEstimation.from_pretrained("nielsr/vitpose-base-simple")
# Depth estimation model
self.depth_model = pipeline("depth-estimation", model="depth-anything/Depth-Anything-V2-Small-hf")
# Semantic segmentation model
self.segmentation_model = pipeline("image-segmentation", model="facebook/maskformer-swin-base-ade")
self.segmentation_id2label = self.segmentation_model.model.config.id2label
self.segmentation_label2id = {v: k for k, v in self.segmentation_model.model.config.id2label.items()}
def get_nearest_pixel_class(self, joint, depth_map, segmentation_map):
"""
Find the nearest pixel of a specific class to a given joint coordinate
Args:
joint: (x, y) coordinates of the joint
depth_map: Depth map
segmentation_map: Semantic segmentation results
Returns:
tuple: class_name of nearest pixel, distance to that pixel
"""
PERSON_ID = 12
grid_x, grid_y = np.meshgrid(np.arange(depth_map.shape[0]), np.arange(depth_map.shape[1]))
dist_x = np.abs(grid_x.T - joint[1])
dist_y = np.abs(grid_y.T - joint[0])
dist_coord = dist_x + dist_y
depth_dist = np.abs(depth_map - depth_map[joint[1], joint[0]])
depth_dist[(segmentation_map == PERSON_ID) | (dist_coord > 50)] = 255
min_dist = np.unravel_index(np.argmin(depth_dist), depth_dist.shape)
return segmentation_map[min_dist], depth_dist[min_dist]
def detect_persons(self, image: Image.Image):
"""Detect persons in the image"""
inputs = self.person_processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = self.person_detector(**inputs)
results = self.person_processor.post_process_object_detection(
outputs,
target_sizes=torch.tensor([(image.height, image.width)]),
threshold=0.3
)
boxes = results[0]["boxes"][results[0]["labels"] == 0]
scores = results[0]["scores"][results[0]["labels"] == 0]
return boxes.cpu().numpy(), scores.cpu().numpy()
def detect_keypoints(self, image: Image.Image):
"""Detect keypoints in the image"""
boxes, scores = self.detect_persons(image)
pixel_values = self.pose_processor(image, boxes=[boxes], return_tensors="pt").pixel_values
with torch.no_grad():
outputs = self.pose_model(pixel_values)
pose_results = self.pose_processor.post_process_pose_estimation(outputs, boxes=[boxes])[0]
return pose_results, boxes, scores
def estimate_depth(self, image: Image.Image):
"""Estimate depth for the image"""
with torch.no_grad():
depth_map = np.array(self.depth_model(image)['depth'])
return depth_map
def segment_image(self, image: Image.Image):
"""Perform semantic segmentation on the image"""
with torch.no_grad():
segmentation_map = self.segmentation_model(image)
result = np.zeros(np.array(image).shape[:2], dtype=np.uint8)
print("Found", [l['label'] for l in segmentation_map])
for cls_item in sorted(segmentation_map, key=lambda l: np.sum(l['mask']), reverse=True):
result[np.array(cls_item['mask']) > 0] = self.segmentation_label2id[cls_item['label']]
return result
def detect_wall_interaction(self, image: Image.Image):
"""Detect if hands are touching walls"""
# Get all necessary information
pose_results, boxes, scores = self.detect_keypoints(image)
depth_map = self.estimate_depth(image)
segmentation_map = self.segment_image(image)
interactions = []
for person_idx, pose_result in enumerate(pose_results):
# Get hand keypoints
right_hand = pose_result["keypoints"][10].numpy().astype(int)
left_hand = pose_result["keypoints"][9].numpy().astype(int)
# Find nearest anything pixels
right_cls, r_distance = self.get_nearest_pixel_class(right_hand[:2], depth_map, segmentation_map)
left_cls, l_distance = self.get_nearest_pixel_class(left_hand[:2], depth_map, segmentation_map)
# Check for interactions
right_touching = r_distance < self.interaction_threshold
left_touching = l_distance < self.interaction_threshold
interactions.append({
"person_id": person_idx,
"right_hand_touching_object": self.segmentation_id2label[right_cls],
"left_hand_touching_object": self.segmentation_id2label[left_cls],
"right_hand_touching": right_touching,
"left_hand_touching": left_touching,
"right_hand_distance": r_distance,
"left_hand_distance": l_distance
})
return interactions, pose_results, segmentation_map, depth_map
def visualize_results(self, image: Image.Image, interactions, pose_results):
"""Visualize detection results"""
# Create base visualization from original image
vis_image = np.array(image).copy()
# Add pose keypoints
edge_annotator = sv.EdgeAnnotator(color=sv.Color.GREEN, thickness=2)
key_points = sv.KeyPoints(
xy=torch.cat([pose_result['keypoints'].unsqueeze(0) for pose_result in pose_results]).cpu().numpy()
)
vis_image = edge_annotator.annotate(scene=vis_image, key_points=key_points)
# Add interaction indicators
for interaction in interactions:
person_id = interaction["person_id"]
pose_result = pose_results[person_id]
# Draw indicators for touching hands
if interaction["right_hand_touching"]:
cv2.circle(vis_image,
tuple(map(int, pose_result["keypoints"][10][:2])),
10, (0, 0, 255), -1)
if interaction["left_hand_touching"]:
cv2.circle(vis_image,
tuple(map(int, pose_result["keypoints"][9][:2])),
10, (0, 0, 255), -1)
return Image.fromarray(vis_image)
def process_image(self, input_image):
"""Process image and return visualization with interaction detection"""
if input_image is None:
return None, ""
# Convert to PIL Image if necessary
if isinstance(input_image, np.ndarray):
image = Image.fromarray(input_image)
else:
image = input_image
image = image.resize((1280, 720))
# Detect interactions
interactions, pose_results, segmentation_map, depth_map = self.detect_wall_interaction(image)
# Visualize results
result_image = self.visualize_results(image, interactions, pose_results)
# Create interaction information text
info_text = []
for interaction in interactions:
info_text.append(f"\nPerson {interaction['person_id'] + 1}:")
if interaction["right_hand_touching"]:
info_text.append(f"Right hand is touching {interaction['right_hand_touching_object']}")
if interaction["left_hand_touching"]:
info_text.append(f"Left hand is touching {interaction['left_hand_touching_object']}")
info_text.append(f"Right hand distance to wall: {interaction['right_hand_distance']:.2f}")
info_text.append(f"Left hand distance to wall: {interaction['left_hand_distance']:.2f}")
# Add color to segmentation
mask = np.zeros((*segmentation_map.shape, 3), dtype=np.uint8)
colors = np.random.randint(0, 255, size=(100, 3))
for cl_id in np.unique(segmentation_map):
mask_array = np.array(segmentation_map == cl_id)
color = colors[cl_id % len(colors)]
mask[mask_array] = color
return result_image, mask, depth_map, "\n".join(info_text)
def create_gradio_interface():
"""Create Gradio interface"""
detector = InteractionDetector()
with gr.Blocks() as interface:
gr.Markdown("# Object Interaction Detection")
gr.Markdown("Upload an image to detect when people are touching objects.")
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Input Image")
process_button = gr.Button("Detect Interactions")
with gr.Column():
output_image = gr.Image(label="Detection Results")
interaction_info = gr.Textbox(
label="Interaction Information",
lines=10,
placeholder="Interaction details will appear here..."
)
segmentation_im = gr.Image(label="Segmentaiton Results")
depth_im = gr.Image(label="Depth Results")
process_button.click(
fn=detector.process_image,
inputs=input_image,
outputs=[output_image, segmentation_im, depth_im, interaction_info]
)
gr.Examples(
examples=[
"images/1-8ea4418f.jpg",
"images/276757975.jpg"
],
inputs=input_image
)
return interface
interface = create_gradio_interface()
if __name__ == "__main__":
interface.launch(debug=True)