File size: 3,421 Bytes
d04cd0a
 
 
 
 
 
 
 
 
 
b05da2d
 
 
d04cd0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b05da2d
d04cd0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
607956f
 
d04cd0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import torch
import numpy as np
from PIL import Image, ImageDraw
from torchvision import transforms
from config import SAPIENS_LITE_MODELS_PATH

def load_model(task, version):
    try:
        model_path = SAPIENS_LITE_MODELS_PATH[task][version]
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        if torch.cuda.is_available() and torch.cuda.get_device_properties(0).major >= 8:
            torch.backends.cuda.matmul.allow_tf32 = True
            torch.backends.cudnn.allow_tf32 = True
        model = torch.jit.load(model_path)
        model.eval()
        model.to(device)
        return model, device
    except KeyError as e:
        print(f"Error: Tarea o versión inválida. {e}")
        return None, None

def preprocess_image(image, input_shape):
    img = image.resize((input_shape[2], input_shape[1]))
    img = np.array(img).transpose(2, 0, 1)
    img = torch.from_numpy(img).float()
    img = img[[2, 1, 0], ...] # RGB to BGR
    mean = torch.tensor([123.675, 116.28, 103.53]).view(3, 1, 1)
    std = torch.tensor([58.395, 57.12, 57.375]).view(3, 1, 1)
    img = (img - mean) / std
    return img.unsqueeze(0)

def udp_decode(heatmap, img_size, heatmap_size):
    # This is a simplified version. You might need to implement the full UDP decode logic
    h, w = heatmap_size
    keypoints = np.zeros((heatmap.shape[0], 2))
    keypoint_scores = np.zeros(heatmap.shape[0])
    
    for i in range(heatmap.shape[0]):
        hm = heatmap[i]
        idx = np.unravel_index(np.argmax(hm), hm.shape)
        keypoints[i] = [idx[1] * img_size[1] / w, idx[0] * img_size[0] / h]
        keypoint_scores[i] = hm[idx]
    
    return keypoints, keypoint_scores

def visualize_keypoints(image, keypoints, keypoint_scores, threshold=0.3):
    draw = ImageDraw.Draw(image)
    for (x, y), score in zip(keypoints, keypoint_scores):
        if score > threshold:
            draw.ellipse([(x-2, y-2), (x+2, y+2)], fill='red', outline='red')
    return image

def process_image_or_video(input_data, task='pose', version='sapiens_1b'):
    model, device = load_model(task, version)
    if model is None or device is None:
        return None

    input_shape = (3, 1024, 768)

    def process_frame(frame):
        if isinstance(frame, np.ndarray):
            frame = Image.fromarray(frame)
        
        if frame.mode == 'RGBA':
            frame = frame.convert('RGB')
        
        img = preprocess_image(frame, input_shape)
        
        with torch.no_grad():
            heatmap = model(img.to(device))
        
        keypoints, keypoint_scores = udp_decode(heatmap[0].cpu().float().numpy(), 
                                                input_shape[1:], 
                                                (input_shape[1] // 4, input_shape[2] // 4))
        
        scale_x = frame.width / input_shape[2]
        scale_y = frame.height / input_shape[1]
        keypoints[:, 0] *= scale_x
        keypoints[:, 1] *= scale_y
        
        pose_image = visualize_keypoints(frame, keypoints, keypoint_scores)
        return pose_image

    if isinstance(input_data, np.ndarray):  # Video frame
        return process_frame(input_data)
    elif isinstance(input_data, Image.Image):  # Imagen
        return process_frame(input_data)
    else:
        print("Tipo de entrada no soportado. Por favor, proporcione una imagen PIL o un frame de video numpy.")
        return None