joselobenitezg commited on
Commit
d04cd0a
·
1 Parent(s): a92daf2

working on pose

Browse files
Files changed (1) hide show
  1. inference/pose.py +168 -0
inference/pose.py CHANGED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import torch
2
+ # import numpy as np
3
+ # from PIL import Image
4
+ # from torchvision import transforms
5
+ # from config import LABELS_TO_IDS
6
+ # from utils.vis_utils import visualize_mask_with_overlay
7
+
8
+ # # Example usage
9
+ # TASK = 'pose'
10
+ # VERSION = 'sapiens_1b'
11
+
12
+ # model_path = get_model_path(TASK, VERSION)
13
+ # print(model_path)
14
+
15
+ # model = torch.jit.load(model_path)
16
+ # model.eval()
17
+ # model.to("cuda")
18
+
19
+ # def get_pose(image, pose_estimator, input_shape=(3, 1024, 768), device="cuda"):
20
+ # # Preprocess the image
21
+ # img = preprocess_image(image, input_shape)
22
+
23
+ # # Run the model
24
+ # with torch.no_grad():
25
+ # heatmap = pose_estimator(img.to(device))
26
+
27
+ # # Post-process the output
28
+ # keypoints, keypoint_scores = udp_decode(heatmap[0].cpu().float().numpy(),
29
+ # input_shape[1:],
30
+ # (input_shape[1] // 4, input_shape[2] // 4))
31
+
32
+ # # Scale keypoints to original image size
33
+ # scale_x = image.width / input_shape[2]
34
+ # scale_y = image.height / input_shape[1]
35
+ # keypoints[:, 0] *= scale_x
36
+ # keypoints[:, 1] *= scale_y
37
+
38
+ # # Visualize the keypoints on the original image
39
+ # pose_image = visualize_keypoints(image, keypoints, keypoint_scores)
40
+ # return pose_image
41
+
42
+ # def preprocess_image(image, input_shape):
43
+ # # Resize and normalize the image
44
+ # img = image.resize((input_shape[2], input_shape[1]))
45
+ # img = np.array(img).transpose(2, 0, 1)
46
+ # img = torch.from_numpy(img).float()
47
+ # img = img[[2, 1, 0], ...] # RGB to BGR
48
+ # mean = torch.tensor([123.675, 116.28, 103.53]).view(3, 1, 1)
49
+ # std = torch.tensor([58.395, 57.12, 57.375]).view(3, 1, 1)
50
+ # img = (img - mean) / std
51
+ # return img.unsqueeze(0)
52
+
53
+ # def udp_decode(heatmap, img_size, heatmap_size):
54
+ # # This is a simplified version. You might need to implement the full UDP decode logic
55
+ # h, w = heatmap_size
56
+ # keypoints = np.zeros((heatmap.shape[0], 2))
57
+ # keypoint_scores = np.zeros(heatmap.shape[0])
58
+
59
+ # for i in range(heatmap.shape[0]):
60
+ # hm = heatmap[i]
61
+ # idx = np.unravel_index(np.argmax(hm), hm.shape)
62
+ # keypoints[i] = [idx[1] * img_size[1] / w, idx[0] * img_size[0] / h]
63
+ # keypoint_scores[i] = hm[idx]
64
+
65
+ # return keypoints, keypoint_scores
66
+
67
+ # def visualize_keypoints(image, keypoints, keypoint_scores, threshold=0.3):
68
+ # draw = ImageDraw.Draw(image)
69
+ # for (x, y), score in zip(keypoints, keypoint_scores):
70
+ # if score > threshold:
71
+ # draw.ellipse([(x-2, y-2), (x+2, y+2)], fill='red', outline='red')
72
+ # return image
73
+
74
+ # from utils.vis_utils import resize_image
75
+ # pil_image = Image.open('/home/user/app/assets/image.webp')
76
+
77
+ # if pil_image.mode == 'RGBA':
78
+ # pil_image = pil_image.convert('RGB')
79
+
80
+ # output_pose = get_pose(resized_pil_image, model)
81
+
82
+ # output_pose
83
+ import torch
84
+ import numpy as np
85
+ from PIL import Image, ImageDraw
86
+ from torchvision import transforms
87
+ from config import SAPIENS_LITE_MODELS_PATH
88
+
89
+ def load_model(task, version):
90
+ try:
91
+ model_path = SAPIENS_LITE_MODELS_PATH[task][version]
92
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
93
+ model = torch.jit.load(model_path)
94
+ model.eval()
95
+ model.to(device)
96
+ return model, device
97
+ except KeyError as e:
98
+ print(f"Error: Tarea o versión inválida. {e}")
99
+ return None, None
100
+
101
+ def preprocess_image(image, input_shape):
102
+ img = image.resize((input_shape[2], input_shape[1]))
103
+ img = np.array(img).transpose(2, 0, 1)
104
+ img = torch.from_numpy(img).float()
105
+ img = img[[2, 1, 0], ...] # RGB to BGR
106
+ mean = torch.tensor([123.675, 116.28, 103.53]).view(3, 1, 1)
107
+ std = torch.tensor([58.395, 57.12, 57.375]).view(3, 1, 1)
108
+ img = (img - mean) / std
109
+ return img.unsqueeze(0)
110
+
111
+ def udp_decode(heatmap, img_size, heatmap_size):
112
+ h, w = heatmap_size
113
+ keypoints = np.zeros((heatmap.shape[0], 2))
114
+ keypoint_scores = np.zeros(heatmap.shape[0])
115
+
116
+ for i in range(heatmap.shape[0]):
117
+ hm = heatmap[i]
118
+ idx = np.unravel_index(np.argmax(hm), hm.shape)
119
+ keypoints[i] = [idx[1] * img_size[1] / w, idx[0] * img_size[0] / h]
120
+ keypoint_scores[i] = hm[idx]
121
+
122
+ return keypoints, keypoint_scores
123
+
124
+ def visualize_keypoints(image, keypoints, keypoint_scores, threshold=0.3):
125
+ draw = ImageDraw.Draw(image)
126
+ for (x, y), score in zip(keypoints, keypoint_scores):
127
+ if score > threshold:
128
+ draw.ellipse([(x-2, y-2), (x+2, y+2)], fill='red', outline='red')
129
+ return image
130
+
131
+ def process_image_or_video(input_data, task='pose', version='sapiens_1b'):
132
+ model, device = load_model(task, version)
133
+ if model is None or device is None:
134
+ return None
135
+
136
+ input_shape = (3, 1024, 768)
137
+
138
+ def process_frame(frame):
139
+ if isinstance(frame, np.ndarray):
140
+ frame = Image.fromarray(frame)
141
+
142
+ if frame.mode == 'RGBA':
143
+ frame = frame.convert('RGB')
144
+
145
+ img = preprocess_image(frame, input_shape)
146
+
147
+ with torch.no_grad():
148
+ heatmap = model(img.to(device))
149
+
150
+ keypoints, keypoint_scores = udp_decode(heatmap[0].cpu().float().numpy(),
151
+ input_shape[1:],
152
+ (input_shape[1] // 4, input_shape[2] // 4))
153
+
154
+ scale_x = frame.width / input_shape[2]
155
+ scale_y = frame.height / input_shape[1]
156
+ keypoints[:, 0] *= scale_x
157
+ keypoints[:, 1] *= scale_y
158
+
159
+ pose_image = visualize_keypoints(frame, keypoints, keypoint_scores)
160
+ return pose_image
161
+
162
+ if isinstance(input_data, np.ndarray): # Video frame
163
+ return process_frame(input_data)
164
+ elif isinstance(input_data, Image.Image): # Imagen
165
+ return process_frame(input_data)
166
+ else:
167
+ print("Tipo de entrada no soportado. Por favor, proporcione una imagen PIL o un frame de video numpy.")
168
+ return None