Dref360's picture
use local images
1415779
raw
history blame
10.8 kB
import cv2
import gradio as gr
import numpy as np
import supervision as sv
import torch
from PIL import Image
from transformers import (
RTDetrForObjectDetection,
RTDetrImageProcessor,
VitPoseForPoseEstimation,
VitPoseImageProcessor,
pipeline,
)
KEYPOINT_LABEL_MAP = {
0: "Nose",
1: "L_Eye",
2: "R_Eye",
3: "L_Ear",
4: "R_Ear",
5: "L_Shoulder",
6: "R_Shoulder",
7: "L_Elbow",
8: "R_Elbow",
9: "L_Wrist",
10: "R_Wrist",
11: "L_Hip",
12: "R_Hip",
13: "L_Knee",
14: "R_Knee",
15: "L_Ankle",
16: "R_Ankle",
}
class InteractionDetector:
def __init__(self):
self.person_detector = None
self.person_processor = None
self.pose_model = None
self.pose_processor = None
self.depth_model = None
self.segmentation_model = None
self.interaction_threshold = 2
self.load_models()
def load_models(self):
"""Load all required models"""
# Person detection model
self.person_processor = RTDetrImageProcessor.from_pretrained("PekingU/rtdetr_r50vd_coco_o365")
self.person_detector = RTDetrForObjectDetection.from_pretrained("PekingU/rtdetr_r50vd_coco_o365")
# Pose estimation model
self.pose_processor = VitPoseImageProcessor.from_pretrained("nielsr/vitpose-base-simple")
self.pose_model = VitPoseForPoseEstimation.from_pretrained("nielsr/vitpose-base-simple")
# Depth estimation model
self.depth_model = pipeline("depth-estimation", model="depth-anything/Depth-Anything-V2-Small-hf")
# Semantic segmentation model
self.segmentation_model = pipeline("image-segmentation", model="facebook/maskformer-swin-base-ade")
self.segmentation_id2label = self.segmentation_model.model.config.id2label
self.segmentation_label2id = {v: k for k, v in self.segmentation_model.model.config.id2label.items()}
def get_nearest_pixel_class(self, joint, depth_map, segmentation_map):
"""
Find the nearest pixel of a specific class to a given joint coordinate
Args:
joint: (x, y) coordinates of the joint
depth_map: Depth map
segmentation_map: Semantic segmentation results
Returns:
tuple: class_name of nearest pixel, distance to that pixel
"""
PERSON_ID = 12
grid_x, grid_y = np.meshgrid(np.arange(depth_map.shape[0]), np.arange(depth_map.shape[1]))
dist_x = np.abs(grid_x.T - joint[1])
dist_y = np.abs(grid_y.T - joint[0])
dist_coord = dist_x + dist_y
depth_dist = np.abs(depth_map - depth_map[joint[1], joint[0]])
depth_dist[(segmentation_map == PERSON_ID) | (dist_coord > 50)] = 255
min_dist = np.unravel_index(np.argmin(depth_dist), depth_dist.shape)
return segmentation_map[min_dist], depth_dist[min_dist]
def detect_persons(self, image: Image.Image):
"""Detect persons in the image"""
inputs = self.person_processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = self.person_detector(**inputs)
results = self.person_processor.post_process_object_detection(
outputs,
target_sizes=torch.tensor([(image.height, image.width)]),
threshold=0.3
)
boxes = results[0]["boxes"][results[0]["labels"] == 0]
scores = results[0]["scores"][results[0]["labels"] == 0]
return boxes.cpu().numpy(), scores.cpu().numpy()
def detect_keypoints(self, image: Image.Image):
"""Detect keypoints in the image"""
boxes, scores = self.detect_persons(image)
pixel_values = self.pose_processor(image, boxes=[boxes], return_tensors="pt").pixel_values
with torch.no_grad():
outputs = self.pose_model(pixel_values)
pose_results = self.pose_processor.post_process_pose_estimation(outputs, boxes=[boxes])[0]
return pose_results, boxes, scores
def estimate_depth(self, image: Image.Image):
"""Estimate depth for the image"""
with torch.no_grad():
depth_map = np.array(self.depth_model(image)['depth'])
return depth_map
def segment_image(self, image: Image.Image):
"""Perform semantic segmentation on the image"""
with torch.no_grad():
segmentation_map = self.segmentation_model(image)
result = np.zeros(np.array(image).shape[:2], dtype=np.uint8)
print("Found", [l['label'] for l in segmentation_map])
for cls_item in sorted(segmentation_map, key=lambda l: np.sum(l['mask']), reverse=True):
result[np.array(cls_item['mask']) > 0] = self.segmentation_label2id[cls_item['label']]
return result
def detect_wall_interaction(self, image: Image.Image):
"""Detect if hands are touching walls"""
# Get all necessary information
pose_results, boxes, scores = self.detect_keypoints(image)
depth_map = self.estimate_depth(image)
segmentation_map = self.segment_image(image)
interactions = []
for person_idx, pose_result in enumerate(pose_results):
# Get hand keypoints
right_hand = pose_result["keypoints"][10].numpy().astype(int)
left_hand = pose_result["keypoints"][9].numpy().astype(int)
# Find nearest anything pixels
right_cls, r_distance = self.get_nearest_pixel_class(right_hand[:2], depth_map, segmentation_map)
left_cls, l_distance = self.get_nearest_pixel_class(left_hand[:2], depth_map, segmentation_map)
# Check for interactions
right_touching = r_distance < self.interaction_threshold
left_touching = l_distance < self.interaction_threshold
interactions.append({
"person_id": person_idx,
"right_hand_touching_object": self.segmentation_id2label[right_cls],
"left_hand_touching_object": self.segmentation_id2label[left_cls],
"right_hand_touching": right_touching,
"left_hand_touching": left_touching,
"right_hand_distance": r_distance,
"left_hand_distance": l_distance
})
return interactions, pose_results, segmentation_map, depth_map
def visualize_results(self, image: Image.Image, interactions, pose_results):
"""Visualize detection results"""
# Create base visualization from original image
vis_image = np.array(image).copy()
# Add pose keypoints
edge_annotator = sv.EdgeAnnotator(color=sv.Color.GREEN, thickness=2)
key_points = sv.KeyPoints(
xy=torch.cat([pose_result['keypoints'].unsqueeze(0) for pose_result in pose_results]).cpu().numpy()
)
vis_image = edge_annotator.annotate(scene=vis_image, key_points=key_points)
# Add interaction indicators
for interaction in interactions:
person_id = interaction["person_id"]
pose_result = pose_results[person_id]
# Draw indicators for touching hands
if interaction["right_hand_touching"]:
cv2.circle(vis_image,
tuple(map(int, pose_result["keypoints"][10][:2])),
10, (0, 0, 255), -1)
if interaction["left_hand_touching"]:
cv2.circle(vis_image,
tuple(map(int, pose_result["keypoints"][9][:2])),
10, (0, 0, 255), -1)
return Image.fromarray(vis_image)
def process_image(self, input_image):
"""Process image and return visualization with interaction detection"""
if input_image is None:
return None, ""
# Convert to PIL Image if necessary
if isinstance(input_image, np.ndarray):
image = Image.fromarray(input_image)
else:
image = input_image
image = image.resize((1280, 720))
# Detect interactions
interactions, pose_results, segmentation_map, depth_map = self.detect_wall_interaction(image)
# Visualize results
result_image = self.visualize_results(image, interactions, pose_results)
# Create interaction information text
info_text = []
for interaction in interactions:
info_text.append(f"\nPerson {interaction['person_id'] + 1}:")
if interaction["right_hand_touching"]:
info_text.append(f"Right hand is touching {interaction['right_hand_touching_object']}")
if interaction["left_hand_touching"]:
info_text.append(f"Left hand is touching {interaction['left_hand_touching_object']}")
info_text.append(f"Right hand distance to wall: {interaction['right_hand_distance']:.2f}")
info_text.append(f"Left hand distance to wall: {interaction['left_hand_distance']:.2f}")
# Add color to segmentation
mask = np.zeros((*segmentation_map.shape, 3), dtype=np.uint8)
colors = np.random.randint(0, 255, size=(100, 3))
for cl_id in np.unique(segmentation_map):
mask_array = np.array(segmentation_map == cl_id)
color = colors[cl_id % len(colors)]
mask[mask_array] = color
return result_image, mask, depth_map, "\n".join(info_text)
def create_gradio_interface():
"""Create Gradio interface"""
detector = InteractionDetector()
with gr.Blocks() as interface:
gr.Markdown("# Object Interaction Detection")
gr.Markdown("Upload an image to detect when people are touching objects.")
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Input Image")
process_button = gr.Button("Detect Interactions")
with gr.Column():
output_image = gr.Image(label="Detection Results")
interaction_info = gr.Textbox(
label="Interaction Information",
lines=10,
placeholder="Interaction details will appear here..."
)
segmentation_im = gr.Image(label="Segmentaiton Results")
depth_im = gr.Image(label="Depth Results")
process_button.click(
fn=detector.process_image,
inputs=input_image,
outputs=[output_image, segmentation_im, depth_im, interaction_info]
)
gr.Examples(
examples=[
"images/1-8ea4418f.jpg",
"images/276757975.jpg"
],
inputs=input_image
)
return interface
interface = create_gradio_interface()
if __name__ == "__main__":
interface.launch(debug=True)